/ python / _test_.py
_test_.py
  1  import unittest
  2  import hashlib
  3  import csv
  4  import os
  5  import asn1
  6  import shutil
  7  import python.proof as proof
  8  import sys
  9  import json
 10  import secrets
 11  import subprocess
 12  import python.generate as generate
 13  import base64
 14  from unittest import mock
 15  from web3.auto import w3
 16  from eth_account.messages import encode_defunct
 17  from hexbytes import HexBytes
 18  from helpers.common import Metadata
 19  from helpers.merkle import MerkleTree
 20  from cryptography.hazmat.primitives import serialization as crypto_serialization
 21  from cryptography.hazmat.primitives.asymmetric import rsa
 22  from cryptography.hazmat.backends import default_backend as crypto_default_backend
 23  
 24  TEST_FILES_PATH = os.path.join(os. getcwd(), "__test_cache__")
 25  CSV_TEST_PATH = os.path.join(TEST_FILES_PATH, "test.csv")
 26  USERS = ["user_1", "user_2", "user_3"]
 27  USERS_WITH_DOUBLE_KEYS = ["user_4"]
 28  SECOND_KEY_POSTFIX = "_second"
 29  
 30  # https://github.com/ethereum/EIPs/blob/master/EIPS/eip-2.md
 31  SECP256K1_N = int(
 32      0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141
 33  )
 34  SECP256K1_HALF_N = int(
 35      SECP256K1_N/2
 36  )
 37  
 38  
 39  def init():
 40      if (os.path.exists(TEST_FILES_PATH)):
 41          shutil.rmtree(TEST_FILES_PATH, ignore_errors=False, onerror=None)
 42  
 43      os.mkdir(TEST_FILES_PATH)
 44  
 45      with open(CSV_TEST_PATH, 'w') as f:
 46          writer = csv.writer(f, delimiter=",")
 47          for user in USERS:
 48              pubKey = create_key(user)
 49              writer.writerow([user, pubKey])
 50  
 51          for user in USERS_WITH_DOUBLE_KEYS:
 52              pubKey = create_key(user)
 53              writer.writerow([user, pubKey])
 54  
 55              pubKey = create_key(user + SECOND_KEY_POSTFIX)
 56              writer.writerow([user, pubKey])
 57      return
 58  
 59  
 60  def create_key(name):
 61      key = rsa.generate_private_key(
 62          backend=crypto_default_backend(),
 63          public_exponent=65537,
 64          key_size=2048
 65      )
 66      privateKey = key.private_bytes(
 67          crypto_serialization.Encoding.PEM,
 68          crypto_serialization.PrivateFormat.PKCS8,
 69          crypto_serialization.NoEncryption()
 70      )
 71  
 72      publicKey = key.public_key().public_bytes(
 73          crypto_serialization.Encoding.OpenSSH,
 74          crypto_serialization.PublicFormat.OpenSSH
 75      )
 76  
 77      with open(os.path.join(TEST_FILES_PATH, name), 'w') as f:
 78          f.write(privateKey.decode())
 79  
 80      with open(os.path.join(TEST_FILES_PATH, name + ".pub"), 'w') as f:
 81          f.write(publicKey.decode())
 82  
 83      return publicKey.decode()
 84  
 85  
 86  def call_generate():
 87      sys.argv = ["", "--output", f"{TEST_FILES_PATH}",
 88                  os.path.join(TEST_FILES_PATH, "test.csv")]
 89  
 90      generate.main()
 91  
 92      with open(os.path.join(TEST_FILES_PATH, "metadata.json"), "r") as f:
 93          return hashlib.sha256(f.read().encode('utf-8')).hexdigest()
 94  
 95  
 96  def call_proof():
 97      sys.argv = ["", os.path.join(TEST_FILES_PATH, "metadata.json")]
 98      proof.main()
 99  
