/ test / functional / test_framework / crypto / ellswift.py
ellswift.py
  1  #!/usr/bin/env python3
  2  # Copyright (c) 2022-present The Bitcoin Core developers
  3  # Distributed under the MIT software license, see the accompanying
  4  # file COPYING or http://www.opensource.org/licenses/mit-license.php.
  5  """Test-only Elligator Swift implementation
  6  
  7  WARNING: This code is slow and uses bad randomness.
  8  Do not use for anything but tests."""
  9  
 10  import csv
 11  import os
 12  import random
 13  import unittest
 14  
 15  from test_framework.crypto.secp256k1 import FE, G, GE
 16  from test_framework.util import assert_equal
 17  
 18  # Precomputed constant square root of -3 (mod p).
 19  MINUS_3_SQRT = FE(-3).sqrt()
 20  
 21  def xswiftec(u, t):
 22      """Decode field elements (u, t) to an X coordinate on the curve."""
 23      if u == 0:
 24          u = FE(1)
 25      if t == 0:
 26          t = FE(1)
 27      if u**3 + t**2 + 7 == 0:
 28          t = 2 * t
 29      X = (u**3 + 7 - t**2) / (2 * t)
 30      Y = (X + t) / (MINUS_3_SQRT * u)
 31      for x in (u + 4 * Y**2, (-X / Y - u) / 2, (X / Y - u) / 2):
 32          if GE.is_valid_x(x):
 33              return x
 34      assert False
 35  
 36  def xswiftec_inv(x, u, case):
 37      """Given x and u, find t such that xswiftec(u, t) = x, or return None.
 38  
 39      Case selects which of the up to 8 results to return."""
 40  
 41      if case & 2 == 0:
 42          if GE.is_valid_x(-x - u):
 43              return None
 44          v = x
 45          s = -(u**3 + 7) / (u**2 + u*v + v**2)
 46      else:
 47          s = x - u
 48          if s == 0:
 49              return None
 50          r = (-s * (4 * (u**3 + 7) + 3 * s * u**2)).sqrt()
 51          if r is None:
 52              return None
 53          if case & 1 and r == 0:
 54              return None
 55          v = (-u + r / s) / 2
 56      w = s.sqrt()
 57      if w is None:
 58          return None
 59      if case & 5 == 0:
 60          return -w * (u * (1 - MINUS_3_SQRT) / 2 + v)
 61      if case & 5 == 1:
 62          return w * (u * (1 + MINUS_3_SQRT) / 2 + v)
 63      if case & 5 == 4:
 64          return w * (u * (1 - MINUS_3_SQRT) / 2 + v)
 65      if case & 5 == 5:
 66          return -w * (u * (1 + MINUS_3_SQRT) / 2 + v)
 67  
 68  def xelligatorswift(x):
 69      """Given a field element X on the curve, find (u, t) that encode them."""
 70      assert GE.is_valid_x(x)
 71      while True:
 72          u = FE(random.randrange(1, FE.SIZE))
 73          case = random.randrange(0, 8)
 74          t = xswiftec_inv(x, u, case)
 75          if t is not None:
 76              return u, t
 77  
 78  def ellswift_create():
 79      """Generate a (privkey, ellswift_pubkey) pair."""
 80      priv = random.randrange(1, GE.ORDER)
 81      u, t = xelligatorswift((priv * G).x)
 82      return priv.to_bytes(32, 'big'), u.to_bytes() + t.to_bytes()
 83  
 84  def ellswift_ecdh_xonly(pubkey_theirs, privkey):
 85      """Compute X coordinate of shared ECDH point between ellswift pubkey and privkey."""
 86      u = FE(int.from_bytes(pubkey_theirs[:32], 'big'))
 87      t = FE(int.from_bytes(pubkey_theirs[32:], 'big'))
 88      d = int.from_bytes(privkey, 'big')
 89      return (d * GE.lift_x(xswiftec(u, t))).x.to_bytes()
 90  
 91  
 92  class TestFrameworkEllSwift(unittest.TestCase):
 93      def test_xswiftec(self):
 94          """Verify that xswiftec maps all inputs to the curve."""
 95          for _ in range(32):
 96              u = FE(random.randrange(0, FE.SIZE))
 97              t = FE(random.randrange(0, FE.SIZE))
 98              x = xswiftec(u, t)
 99              self.assertTrue(GE.is_valid_x(x))
