/ test / functional / test_framework / crypto / secp256k1.py
secp256k1.py
  1  # Copyright (c) 2022-present The Bitcoin Core developers
  2  # Distributed under the MIT software license, see the accompanying
  3  # file COPYING or http://www.opensource.org/licenses/mit-license.php.
  4  
  5  """Test-only implementation of low-level secp256k1 field and group arithmetic
  6  
  7  It is designed for ease of understanding, not performance.
  8  
  9  WARNING: This code is slow and trivially vulnerable to side channel attacks. Do not use for
 10  anything but tests.
 11  
 12  Exports:
 13  * FE: class for secp256k1 field elements
 14  * GE: class for secp256k1 group elements
 15  * G: the secp256k1 generator point
 16  """
 17  
 18  import unittest
 19  from hashlib import sha256
 20  from test_framework.util import assert_equal, assert_not_equal
 21  
 22  class FE:
 23      """Objects of this class represent elements of the field GF(2**256 - 2**32 - 977).
 24  
 25      They are represented internally in numerator / denominator form, in order to delay inversions.
 26      """
 27  
 28      # The size of the field (also its modulus and characteristic).
 29      SIZE = 2**256 - 2**32 - 977
 30  
 31      def __init__(self, a=0, b=1):
 32          """Initialize a field element a/b; both a and b can be ints or field elements."""
 33          if isinstance(a, FE):
 34              num = a._num
 35              den = a._den
 36          else:
 37              num = a % FE.SIZE
 38              den = 1
 39          if isinstance(b, FE):
 40              den = (den * b._num) % FE.SIZE
 41              num = (num * b._den) % FE.SIZE
 42          else:
 43              den = (den * b) % FE.SIZE
 44          assert_not_equal(den, 0)
 45          if num == 0:
 46              den = 1
 47          self._num = num
 48          self._den = den
 49  
 50      def __add__(self, a):
 51          """Compute the sum of two field elements (second may be int)."""
 52          if isinstance(a, FE):
 53              return FE(self._num * a._den + self._den * a._num, self._den * a._den)
 54          return FE(self._num + self._den * a, self._den)
 55  
 56      def __radd__(self, a):
 57          """Compute the sum of an integer and a field element."""
 58          return FE(a) + self
 59  
 60      def __sub__(self, a):
 61          """Compute the difference of two field elements (second may be int)."""
 62          if isinstance(a, FE):
 63              return FE(self._num * a._den - self._den * a._num, self._den * a._den)
 64          return FE(self._num - self._den * a, self._den)
 65  
 66      def __rsub__(self, a):
 67          """Compute the difference of an integer and a field element."""
 68          return FE(a) - self
 69  
 70      def __mul__(self, a):
 71          """Compute the product of two field elements (second may be int)."""
 72          if isinstance(a, FE):
 73              return FE(self._num * a._num, self._den * a._den)
 74          return FE(self._num * a, self._den)
 75  
 76      def __rmul__(self, a):
 77          """Compute the product of an integer with a field element."""
 78          return FE(a) * self
 79  
 80      def __truediv__(self, a):
 81          """Compute the ratio of two field elements (second may be int)."""
 82          return FE(self, a)
 83  
 84      def __pow__(self, a):
 85          """Raise a field element to an integer power."""
 86          return FE(pow(self._num, a, FE.SIZE), pow(self._den, a, FE.SIZE))
 87  
 88      def __neg__(self):
 89          """Negate a field element."""
 90          return FE(-self._num, self._den)
 91  
 92      def __int__(self):
 93          """Convert a field element to an integer in range 0..p-1. The result is cached."""
 94          if self._den != 1:
 95              self._num = (self._num * pow(self._den, -1, FE.SIZE)) % FE.SIZE
 96              self._den = 1
 97          return self._num
 98  
 99      def sqrt(self):
