/ reference / haskell / src / Poseidon2 / Permutation.hs
Permutation.hs
 1  
 2  -- | The Poseidon2 permutation
 3  
 4  module Poseidon2.Permutation where
 5  
 6  --------------------------------------------------------------------------------
 7  
 8  import ZK.Algebra.Curves.BN128.Fr.Mont (Fr)
 9  
10  import Poseidon2.RoundConsts
11  
12  --------------------------------------------------------------------------------
13  
14  sbox :: Fr -> Fr
15  sbox x = x4*x where
16    x2 = x *x
17    x4 = x2*x2
18  
19  internalRound :: Fr -> (Fr,Fr,Fr) -> (Fr,Fr,Fr) 
20  internalRound c (x,y,z) = 
21    ( 2*x' +   y +   z 
22    ,   x' + 2*y +   z 
23    ,   x' +   y + 3*z 
24    )
25    where
26      x' = sbox (x + c) 
27  
28  externalRound :: (Fr,Fr,Fr) -> (Fr,Fr,Fr) -> (Fr,Fr,Fr)
29  externalRound (cx,cy,cz) (x,y,z) = (x'+s , y'+s , z'+s) where
30    x' = sbox (x + cx)
31    y' = sbox (y + cy)
32    z' = sbox (z + cz)
33    s  = x' + y' + z'
34  
35  linearLayer :: (Fr,Fr,Fr) -> (Fr,Fr,Fr)
36  linearLayer (x,y,z) = (x+s, y+s, z+s) where s = x+y+z
37  
38  --------------------------------------------------------------------------------
39  
40  permutation :: (Fr,Fr,Fr) -> (Fr,Fr,Fr)
41  permutation 
42    = (\state -> foldl (flip externalRound) state finalRoundConsts   )
43    . (\state -> foldl (flip internalRound) state internalRoundConsts)
44    . (\state -> foldl (flip externalRound) state initialRoundConsts )
45    . linearLayer
46  
47  --------------------------------------------------------------------------------