/ RNS / Cryptography / pure25519 / basic.py
basic.py
  1  # MIT License
  2  #
  3  # Copyright (c) 2015 Brian Warner and other contributors
  4  
  5  # Permission is hereby granted, free of charge, to any person obtaining a copy
  6  # of this software and associated documentation files (the "Software"), to deal
  7  # in the Software without restriction, including without limitation the rights
  8  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9  # copies of the Software, and to permit persons to whom the Software is
 10  # furnished to do so, subject to the following conditions:
 11  #
 12  # The above copyright notice and this permission notice shall be included in all
 13  # copies or substantial portions of the Software.
 14  #
 15  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 16  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 17  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 18  # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 19  # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 20  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 21  # SOFTWARE.
 22  
 23  import binascii, hashlib, itertools
 24  
 25  Q = 2**255 - 19
 26  L = 2**252 + 27742317777372353535851937790883648493
 27  
 28  def inv(x):
 29      return pow(x, Q-2, Q)
 30  
 31  d = -121665 * inv(121666)
 32  I = pow(2,(Q-1)//4,Q)
 33  
 34  def xrecover(y):
 35      xx = (y*y-1) * inv(d*y*y+1)
 36      x = pow(xx,(Q+3)//8,Q)
 37      if (x*x - xx) % Q != 0: x = (x*I) % Q
 38      if x % 2 != 0: x = Q-x
 39      return x
 40  
 41  By = 4 * inv(5)
 42  Bx = xrecover(By)
 43  B = [Bx % Q,By % Q]
 44  
 45  # Extended Coordinates: x=X/Z, y=Y/Z, x*y=T/Z
 46  # http://www.hyperelliptic.org/EFD/g1p/auto-twisted-extended-1.html
 47  
 48  def xform_affine_to_extended(pt):
 49      (x, y) = pt
 50      return (x%Q, y%Q, 1, (x*y)%Q) # (X,Y,Z,T)
 51  
 52  def xform_extended_to_affine(pt):
 53      (x, y, z, _) = pt
 54      return ((x*inv(z))%Q, (y*inv(z))%Q)
 55  
 56  def double_element(pt): # extended->extended
 57      # dbl-2008-hwcd
 58      (X1, Y1, Z1, _) = pt
 59      A = (X1*X1)
 60      B = (Y1*Y1)
 61      C = (2*Z1*Z1)
 62      D = (-A) % Q
 63      J = (X1+Y1) % Q
 64      E = (J*J-A-B) % Q
 65      G = (D+B) % Q
 66      F = (G-C) % Q
 67      H = (D-B) % Q
 68      X3 = (E*F) % Q
 69      Y3 = (G*H) % Q
 70      Z3 = (F*G) % Q
 71      T3 = (E*H) % Q
 72      return (X3, Y3, Z3, T3)
 73  
 74  def add_elements(pt1, pt2): # extended->extended
 75      # add-2008-hwcd-3 . Slightly slower than add-2008-hwcd-4, but -3 is
 76      # unified, so it's safe for general-purpose addition
 77      (X1, Y1, Z1, T1) = pt1
 78      (X2, Y2, Z2, T2) = pt2
 79      A = ((Y1-X1)*(Y2-X2)) % Q
 80      B = ((Y1+X1)*(Y2+X2)) % Q
 81      C = T1*(2*d)*T2 % Q
 82      D = Z1*2*Z2 % Q
 83      E = (B-A) % Q
 84      F = (D-C) % Q
 85      G = (D+C) % Q
 86      H = (B+A) % Q
 87      X3 = (E*F) % Q
 88      Y3 = (G*H) % Q
 89      T3 = (E*H) % Q
 90      Z3 = (F*G) % Q
 91      return (X3, Y3, Z3, T3)
 92  
 93  def scalarmult_element_safe_slow(pt, n):
 94      # this form is slightly slower, but tolerates arbitrary points, including
 95      # those which are not in the main 1*L subgroup. This includes points of
 96      # order 1 (the neutral element Zero), 2, 4, and 8.
 97      assert n >= 0
 98      if n==0:
 99          return xform_affine_to_extended((0,1))
100      _ = double_element(scalarmult_element_safe_slow(pt, n>>1))
101      return add_elements(_, pt) if n&1 else _
102  
103  def _add_elements_nonunfied(pt1, pt2): # extended->extended
104      # add-2008-hwcd-4 : NOT unified, only for pt1!=pt2. About 10% faster than
105      # the (unified) add-2008-hwcd-3, and safe to use inside scalarmult if you
106      # aren't using points of order 1/2/4/8
107      (X1, Y1, Z1, T1) = pt1
108      (X2, Y2, Z2, T2) = pt2
109      A = ((Y1-X1)*(Y2+X2)) % Q
110      B = ((Y1+X1)*(Y2-X2)) % Q
111      C = (Z1*2*T2) % Q
112      D = (T1*2*Z2) % Q
113      E = (D+C) % Q
114      F = (B-A) % Q
115      G = (B+A) % Q
116      H = (D-C) % Q
117      X3 = (E*F) % Q
118      Y3 = (G*H) % Q
119      Z3 = (F*G) % Q
120      T3 = (E*H) % Q
121      return (X3, Y3, Z3, T3)
122  
123  def scalarmult_element(pt, n): # extended->extended
124      # This form only works properly when given points that are a member of
125      # the main 1*L subgroup. It will give incorrect answers when called with
126      # the points of order 1/2/4/8, including point Zero. (it will also work
127      # properly when given points of order 2*L/4*L/8*L)
128      assert n >= 0
129      if n==0:
130          return xform_affine_to_extended((0,1))
131      _ = double_element(scalarmult_element(pt, n>>1))
132      return _add_elements_nonunfied(_, pt) if n&1 else _
133  
134  # points are encoded as 32-bytes little-endian, b255 is sign, b2b1b0 are 0
135  
136  def encodepoint(P):
137      x = P[0]
138      y = P[1]
139      # MSB of output equals x.b0 (=x&1)
140      # rest of output is little-endian y
141      assert 0 <= y < (1<<255) # always < 0x7fff..ff
142      if x & 1:
143          y += 1<<255
144      return binascii.unhexlify("%064x" % y)[::-1]
145  
146  def isoncurve(P):
147      x = P[0]
148      y = P[1]
149      return (-x*x + y*y - 1 - d*x*x*y*y) % Q == 0
150  
151  class NotOnCurve(Exception):
152      pass
153  
154  def decodepoint(s):
155      unclamped = int(binascii.hexlify(s[:32][::-1]), 16)
156      clamp = (1 << 255) - 1
157      y = unclamped & clamp # clear MSB
158      x = xrecover(y)
159      if bool(x & 1) != bool(unclamped & (1<<255)): x = Q-x
160      P = [x,y]
161      if not isoncurve(P): raise NotOnCurve("decoding point that is not on curve")
162      return P
163  
164  # scalars are encoded as 32-bytes little-endian
165  
166  def bytes_to_scalar(s):
167      assert len(s) == 32, len(s)
168      return int(binascii.hexlify(s[::-1]), 16)
169  
170  def bytes_to_clamped_scalar(s):
171      # Ed25519 private keys clamp the scalar to ensure two things:
172      #   1: integer value is in L/2 .. L, to avoid small-logarithm
173      #      non-wraparaound
174      #   2: low-order 3 bits are zero, so a small-subgroup attack won't learn
175      #      any information
176      # set the top two bits to 01, and the bottom three to 000
177      a_unclamped = bytes_to_scalar(s)
178      AND_CLAMP = (1<<254) - 1 - 7
179      OR_CLAMP = (1<<254)
180      a_clamped = (a_unclamped & AND_CLAMP) | OR_CLAMP
181      return a_clamped
182  
183  def random_scalar(entropy_f): # 0..L-1 inclusive
184      # reduce the bias to a safe level by generating 256 extra bits
185      oversized = int(binascii.hexlify(entropy_f(32+32)), 16)
186      return oversized % L
187  
188  def password_to_scalar(pw):
189      oversized = hashlib.sha512(pw).digest()
190      return int(binascii.hexlify(oversized), 16) % L
191  
192  def scalar_to_bytes(y):
193      y = y % L
194      assert 0 <= y < 2**256
195      return binascii.unhexlify("%064x" % y)[::-1]
196  
197  # Elements, of various orders
198  
199  def is_extended_zero(XYTZ):
200      # catch Zero
201      (X, Y, Z, T) = XYTZ
202      Y = Y % Q
203      Z = Z % Q
204      if X==0 and Y==Z and Y!=0:
205          return True
206      return False
207  
208  class ElementOfUnknownGroup:
209      # This is used for points of order 2,4,8,2*L,4*L,8*L
210      def __init__(self, XYTZ):
211          assert isinstance(XYTZ, tuple)
212          assert len(XYTZ) == 4
213          self.XYTZ = XYTZ
214  
215      def add(self, other):
216          if not isinstance(other, ElementOfUnknownGroup):
217              raise TypeError("elements can only be added to other elements")
218          sum_XYTZ = add_elements(self.XYTZ, other.XYTZ)
219          if is_extended_zero(sum_XYTZ):
220              return Zero
221          return ElementOfUnknownGroup(sum_XYTZ)
222  
223      def scalarmult(self, s):
224          if isinstance(s, ElementOfUnknownGroup):
225              raise TypeError("elements cannot be multiplied together")
226          assert s >= 0
227          product = scalarmult_element_safe_slow(self.XYTZ, s)
228          return ElementOfUnknownGroup(product)
229  
230      def to_bytes(self):
231          return encodepoint(xform_extended_to_affine(self.XYTZ))
232      def __eq__(self, other):
233          return self.to_bytes() == other.to_bytes()
234      def __ne__(self, other):
235          return not self == other
236  
237  class Element(ElementOfUnknownGroup):
238      # this only holds elements in the main 1*L subgroup. It never holds Zero,
239      # or elements of order 1/2/4/8, or 2*L/4*L/8*L.
240  
241      def add(self, other):
242          if not isinstance(other, ElementOfUnknownGroup):
243              raise TypeError("elements can only be added to other elements")
244          sum_element = ElementOfUnknownGroup.add(self, other)
245          if sum_element is Zero:
246              return sum_element
247          if isinstance(other, Element):
248              # adding two subgroup elements results in another subgroup
249              # element, or Zero, and we've already excluded Zero
250              return Element(sum_element.XYTZ)
251          # not necessarily a subgroup member, so assume not
252          return sum_element
253  
254      def scalarmult(self, s):
255          if isinstance(s, ElementOfUnknownGroup):
256              raise TypeError("elements cannot be multiplied together")
257          # scalarmult of subgroup members can be done modulo the subgroup
258          # order, and using the faster non-unified function.
259          s = s % L
260          # scalarmult(s=0) gets you Zero
261          if s == 0:
262              return Zero
263          # scalarmult(s=1) gets you self, which is a subgroup member
264          # scalarmult(s<grouporder) gets you a different subgroup member
265          return Element(scalarmult_element(self.XYTZ, s))
266  
267      # negation and subtraction only make sense for the main subgroup
268      def negate(self):
269          # slow. Prefer e.scalarmult(-pw) to e.scalarmult(pw).negate()
270          return Element(scalarmult_element(self.XYTZ, L-2))
271      def subtract(self, other):
272          return self.add(other.negate())
273  
274  class _ZeroElement(ElementOfUnknownGroup):
275      def add(self, other):
276          return other # zero+anything = anything
277      def scalarmult(self, s):
278          return self # zero*anything = zero
279      def negate(self):
280          return self # -zero = zero
281      def subtract(self, other):
282          return self.add(other.negate())
283  
284  
285  Base = Element(xform_affine_to_extended(B))
286  Zero = _ZeroElement(xform_affine_to_extended((0,1))) # the neutral (identity) element
287  
288  _zero_bytes = Zero.to_bytes()
289  
290  
291  def arbitrary_element(seed): # unknown DL
292      # TODO: if we don't need uniformity, maybe use just sha256 here?
293      hseed = hashlib.sha512(seed).digest()
294      y = int(binascii.hexlify(hseed), 16) % Q
295  
296      # we try successive Y values until we find a valid point
297      for plus in itertools.count(0):
298          y_plus = (y + plus) % Q
299          x = xrecover(y_plus)
300          Pa = [x,y_plus] # no attempt to use both "positive" and "negative" X
301  
302          # only about 50% of Y coordinates map to valid curve points (I think
303          # the other half give you points on the "twist").
304          if not isoncurve(Pa):
305              continue
306  
307          P = ElementOfUnknownGroup(xform_affine_to_extended(Pa))
308          # even if the point is on our curve, it may not be in our particular
309          # (order=L) subgroup. The curve has order 8*L, so an arbitrary point
310          # could have order 1,2,4,8,1*L,2*L,4*L,8*L (everything which divides
311          # the group order).
312  
313          # [I MAY BE COMPLETELY WRONG ABOUT THIS, but my brief statistical
314          # tests suggest it's not too far off] There are phi(x) points with
315          # order x, so:
316          #  1 element of order 1: [(x=0,y=1)=Zero]
317          #  1 element of order 2 [(x=0,y=-1)]
318          #  2 elements of order 4
319          #  4 elements of order 8
320          #  L-1 elements of order L (including Base)
321          #  L-1 elements of order 2*L
322          #  2*(L-1) elements of order 4*L
323          #  4*(L-1) elements of order 8*L
324  
325          # So 50% of random points will have order 8*L, 25% will have order
326          # 4*L, 13% order 2*L, and 13% will have our desired order 1*L (and a
327          # vanishingly small fraction will have 1/2/4/8). If we multiply any
328          # of the 8*L points by 2, we're sure to get an 4*L point (and
329          # multiplying a 4*L point by 2 gives us a 2*L point, and so on).
330          # Multiplying a 1*L point by 2 gives us a different 1*L point. So
331          # multiplying by 8 gets us from almost any point into a uniform point
332          # on the correct 1*L subgroup.
333  
334          P8 = P.scalarmult(8)
335  
336          # if we got really unlucky and picked one of the 8 low-order points,
337          # multiplying by 8 will get us to the identity (Zero), which we check
338          # for explicitly.
339          if is_extended_zero(P8.XYTZ):
340              continue
341  
342          # Test that we're finally in the right group. We want to scalarmult
343          # by L, and we want to *not* use the trick in Group.scalarmult()
344          # which does x%L, because that would bypass the check we care about.
345          # P is still an _ElementOfUnknownGroup, which doesn't use x%L because
346          # that's not correct for points outside the main group.
347          assert is_extended_zero(P8.scalarmult(L).XYTZ)
348  
349          return Element(P8.XYTZ)
350      # never reached
351  
352  def bytes_to_unknown_group_element(bytes):
353      # this accepts all elements, including Zero and wrong-subgroup ones
354      if bytes == _zero_bytes:
355          return Zero
356      XYTZ = xform_affine_to_extended(decodepoint(bytes))
357      return ElementOfUnknownGroup(XYTZ)
358  
359  def bytes_to_element(bytes):
360      # this strictly only accepts elements in the right subgroup
361      P = bytes_to_unknown_group_element(bytes)
362      if P is Zero:
363          raise ValueError("element was Zero")
364      if not is_extended_zero(P.scalarmult(L).XYTZ):
365          raise ValueError("element is not in the right group")
366      # the point is in the expected 1*L subgroup, not in the 2/4/8 groups,
367      # or in the 2*L/4*L/8*L groups. Promote it to a correct-group Element.
368      return Element(P.XYTZ)