/ haskell / src / BN254.hs
BN254.hs
  1  
  2  -- | The BN254 scalar field
  3  
  4  {-# LANGUAGE Strict, BangPatterns #-} 
  5  module BN254 where
  6  
  7  --------------------------------------------------------------------------------
  8  
  9  import Prelude hiding (div)
 10  import qualified Prelude
 11  
 12  import Data.Bits
 13  import Data.Word
 14  import Data.Ratio
 15  
 16  import System.Random
 17  import Text.Printf
 18  
 19  import Misc
 20  
 21  --------------------------------------------------------------------------------
 22  
 23  fieldPrime :: Integer
 24  fieldPrime = 21888242871839275222246405745257275088548364400416034343698204186575808495617
 25  
 26  modP :: Integer -> Integer
 27  modP x = mod x fieldPrime
 28  
 29  halfPrimePlus1 :: Integer
 30  halfPrimePlus1 = 1 + Prelude.div fieldPrime 2
 31  
 32  --------------------------------------------------------------------------------
 33  
 34  newtype F 
 35    = MkF Integer 
 36    deriving (Eq,Show)
 37  
 38  fromF :: F -> Integer
 39  fromF (MkF x) = x
 40  
 41  -- from the circom docs: @val(z) = z-p  if p/2 +1 <= z < p@
 42  signedFromF :: F -> Integer
 43  signedFromF (MkF x) = if x >= halfPrimePlus1 then x - fieldPrime else x
 44  
 45  toF :: Integer -> F
 46  toF = MkF . modP
 47  
 48  isZero :: F -> Bool
 49  isZero (MkF x) = (x == 0)
 50  
 51  fromBool :: Bool -> F
 52  fromBool b = if b then 1 else 0
 53  
 54  toBool :: F -> Bool
 55  toBool = not . isZero
 56  
 57  --------------------------------------------------------------------------------
 58  
 59  neg :: F -> F
 60  neg (MkF x) = toF (negate x)
 61  
 62  add :: F -> F -> F
 63  add (MkF x) (MkF y) = toF (x+y)
 64  
 65  sub :: F -> F -> F
 66  sub (MkF x) (MkF y) = toF (x-y)
 67  
 68  mul :: F -> F -> F
 69  mul (MkF x) (MkF y) = toF (x*y)
 70  
 71  instance Num F where
 72    fromInteger = toF 
 73    negate = neg
 74    (+) = add
 75    (-) = sub
 76    (*) = mul
 77    abs = id
 78    signum _ = toF 1
 79  
 80  square :: F -> F
 81  square x = x*x
 82  
 83  rndF :: IO F
 84  rndF = MkF <$> randomRIO (0,fieldPrime-1)
 85  
 86  --------------------------------------------------------------------------------
 87  
 88  pow :: F -> Integer -> F
 89  pow x0 exponent
 90    | exponent < 0  = error "power: expecting positive exponent"
 91    | otherwise     = go 1 x0 exponent
 92    where
 93      go !acc _ 0 = acc
 94      go !acc s e = go acc' s' (shiftR e 1) where
 95        s'   = s*s
 96        acc' = if e .&. 1 == 0 then acc else acc*s
 97  
 98  invNaive :: F -> F
 99  invNaive x = pow x (fieldPrime - 2)
100  
101  divNaive :: F -> F -> F
102  divNaive x y = x * invNaive y
103  
104  --------------------------------------------------------------------------------
105  
106  instance Fractional F where
107    fromRational q = fromInteger (numerator q) / fromInteger (denominator q)
108    recip = inv
109    (/)   = div
110  
111  --------------------------------------------------------------------------------
112  
113  fromBytesLE :: [Word8] -> F
114  fromBytesLE = toF . integerFromBytesLE
115  
116  integerFromBytesLE :: [Word8] -> Integer
117  integerFromBytesLE = go where
118    go []     = 0
119    go (b:bs) = fromIntegral b + (shiftL (go bs) 8)
120  
121  --------------------------------------------------------------------------------
122  
123  instance ShowHex Integer where showHex = printf "0x%x"
124  instance ShowHex F       where showHex (MkF x) = showHex x
125  
126  --------------------------------------------------------------------------------
127  
128  -- | Inversion (using Euclid's algorithm)
129  inv :: F -> F
130  inv (MkF a) 
131    | a == 0    = 0 -- error "field inverse of zero (generic prime)"
132    | otherwise = MkF (euclid 1 0 a fieldPrime) 
133  
134  -- | Division via Euclid's algorithm
135  div :: F -> F -> F
136  div (MkF a) (MkF b)
137    | b == 0    = 0 -- error "field division by zero (generic prime)"
138    | otherwise = MkF (euclid a 0 b fieldPrime) 
139  
140  --------------------------------------------------------------------------------
141  -- * Euclidean algorithm
142  
143  -- | Extended binary Euclidean algorithm
144  euclid :: Integer -> Integer -> Integer -> Integer -> Integer 
145  euclid !x1 !x2 !u !v = go x1 x2 u v where
146  
147    p = fieldPrime
148  
149    halfp1 = shiftR (p+1) 1
150  
151    modp :: Integer -> Integer
152    modp n = mod n p
153  
154    -- Inverse using the binary Euclidean algorithm 
155    euclid :: Integer -> Integer
156    euclid a 
157      | a == 0     = 0
158      | otherwise  = go 1 0 a p
159    
160    go :: Integer -> Integer -> Integer -> Integer -> Integer
161    go !x1 !x2 !u !v 
162      | u==1       = x1
163      | v==1       = x2
164      | otherwise  = stepU x1 x2 u v
165  
166    stepU :: Integer -> Integer -> Integer -> Integer -> Integer
167    stepU !x1 !x2 !u !v = if even u 
168      then let u'  = shiftR u 1
169               x1' = if even x1 then shiftR x1 1 else shiftR x1 1 + halfp1
170           in  stepU x1' x2 u' v
171      else     stepV x1  x2 u  v
172  
173    stepV :: Integer -> Integer -> Integer -> Integer -> Integer
174    stepV !x1 !x2 !u !v = if even v
175      then let v'  = shiftR v 1
176               x2' = if even x2 then shiftR x2 1 else shiftR x2 1 + halfp1
177           in  stepV x1 x2' u v' 
178      else     final x1 x2  u v
179  
180    final :: Integer -> Integer -> Integer -> Integer -> Integer
181    final !x1 !x2 !u !v = if u>=v
182  
183      then let u'  = u-v
184               x1' = if x1 >= x2 then modp (x1-x2) else modp (x1+p-x2)
185           in  go x1' x2  u' v 
186  
187      else let v'  = v-u
188               x2' = if x2 >= x1 then modp (x2-x1) else modp (x2+p-x1)
189           in  go x1  x2' u  v'
190  
191  --------------------------------------------------------------------------------