100          """Compute the square root of a field element if it exists (None otherwise).
101  
102          Due to the fact that our modulus is of the form (p % 4) == 3, the Tonelli-Shanks
103          algorithm (https://en.wikipedia.org/wiki/Tonelli-Shanks_algorithm) is simply
104          raising the argument to the power (p + 1) / 4.
105  
106          To see why: (p-1) % 2 = 0, so 2 divides the order of the multiplicative group,
107          and thus only half of the non-zero field elements are squares. An element a is
108          a (nonzero) square when Euler's criterion, a^((p-1)/2) = 1 (mod p), holds. We're
109          looking for x such that x^2 = a (mod p). Given a^((p-1)/2) = 1, that is equivalent
110          to x^2 = a^(1 + (p-1)/2) mod p. As (1 + (p-1)/2) is even, this is equivalent to
111          x = a^((1 + (p-1)/2)/2) mod p, or x = a^((p+1)/4) mod p."""
112          v = int(self)
113          s = pow(v, (FE.SIZE + 1) // 4, FE.SIZE)
114          if s**2 % FE.SIZE == v:
115              return FE(s)
116          return None
117  
118      def is_square(self):
119          """Determine if this field element has a square root."""
120          # A more efficient algorithm is possible here (Jacobi symbol).
121          return self.sqrt() is not None
122  
123      def is_even(self):
124          """Determine whether this field element, represented as integer in 0..p-1, is even."""
125          return int(self) & 1 == 0
126  
127      def __eq__(self, a):
128          """Check whether two field elements are equal (second may be an int)."""
129          if isinstance(a, FE):
130              return (self._num * a._den - self._den * a._num) % FE.SIZE == 0
131          return (self._num - self._den * a) % FE.SIZE == 0
132  
133      def to_bytes(self):
134          """Convert a field element to a 32-byte array (BE byte order)."""
135          return int(self).to_bytes(32, 'big')
136  
137      @staticmethod
138      def from_bytes(b):
139          """Convert a 32-byte array to a field element (BE byte order, no overflow allowed)."""
140          v = int.from_bytes(b, 'big')
141          if v >= FE.SIZE:
142              return None
143          return FE(v)
144  
145      def __str__(self):
146          """Convert this field element to a 64 character hex string."""
147          return f"{int(self):064x}"
148  
149      def __repr__(self):
150          """Get a string representation of this field element."""
151          return f"FE(0x{int(self):x})"
152  
153  
154  class GE:
155      """Objects of this class represent secp256k1 group elements (curve points or infinity)
156  
157      Normal points on the curve have fields:
158      * x: the x coordinate (a field element)
159      * y: the y coordinate (a field element, satisfying y^2 = x^3 + 7)
160      * infinity: False
161  
162      The point at infinity has field:
163      * infinity: True
164      """
165  
166      # Order of the group (number of points on the curve, plus 1 for infinity)
167      ORDER = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
168  
169      # Number of valid distinct x coordinates on the curve.
170      ORDER_HALF = ORDER // 2
171  
172      def __init__(self, x=None, y=None):
173          """Initialize a group element with specified x and y coordinates, or infinity."""
174          if x is None:
175              # Initialize as infinity.
176              assert y is None
177              self.infinity = True
178          else:
179              # Initialize as point on the curve (and check that it is).
180              fx = FE(x)
181              fy = FE(y)
182              assert_equal(fy**2, fx**3 + 7)
183              self.infinity = False
184              self.x = fx
185              self.y = fy
186  
187      def __add__(self, a):
188          """Add two group elements together."""
189          # Deal with infinity: a + infinity == infinity + a == a.
190          if self.infinity:
191              return a
192          if a.infinity:
193              return self
194          if self.x == a.x:
195              if self.y != a.y:
196                  # A point added to its own negation is infinity.
197                  assert_equal(self.y + a.y, 0)
198                  return GE()
199              else:
200                  # For identical inputs, use the tangent (doubling formula).
201                  lam = (3 * self.x**2) / (2 * self.y)
202          else:
203              # For distinct inputs, use the line through both points (adding formula).
204              lam = (self.y - a.y) / (self.x - a.x)
205          # Determine point opposite to the intersection of that line with the curve.
206          x = lam**2 - (self.x + a.x)
207          y = lam * (self.x - x) - self.y
208          return GE(x, y)
209  
210      @staticmethod
211      def mul(*aps):
212          """Compute a (batch) scalar group element multiplication.
213  
214          GE.mul((a1, p1), (a2, p2), (a3, p3)) is identical to a1*p1 + a2*p2 + a3*p3,
215          but more efficient."""
216          # Reduce all the scalars modulo order first (so we can deal with negatives etc).
217          naps = [(a % GE.ORDER, p) for a, p in aps]
218          # Start with point at infinity.
219          r = GE()
220          # Iterate over all bit positions, from high to low.
221          for i in range(255, -1, -1):
222              # Double what we have so far.
223              r = r + r
224              # Add then add the points for which the corresponding scalar bit is set.
225              for (a, p) in naps:
226                  if (a >> i) & 1:
227                      r += p
228          return r
229  
230      def __rmul__(self, a):
231          """Multiply an integer with a group element."""
232          if self == G:
233              return FAST_G.mul(a)
234          return GE.mul((a, self))
235  
236      def __neg__(self):
237          """Compute the negation of a group element."""
238          if self.infinity:
239              return self
240          return GE(self.x, -self.y)
241  
242      def to_bytes_compressed(self):
243          """Convert a non-infinite group element to 33-byte compressed encoding."""
244          assert not self.infinity
245          return bytes([3 - self.y.is_even()]) + self.x.to_bytes()
246  
247      def to_bytes_uncompressed(self):
248          """Convert a non-infinite group element to 65-byte uncompressed encoding."""
249          assert not self.infinity
250          return b'\x04' + self.x.to_bytes() + self.y.to_bytes()
251  
252      def to_bytes_xonly(self):
253          """Convert (the x coordinate of) a non-infinite group element to 32-byte xonly encoding."""
254          assert not self.infinity
255          return self.x.to_bytes()
256  
257      @staticmethod
258      def lift_x(x):
259          """Return group element with specified field element as x coordinate (and even y)."""
260          y = (FE(x)**3 + 7).sqrt()
261          if y is None:
262              return None
263          if not y.is_even():
264              y = -y
265          return GE(x, y)
266  
267      @staticmethod
268      def from_bytes(b):
269          """Convert a compressed or uncompressed encoding to a group element."""
270          assert len(b) in (33, 65)
271          if len(b) == 33:
272              if b[0] != 2 and b[0] != 3:
273                  return None
274              x = FE.from_bytes(b[1:])
275              if x is None:
276                  return None
277              r = GE.lift_x(x)
278              if r is None:
279                  return None
280              if b[0] == 3:
281                  r = -r
282              return r
283          else:
284              if b[0] != 4:
285                  return None
286              x = FE.from_bytes(b[1:33])
287              y = FE.from_bytes(b[33:])
288              if y**2 != x**3 + 7:
289                  return None
290              return GE(x, y)
291  
292      @staticmethod
293      def from_bytes_xonly(b):
294          """Convert a point given in xonly encoding to a group element."""
295          assert_equal(len(b), 32)
296          x = FE.from_bytes(b)
297          if x is None:
298              return None
299          return GE.lift_x(x)
300  
301      @staticmethod
302      def is_valid_x(x):
303          """Determine whether the provided field element is a valid X coordinate."""
304          return (FE(x)**3 + 7).is_square()
305  
306      def __str__(self):
307          """Convert this group element to a string."""
308          if self.infinity:
309              return "(inf)"
310          return f"({self.x},{self.y})"
311  
312      def __repr__(self):
313          """Get a string representation for this group element."""
314          if self.infinity:
315              return "GE()"
316          return f"GE(0x{int(self.x):x},0x{int(self.y):x})"
317  
318  # The secp256k1 generator point
319  G = GE.lift_x(0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798)
320  
321  
322  class FastGEMul:
323      """Table for fast multiplication with a constant group element.
324  
325      Speed up scalar multiplication with a fixed point P by using a precomputed lookup table with
326      its powers of 2:
327  
328          table = [P, 2*P, 4*P, (2^3)*P, (2^4)*P, ..., (2^255)*P]
329  
330      During multiplication, the points corresponding to each bit set in the scalar are added up,
331      i.e. on average ~128 point additions take place.
332      """
333  
334      def __init__(self, p):
335          self.table = [p]  # table[i] = (2^i) * p
336          for _ in range(255):
337              p = p + p
338              self.table.append(p)
339  
340      def mul(self, a):
341          result = GE()
342          a = a % GE.ORDER
343          for bit in range(a.bit_length()):
344              if a & (1 << bit):
345                  result += self.table[bit]
346          return result
347  
348  # Precomputed table with multiples of G for fast multiplication
349  FAST_G = FastGEMul(G)
350  
351  class TestFrameworkSecp256k1(unittest.TestCase):
352      def test_H(self):
353          H = sha256(G.to_bytes_uncompressed()).digest()
354          assert GE.lift_x(FE.from_bytes(H)) is not None
355          self.assertEqual(H.hex(), "50929b74c1a04954b78b4b6035e97a5e078a5a0f28ec96d547bfee9ace803ac0")