/ src / kernel.rs
kernel.rs
  1  use num_traits::Num;
  2  
  3  pub enum Assert<const CHECK: bool> {}
  4  pub trait IsTrue {}
  5  impl IsTrue for Assert<true> {}
  6  
  7  #[derive(Clone, Debug)]
  8  pub struct Kernel<'a, const N: usize, T: Num + Copy> {
  9      kernel: &'a [T; N],
 10  }
 11  
 12  pub struct KernelIntoIter<'a, T: Num + Copy> {
 13      iter: std::slice::Iter<'a, T>,
 14  }
 15  
 16  impl<T: Num + Copy> Iterator for KernelIntoIter<'_, T> {
 17      type Item = T;
 18      fn next(&mut self) -> Option<Self::Item> {
 19          // access fields of a tuple struct numerically
 20          self.iter.next().copied()
 21      }
 22  }
 23  
 24  impl<'a, const N: usize, T: Num + Copy> IntoIterator for Kernel<'a, N, T> {
 25      type Item = T;
 26      type IntoIter = KernelIntoIter<'a, T>;
 27  
 28      fn into_iter(self) -> Self::IntoIter {
 29          KernelIntoIter {
 30              iter: self.kernel.iter(),
 31          }
 32      }
 33  }
 34  
 35  pub struct KernelIter<'a, const N: usize, T: Num + Copy> {
 36      iter: std::slice::Iter<'a, T>,
 37  }
 38  
 39  impl<'a, const N: usize, T: Num + Copy> Iterator for KernelIter<'a, N, T> {
 40      type Item = &'a T;
 41  
 42      fn next(&mut self) -> Option<Self::Item> {
 43          self.iter.next()
 44      }
 45  }
 46  
 47  impl<'a, const N: usize, T: Num + Copy + std::fmt::Debug> Kernel<'a, N, T>
 48  where
 49      Assert<{ N > 0 }>: IsTrue,
 50  {
 51      pub fn new(kernel: &'a [T; N]) -> Self {
 52          Kernel { kernel }
 53      }
 54      pub fn transition(&mut self, values: &[T; N]) -> [T; N] {
 55          // TODO Vectorize
 56          println!("{:?} {:?}", self.kernel, self.kernel.iter());
 57          let res: Vec<_> = self
 58              .kernel
 59              .iter()
 60              .zip(values.iter())
 61              .map(|(&a, &b)| a * b)
 62              .collect();
 63          // The unwrap is perfectly safe here.
 64          let res: [T; N] = res[0..N].try_into().unwrap();
 65          res
 66      }
 67      pub fn iter(&self) -> KernelIter<'_, N, T> {
 68          KernelIter {
 69              iter: self.kernel.iter(),
 70          }
 71      }
 72  }
 73  
 74  #[cfg(test)]
 75  mod test {
 76      use super::*;
 77      use quickcheck_macros::quickcheck;
 78  
 79      fn type_of<T>(_: T) -> String {
 80          std::any::type_name::<T>().to_string()
 81      }
 82      #[quickcheck]
 83      fn identity(data: Vec<u16>) -> bool {
 84          let mut kernel = Kernel::new(&[1, 1, 1, 1, 1]);
 85          let mut res = true;
 86          for window in data.windows(5) {
 87              let window: &[u16; 5] = window.try_into().unwrap();
 88              let out = kernel.transition(window);
 89              let expected = window;
 90              res = res && &out == expected
 91          }
 92          res
 93      }
 94  
 95      #[quickcheck]
 96      fn sma(data: Vec<u16>) -> bool {
 97          let mut kernel: Kernel<'_, 5, u32> = Kernel::new(&[1, 2, 3, 4, 5]);
 98          let mut res = true;
 99          for window in data.windows(5) {
100              let window: Vec<_> = window.iter().map(|&e| e as u32).collect();
101              let window: &[u32; 5] = &window.try_into().unwrap();
102              let out = kernel.transition(window);
103              let sum = out.iter().sum::<u32>();
104              let expected = window
105                  .iter()
106                  .zip(kernel.iter())
107                  .fold(0, |acc, (x, y)| acc + x * y);
108              println!("{} {}", expected, sum);
109              res = res && sum == expected
110          }
111          res
112      }
113  
114      #[test]
115      fn into_iter() {
116          let kernel: Kernel<'_, 5, u32> = Kernel::new(&[1, 2, 3, 4, 5]);
117          let mut res = vec![];
118          for x in kernel.into_iter() {
119              res.push(x);
120              assert_eq!(type_of(x), "u32");
121          }
122  
123          assert_eq!(res, vec![1, 2, 3, 4, 5])
124      }
125  
126      #[test]
127      fn iter() {
128          let kernel: Kernel<'_, 5, u32> = Kernel::new(&[1, 2, 3, 4, 5]);
129          let mut res = vec![];
130          for x in kernel.iter() {
131              res.push(x);
132              assert_eq!(type_of(x), "&u32");
133          }
134      }
135  }