kernel.rs
1 use num_traits::Num; 2 3 pub enum Assert<const CHECK: bool> {} 4 pub trait IsTrue {} 5 impl IsTrue for Assert<true> {} 6 7 #[derive(Clone, Debug)] 8 pub struct Kernel<'a, const N: usize, T: Num + Copy> { 9 kernel: &'a [T; N], 10 } 11 12 pub struct KernelIntoIter<'a, T: Num + Copy> { 13 iter: std::slice::Iter<'a, T>, 14 } 15 16 impl<T: Num + Copy> Iterator for KernelIntoIter<'_, T> { 17 type Item = T; 18 fn next(&mut self) -> Option<Self::Item> { 19 // access fields of a tuple struct numerically 20 self.iter.next().copied() 21 } 22 } 23 24 impl<'a, const N: usize, T: Num + Copy> IntoIterator for Kernel<'a, N, T> { 25 type Item = T; 26 type IntoIter = KernelIntoIter<'a, T>; 27 28 fn into_iter(self) -> Self::IntoIter { 29 KernelIntoIter { 30 iter: self.kernel.iter(), 31 } 32 } 33 } 34 35 pub struct KernelIter<'a, const N: usize, T: Num + Copy> { 36 iter: std::slice::Iter<'a, T>, 37 } 38 39 impl<'a, const N: usize, T: Num + Copy> Iterator for KernelIter<'a, N, T> { 40 type Item = &'a T; 41 42 fn next(&mut self) -> Option<Self::Item> { 43 self.iter.next() 44 } 45 } 46 47 impl<'a, const N: usize, T: Num + Copy + std::fmt::Debug> Kernel<'a, N, T> 48 where 49 Assert<{ N > 0 }>: IsTrue, 50 { 51 pub fn new(kernel: &'a [T; N]) -> Self { 52 Kernel { kernel } 53 } 54 pub fn transition(&mut self, values: &[T; N]) -> [T; N] { 55 // TODO Vectorize 56 println!("{:?} {:?}", self.kernel, self.kernel.iter()); 57 let res: Vec<_> = self 58 .kernel 59 .iter() 60 .zip(values.iter()) 61 .map(|(&a, &b)| a * b) 62 .collect(); 63 // The unwrap is perfectly safe here. 64 let res: [T; N] = res[0..N].try_into().unwrap(); 65 res 66 } 67 pub fn iter(&self) -> KernelIter<'_, N, T> { 68 KernelIter { 69 iter: self.kernel.iter(), 70 } 71 } 72 } 73 74 #[cfg(test)] 75 mod test { 76 use super::*; 77 use quickcheck_macros::quickcheck; 78 79 fn type_of<T>(_: T) -> String { 80 std::any::type_name::<T>().to_string() 81 } 82 #[quickcheck] 83 fn identity(data: Vec<u16>) -> bool { 84 let mut kernel = Kernel::new(&[1, 1, 1, 1, 1]); 85 let mut res = true; 86 for window in data.windows(5) { 87 let window: &[u16; 5] = window.try_into().unwrap(); 88 let out = kernel.transition(window); 89 let expected = window; 90 res = res && &out == expected 91 } 92 res 93 } 94 95 #[quickcheck] 96 fn sma(data: Vec<u16>) -> bool { 97 let mut kernel: Kernel<'_, 5, u32> = Kernel::new(&[1, 2, 3, 4, 5]); 98 let mut res = true; 99 for window in data.windows(5) { 100 let window: Vec<_> = window.iter().map(|&e| e as u32).collect(); 101 let window: &[u32; 5] = &window.try_into().unwrap(); 102 let out = kernel.transition(window); 103 let sum = out.iter().sum::<u32>(); 104 let expected = window 105 .iter() 106 .zip(kernel.iter()) 107 .fold(0, |acc, (x, y)| acc + x * y); 108 println!("{} {}", expected, sum); 109 res = res && sum == expected 110 } 111 res 112 } 113 114 #[test] 115 fn into_iter() { 116 let kernel: Kernel<'_, 5, u32> = Kernel::new(&[1, 2, 3, 4, 5]); 117 let mut res = vec![]; 118 for x in kernel.into_iter() { 119 res.push(x); 120 assert_eq!(type_of(x), "u32"); 121 } 122 123 assert_eq!(res, vec![1, 2, 3, 4, 5]) 124 } 125 126 #[test] 127 fn iter() { 128 let kernel: Kernel<'_, 5, u32> = Kernel::new(&[1, 2, 3, 4, 5]); 129 let mut res = vec![]; 130 for x in kernel.iter() { 131 res.push(x); 132 assert_eq!(type_of(x), "&u32"); 133 } 134 } 135 }