100  
101          # Check that inputs which are considered undefined in the original
102          # SwiftEC paper can also be decoded successfully (by remapping)
103          undefined_inputs = [
104              (FE(0), FE(23)),  # u = 0
105              (FE(42), FE(0)),  # t = 0
106              (FE(5), FE(-132).sqrt()),  # u^3 + t^2 + 7 = 0
107          ]
108          assert_equal(undefined_inputs[-1][0]**3 + undefined_inputs[-1][1]**2 + 7, 0)
109          for u, t in undefined_inputs:
110              x = xswiftec(u, t)
111              self.assertTrue(GE.is_valid_x(x))
112  
113      def test_elligator_roundtrip(self):
114          """Verify that encoding using xelligatorswift decodes back using xswiftec."""
115          for _ in range(32):
116              while True:
117                  # Loop until we find a valid X coordinate on the curve.
118                  x = FE(random.randrange(1, FE.SIZE))
119                  if GE.is_valid_x(x):
120                      break
121              # Encoding it to (u, t), decode it back, and compare.
122              u, t = xelligatorswift(x)
123              x2 = xswiftec(u, t)
124              self.assertEqual(x2, x)
125  
126      def test_ellswift_ecdh_xonly(self):
127          """Verify that shared secret computed by ellswift_ecdh_xonly match."""
128          for _ in range(32):
129              privkey1, encoding1 = ellswift_create()
130              privkey2, encoding2 = ellswift_create()
131              shared_secret1 = ellswift_ecdh_xonly(encoding1, privkey2)
132              shared_secret2 = ellswift_ecdh_xonly(encoding2, privkey1)
133              self.assertEqual(shared_secret1, shared_secret2)
134  
135      def test_elligator_encode_testvectors(self):
136          """Implement the BIP324 test vectors for ellswift encoding (read from xswiftec_inv_test_vectors.csv)."""
137          vectors_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'xswiftec_inv_test_vectors.csv')
138          with open(vectors_file, newline='') as csvfile:
139              reader = csv.DictReader(csvfile)
140              for row in reader:
141                  u = FE.from_bytes(bytes.fromhex(row['u']))
142                  x = FE.from_bytes(bytes.fromhex(row['x']))
143                  for case in range(8):
144                      ret = xswiftec_inv(x, u, case)
145                      if ret is None:
146                          self.assertEqual(row[f"case{case}_t"], "")
147                      else:
148                          self.assertEqual(row[f"case{case}_t"], ret.to_bytes().hex())
149                          self.assertEqual(xswiftec(u, ret), x)
150  
151      def test_elligator_decode_testvectors(self):
152          """Implement the BIP324 test vectors for ellswift decoding (read from ellswift_decode_test_vectors.csv)."""
153          vectors_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'ellswift_decode_test_vectors.csv')
154          with open(vectors_file, newline='') as csvfile:
155              reader = csv.DictReader(csvfile)
156              for row in reader:
157                  encoding = bytes.fromhex(row['ellswift'])
158                  assert_equal(len(encoding), 64)
159                  expected_x = FE(int(row['x'], 16))
160                  u = FE(int.from_bytes(encoding[:32], 'big'))
161                  t = FE(int.from_bytes(encoding[32:], 'big'))
162                  x = xswiftec(u, t)
163                  self.assertEqual(x, expected_x)
164                  self.assertTrue(GE.is_valid_x(x))