basic.py
1 # MIT License 2 # 3 # Copyright (c) 2015 Brian Warner and other contributors 4 5 # Permission is hereby granted, free of charge, to any person obtaining a copy 6 # of this software and associated documentation files (the "Software"), to deal 7 # in the Software without restriction, including without limitation the rights 8 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 # copies of the Software, and to permit persons to whom the Software is 10 # furnished to do so, subject to the following conditions: 11 # 12 # The above copyright notice and this permission notice shall be included in all 13 # copies or substantial portions of the Software. 14 # 15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 # SOFTWARE. 22 23 import binascii, hashlib, itertools 24 25 Q = 2**255 - 19 26 L = 2**252 + 27742317777372353535851937790883648493 27 28 def inv(x): 29 return pow(x, Q-2, Q) 30 31 d = -121665 * inv(121666) 32 I = pow(2,(Q-1)//4,Q) 33 34 def xrecover(y): 35 xx = (y*y-1) * inv(d*y*y+1) 36 x = pow(xx,(Q+3)//8,Q) 37 if (x*x - xx) % Q != 0: x = (x*I) % Q 38 if x % 2 != 0: x = Q-x 39 return x 40 41 By = 4 * inv(5) 42 Bx = xrecover(By) 43 B = [Bx % Q,By % Q] 44 45 # Extended Coordinates: x=X/Z, y=Y/Z, x*y=T/Z 46 # http://www.hyperelliptic.org/EFD/g1p/auto-twisted-extended-1.html 47 48 def xform_affine_to_extended(pt): 49 (x, y) = pt 50 return (x%Q, y%Q, 1, (x*y)%Q) # (X,Y,Z,T) 51 52 def xform_extended_to_affine(pt): 53 (x, y, z, _) = pt 54 return ((x*inv(z))%Q, (y*inv(z))%Q) 55 56 def double_element(pt): # extended->extended 57 # dbl-2008-hwcd 58 (X1, Y1, Z1, _) = pt 59 A = (X1*X1) 60 B = (Y1*Y1) 61 C = (2*Z1*Z1) 62 D = (-A) % Q 63 J = (X1+Y1) % Q 64 E = (J*J-A-B) % Q 65 G = (D+B) % Q 66 F = (G-C) % Q 67 H = (D-B) % Q 68 X3 = (E*F) % Q 69 Y3 = (G*H) % Q 70 Z3 = (F*G) % Q 71 T3 = (E*H) % Q 72 return (X3, Y3, Z3, T3) 73 74 def add_elements(pt1, pt2): # extended->extended 75 # add-2008-hwcd-3 . Slightly slower than add-2008-hwcd-4, but -3 is 76 # unified, so it's safe for general-purpose addition 77 (X1, Y1, Z1, T1) = pt1 78 (X2, Y2, Z2, T2) = pt2 79 A = ((Y1-X1)*(Y2-X2)) % Q 80 B = ((Y1+X1)*(Y2+X2)) % Q 81 C = T1*(2*d)*T2 % Q 82 D = Z1*2*Z2 % Q 83 E = (B-A) % Q 84 F = (D-C) % Q 85 G = (D+C) % Q 86 H = (B+A) % Q 87 X3 = (E*F) % Q 88 Y3 = (G*H) % Q 89 T3 = (E*H) % Q 90 Z3 = (F*G) % Q 91 return (X3, Y3, Z3, T3) 92 93 def scalarmult_element_safe_slow(pt, n): 94 # this form is slightly slower, but tolerates arbitrary points, including 95 # those which are not in the main 1*L subgroup. This includes points of 96 # order 1 (the neutral element Zero), 2, 4, and 8. 97 assert n >= 0 98 if n==0: 99 return xform_affine_to_extended((0,1)) 100 _ = double_element(scalarmult_element_safe_slow(pt, n>>1)) 101 return add_elements(_, pt) if n&1 else _ 102 103 def _add_elements_nonunfied(pt1, pt2): # extended->extended 104 # add-2008-hwcd-4 : NOT unified, only for pt1!=pt2. About 10% faster than 105 # the (unified) add-2008-hwcd-3, and safe to use inside scalarmult if you 106 # aren't using points of order 1/2/4/8 107 (X1, Y1, Z1, T1) = pt1 108 (X2, Y2, Z2, T2) = pt2 109 A = ((Y1-X1)*(Y2+X2)) % Q 110 B = ((Y1+X1)*(Y2-X2)) % Q 111 C = (Z1*2*T2) % Q 112 D = (T1*2*Z2) % Q 113 E = (D+C) % Q 114 F = (B-A) % Q 115 G = (B+A) % Q 116 H = (D-C) % Q 117 X3 = (E*F) % Q 118 Y3 = (G*H) % Q 119 Z3 = (F*G) % Q 120 T3 = (E*H) % Q 121 return (X3, Y3, Z3, T3) 122 123 def scalarmult_element(pt, n): # extended->extended 124 # This form only works properly when given points that are a member of 125 # the main 1*L subgroup. It will give incorrect answers when called with 126 # the points of order 1/2/4/8, including point Zero. (it will also work 127 # properly when given points of order 2*L/4*L/8*L) 128 assert n >= 0 129 if n==0: 130 return xform_affine_to_extended((0,1)) 131 _ = double_element(scalarmult_element(pt, n>>1)) 132 return _add_elements_nonunfied(_, pt) if n&1 else _ 133 134 # points are encoded as 32-bytes little-endian, b255 is sign, b2b1b0 are 0 135 136 def encodepoint(P): 137 x = P[0] 138 y = P[1] 139 # MSB of output equals x.b0 (=x&1) 140 # rest of output is little-endian y 141 assert 0 <= y < (1<<255) # always < 0x7fff..ff 142 if x & 1: 143 y += 1<<255 144 return binascii.unhexlify("%064x" % y)[::-1] 145 146 def isoncurve(P): 147 x = P[0] 148 y = P[1] 149 return (-x*x + y*y - 1 - d*x*x*y*y) % Q == 0 150 151 class NotOnCurve(Exception): 152 pass 153 154 def decodepoint(s): 155 unclamped = int(binascii.hexlify(s[:32][::-1]), 16) 156 clamp = (1 << 255) - 1 157 y = unclamped & clamp # clear MSB 158 x = xrecover(y) 159 if bool(x & 1) != bool(unclamped & (1<<255)): x = Q-x 160 P = [x,y] 161 if not isoncurve(P): raise NotOnCurve("decoding point that is not on curve") 162 return P 163 164 # scalars are encoded as 32-bytes little-endian 165 166 def bytes_to_scalar(s): 167 assert len(s) == 32, len(s) 168 return int(binascii.hexlify(s[::-1]), 16) 169 170 def bytes_to_clamped_scalar(s): 171 # Ed25519 private keys clamp the scalar to ensure two things: 172 # 1: integer value is in L/2 .. L, to avoid small-logarithm 173 # non-wraparaound 174 # 2: low-order 3 bits are zero, so a small-subgroup attack won't learn 175 # any information 176 # set the top two bits to 01, and the bottom three to 000 177 a_unclamped = bytes_to_scalar(s) 178 AND_CLAMP = (1<<254) - 1 - 7 179 OR_CLAMP = (1<<254) 180 a_clamped = (a_unclamped & AND_CLAMP) | OR_CLAMP 181 return a_clamped 182 183 def random_scalar(entropy_f): # 0..L-1 inclusive 184 # reduce the bias to a safe level by generating 256 extra bits 185 oversized = int(binascii.hexlify(entropy_f(32+32)), 16) 186 return oversized % L 187 188 def password_to_scalar(pw): 189 oversized = hashlib.sha512(pw).digest() 190 return int(binascii.hexlify(oversized), 16) % L 191 192 def scalar_to_bytes(y): 193 y = y % L 194 assert 0 <= y < 2**256 195 return binascii.unhexlify("%064x" % y)[::-1] 196 197 # Elements, of various orders 198 199 def is_extended_zero(XYTZ): 200 # catch Zero 201 (X, Y, Z, T) = XYTZ 202 Y = Y % Q 203 Z = Z % Q 204 if X==0 and Y==Z and Y!=0: 205 return True 206 return False 207 208 class ElementOfUnknownGroup: 209 # This is used for points of order 2,4,8,2*L,4*L,8*L 210 def __init__(self, XYTZ): 211 assert isinstance(XYTZ, tuple) 212 assert len(XYTZ) == 4 213 self.XYTZ = XYTZ 214 215 def add(self, other): 216 if not isinstance(other, ElementOfUnknownGroup): 217 raise TypeError("elements can only be added to other elements") 218 sum_XYTZ = add_elements(self.XYTZ, other.XYTZ) 219 if is_extended_zero(sum_XYTZ): 220 return Zero 221 return ElementOfUnknownGroup(sum_XYTZ) 222 223 def scalarmult(self, s): 224 if isinstance(s, ElementOfUnknownGroup): 225 raise TypeError("elements cannot be multiplied together") 226 assert s >= 0 227 product = scalarmult_element_safe_slow(self.XYTZ, s) 228 return ElementOfUnknownGroup(product) 229 230 def to_bytes(self): 231 return encodepoint(xform_extended_to_affine(self.XYTZ)) 232 def __eq__(self, other): 233 return self.to_bytes() == other.to_bytes() 234 def __ne__(self, other): 235 return not self == other 236 237 class Element(ElementOfUnknownGroup): 238 # this only holds elements in the main 1*L subgroup. It never holds Zero, 239 # or elements of order 1/2/4/8, or 2*L/4*L/8*L. 240 241 def add(self, other): 242 if not isinstance(other, ElementOfUnknownGroup): 243 raise TypeError("elements can only be added to other elements") 244 sum_element = ElementOfUnknownGroup.add(self, other) 245 if sum_element is Zero: 246 return sum_element 247 if isinstance(other, Element): 248 # adding two subgroup elements results in another subgroup 249 # element, or Zero, and we've already excluded Zero 250 return Element(sum_element.XYTZ) 251 # not necessarily a subgroup member, so assume not 252 return sum_element 253 254 def scalarmult(self, s): 255 if isinstance(s, ElementOfUnknownGroup): 256 raise TypeError("elements cannot be multiplied together") 257 # scalarmult of subgroup members can be done modulo the subgroup 258 # order, and using the faster non-unified function. 259 s = s % L 260 # scalarmult(s=0) gets you Zero 261 if s == 0: 262 return Zero 263 # scalarmult(s=1) gets you self, which is a subgroup member 264 # scalarmult(s<grouporder) gets you a different subgroup member 265 return Element(scalarmult_element(self.XYTZ, s)) 266 267 # negation and subtraction only make sense for the main subgroup 268 def negate(self): 269 # slow. Prefer e.scalarmult(-pw) to e.scalarmult(pw).negate() 270 return Element(scalarmult_element(self.XYTZ, L-2)) 271 def subtract(self, other): 272 return self.add(other.negate()) 273 274 class _ZeroElement(ElementOfUnknownGroup): 275 def add(self, other): 276 return other # zero+anything = anything 277 def scalarmult(self, s): 278 return self # zero*anything = zero 279 def negate(self): 280 return self # -zero = zero 281 def subtract(self, other): 282 return self.add(other.negate()) 283 284 285 Base = Element(xform_affine_to_extended(B)) 286 Zero = _ZeroElement(xform_affine_to_extended((0,1))) # the neutral (identity) element 287 288 _zero_bytes = Zero.to_bytes() 289 290 291 def arbitrary_element(seed): # unknown DL 292 # TODO: if we don't need uniformity, maybe use just sha256 here? 293 hseed = hashlib.sha512(seed).digest() 294 y = int(binascii.hexlify(hseed), 16) % Q 295 296 # we try successive Y values until we find a valid point 297 for plus in itertools.count(0): 298 y_plus = (y + plus) % Q 299 x = xrecover(y_plus) 300 Pa = [x,y_plus] # no attempt to use both "positive" and "negative" X 301 302 # only about 50% of Y coordinates map to valid curve points (I think 303 # the other half give you points on the "twist"). 304 if not isoncurve(Pa): 305 continue 306 307 P = ElementOfUnknownGroup(xform_affine_to_extended(Pa)) 308 # even if the point is on our curve, it may not be in our particular 309 # (order=L) subgroup. The curve has order 8*L, so an arbitrary point 310 # could have order 1,2,4,8,1*L,2*L,4*L,8*L (everything which divides 311 # the group order). 312 313 # [I MAY BE COMPLETELY WRONG ABOUT THIS, but my brief statistical 314 # tests suggest it's not too far off] There are phi(x) points with 315 # order x, so: 316 # 1 element of order 1: [(x=0,y=1)=Zero] 317 # 1 element of order 2 [(x=0,y=-1)] 318 # 2 elements of order 4 319 # 4 elements of order 8 320 # L-1 elements of order L (including Base) 321 # L-1 elements of order 2*L 322 # 2*(L-1) elements of order 4*L 323 # 4*(L-1) elements of order 8*L 324 325 # So 50% of random points will have order 8*L, 25% will have order 326 # 4*L, 13% order 2*L, and 13% will have our desired order 1*L (and a 327 # vanishingly small fraction will have 1/2/4/8). If we multiply any 328 # of the 8*L points by 2, we're sure to get an 4*L point (and 329 # multiplying a 4*L point by 2 gives us a 2*L point, and so on). 330 # Multiplying a 1*L point by 2 gives us a different 1*L point. So 331 # multiplying by 8 gets us from almost any point into a uniform point 332 # on the correct 1*L subgroup. 333 334 P8 = P.scalarmult(8) 335 336 # if we got really unlucky and picked one of the 8 low-order points, 337 # multiplying by 8 will get us to the identity (Zero), which we check 338 # for explicitly. 339 if is_extended_zero(P8.XYTZ): 340 continue 341 342 # Test that we're finally in the right group. We want to scalarmult 343 # by L, and we want to *not* use the trick in Group.scalarmult() 344 # which does x%L, because that would bypass the check we care about. 345 # P is still an _ElementOfUnknownGroup, which doesn't use x%L because 346 # that's not correct for points outside the main group. 347 assert is_extended_zero(P8.scalarmult(L).XYTZ) 348 349 return Element(P8.XYTZ) 350 # never reached 351 352 def bytes_to_unknown_group_element(bytes): 353 # this accepts all elements, including Zero and wrong-subgroup ones 354 if bytes == _zero_bytes: 355 return Zero 356 XYTZ = xform_affine_to_extended(decodepoint(bytes)) 357 return ElementOfUnknownGroup(XYTZ) 358 359 def bytes_to_element(bytes): 360 # this strictly only accepts elements in the right subgroup 361 P = bytes_to_unknown_group_element(bytes) 362 if P is Zero: 363 raise ValueError("element was Zero") 364 if not is_extended_zero(P.scalarmult(L).XYTZ): 365 raise ValueError("element is not in the right group") 366 # the point is in the expected 1*L subgroup, not in the 2/4/8 groups, 367 # or in the 2*L/4*L/8*L groups. Promote it to a correct-group Element. 368 return Element(P.XYTZ)