100  
101  def call_proof_sh(user, address, keyPath):
102      result = subprocess.run([
103          "./proof-sh/proof.sh",
104          user,
105          address,
106          os.path.join(TEST_FILES_PATH, "metadata.bin"), keyPath
107      ], capture_output=True)
108      if result.returncode != 0:
109          raise OSError(result)
110  
111      return result.stdout.decode()
112  
113  
114  def assert_error(test, fn, err, msg):
115      try:
116          fn()
117      except err as e:
118          test.assertEqual(msg, str(e))
119          return
120  
121      test.fail("Did not raise the exception")
122  
123  
124  def parse_print_out(address, print_mock):
125      printList = print_mock.call_args_list
126      out = printList[len(printList)-1].args[0].split(",")
127  
128      index = int(out[0])
129      tempAddress = out[1]
130      signature = out[2]
131      merkleProof = json.loads(str(base64.b64decode(out[3]).decode()))
132  
133      return index, tempAddress, signature, merkleProof
134  
135  
136  def parse_sh_out(out):
137      result = out.split('\n')
138      result = result[len(result)-2].split(",")
139  
140      index = int(result[0])
141      tempAddress = result[1]
142      asn1Signature = result[2]
143      merkleProof = json.loads(str(base64.b64decode(result[3]).decode()))
144  
145      return index, tempAddress, asn1Signature, merkleProof
146  
147  
148  def is_valid_asn1_sign(asn1Sign, tempAddress, address):
149      decoder = asn1.Decoder()
150      decoder.start(bytes.fromhex(asn1Sign))
151      _, value = decoder.read()
152      decoder.start(value)
153      _, r = decoder.read()
154      _, s = decoder.read()
155  
156      if (s > SECP256K1_HALF_N):
157          s = SECP256K1_N - s
158  
159      rs = r.to_bytes(32, "big") + s.to_bytes(32, "big")
160      sign = rs + int(27).to_bytes(1, "big")
161  
162      one = w3.eth.account.recover_message(
163          encode_defunct(hexstr=address), signature=HexBytes(sign)
164      ).lower() == tempAddress
165  
166      sign = rs + int(28).to_bytes(1, "big")
167      two = w3.eth.account.recover_message(
168          encode_defunct(hexstr=address), signature=HexBytes(sign)
169      ).lower() == tempAddress
170  
171      return one or two
172  
173  
174  def verify_proof(test, index, tempAddress, signature, merkleProof, address, metadata):
175      tree = MerkleTree(metadata.addresses)
176      proof = tree.get_proof(index)
177      test.assertEqual(proof, merkleProof)
178  
179      test.assertEqual(proof, merkleProof)
180      test.assertEqual(
181          w3.eth.account.recover_message(
182              encode_defunct(hexstr=address), signature=HexBytes(signature)
183          ).lower(),
184          tempAddress
185      )
186  
187  
188  def verify_proof_by_user(test, user, address, keyFileName, print_mock, metadata):
189      path = os.path.join(TEST_FILES_PATH, keyFileName)
190      with mock.patch(
191          'builtins.input',
192          side_effect=[user, address, path]
193      ):
194          call_proof()
195  
196      index1, tempAddress1, signature1, merkleProof1 = parse_print_out(
197          address, print_mock
198      )
199      verify_proof(
200          test,
201          index1,
202          tempAddress1,
203          signature1,
204          merkleProof1,
205          address,
206          metadata
207      )
208  
209      result = call_proof_sh(user, address, path)
210      index2, tempAddress2, asn1Signatrue, merkleProof2 = parse_sh_out(
211          result
212      )
213  
214      test.assertEqual(index1, index2)
215      test.assertEqual(tempAddress1, tempAddress2)
216      test.assertTrue(
217          is_valid_asn1_sign(
218              asn1Signatrue, tempAddress1, address
219          )
220      )
221      test.assertEqual(merkleProof1, merkleProof2)
222  
223  
224  @mock.patch('builtins.print')
225  class Test(unittest.TestCase):
226      @mock.patch('subprocess.run', side_effect=[subprocess.CompletedProcess([], 1, "", "age test error")])
227      def test_generate_error(self, print_mock, subprocess_mock):
228          assert_error(self, call_generate, OSError, "age test error")
229          return
230  
231      def test_generate(self, print_mock):
232          self.assertNotEqual(call_generate(), call_generate())
233  
234      @mock.patch('builtins.input', side_effect=["user_999"])
235      def test_proof_user_not_found(self, print_mock, inputMock):
236          assert_error(self, call_proof, ValueError, "User not found")
237  
238      @mock.patch('builtins.input', side_effect=[USERS[0], "invalidTestAddress"])
239      def test_proof_invalid_eth_address(self, print_mock, inputMock):
240          assert_error(self, call_proof, ValueError, "Invalid Ethereum address")
241  
242      @mock.patch('builtins.input', side_effect=[USERS[0], "0x0000000000000000000000000000000000000006", "key-path"])
243      @mock.patch('os.path.exists', side_effect=lambda x: mock.DEFAULT if (x != "key-path") else False)
244      def test_proof_file_is_not_exist(self, print_mock, inputMock, pathExistsMock):
245          assert_error(self, call_proof, FileNotFoundError, "File is not exist")
246  
247      @mock.patch('builtins.input', side_effect=[USERS[0], "0x0000000000000000000000000000000000000006", "key-path"])
248      @mock.patch('os.path.exists', side_effect=lambda x: mock.DEFAULT if (x != "key-path") else True)
249      @mock.patch('os.path.isfile', side_effect=lambda x: mock.DEFAULT if (x != "key-path") else False)
250      def test_proof_file_is_not_ssh_key(self, print_mock, inputMock, pathExistsMock, isFileMock):
251          assert_error(self, call_proof, ValueError, "File is not a SSH key")
252  
253      def test_proof(self, print_mock):
254          metadata = {}
255          with open(os.path.join(TEST_FILES_PATH, "metadata.json"), "r") as f:
256              metadata = Metadata.from_json(f.read())
257  
258          for user in USERS:
259              verify_proof_by_user(
260                  self,
261                  user,
262                  secrets.token_bytes(20).hex(),
263                  user,
264                  print_mock,
265                  metadata
266              )
267  
268          for user in USERS_WITH_DOUBLE_KEYS:
269              verify_proof_by_user(
270                  self,
271                  user,
272                  secrets.token_bytes(20).hex(),
273                  user,
274                  print_mock,
275                  metadata
276              )
277              verify_proof_by_user(
278                  self,
279                  user,
280                  secrets.token_bytes(20).hex(),
281                  user + SECOND_KEY_POSTFIX,
282                  print_mock,
283                  metadata
284              )
285  
286  
287  if __name__ == '__main__':
288      init()
289      unittest.main()