_test_.py
1 import unittest 2 import hashlib 3 import csv 4 import os 5 import asn1 6 import shutil 7 import python.proof as proof 8 import sys 9 import json 10 import secrets 11 import subprocess 12 import python.generate as generate 13 import base64 14 from unittest import mock 15 from web3.auto import w3 16 from eth_account.messages import encode_defunct 17 from hexbytes import HexBytes 18 from helpers.common import Metadata 19 from helpers.merkle import MerkleTree 20 from cryptography.hazmat.primitives import serialization as crypto_serialization 21 from cryptography.hazmat.primitives.asymmetric import rsa 22 from cryptography.hazmat.backends import default_backend as crypto_default_backend 23 24 TEST_FILES_PATH = os.path.join(os. getcwd(), "__test_cache__") 25 CSV_TEST_PATH = os.path.join(TEST_FILES_PATH, "test.csv") 26 USERS = ["user_1", "user_2", "user_3"] 27 USERS_WITH_DOUBLE_KEYS = ["user_4"] 28 SECOND_KEY_POSTFIX = "_second" 29 30 # https://github.com/ethereum/EIPs/blob/master/EIPS/eip-2.md 31 SECP256K1_N = int( 32 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141 33 ) 34 SECP256K1_HALF_N = int( 35 SECP256K1_N/2 36 ) 37 38 39 def init(): 40 if (os.path.exists(TEST_FILES_PATH)): 41 shutil.rmtree(TEST_FILES_PATH, ignore_errors=False, onerror=None) 42 43 os.mkdir(TEST_FILES_PATH) 44 45 with open(CSV_TEST_PATH, 'w') as f: 46 writer = csv.writer(f, delimiter=",") 47 for user in USERS: 48 pubKey = create_key(user) 49 writer.writerow([user, pubKey]) 50 51 for user in USERS_WITH_DOUBLE_KEYS: 52 pubKey = create_key(user) 53 writer.writerow([user, pubKey]) 54 55 pubKey = create_key(user + SECOND_KEY_POSTFIX) 56 writer.writerow([user, pubKey]) 57 return 58 59 60 def create_key(name): 61 key = rsa.generate_private_key( 62 backend=crypto_default_backend(), 63 public_exponent=65537, 64 key_size=2048 65 ) 66 privateKey = key.private_bytes( 67 crypto_serialization.Encoding.PEM, 68 crypto_serialization.PrivateFormat.PKCS8, 69 crypto_serialization.NoEncryption() 70 ) 71 72 publicKey = key.public_key().public_bytes( 73 crypto_serialization.Encoding.OpenSSH, 74 crypto_serialization.PublicFormat.OpenSSH 75 ) 76 77 with open(os.path.join(TEST_FILES_PATH, name), 'w') as f: 78 f.write(privateKey.decode()) 79 80 with open(os.path.join(TEST_FILES_PATH, name + ".pub"), 'w') as f: 81 f.write(publicKey.decode()) 82 83 return publicKey.decode() 84 85 86 def call_generate(): 87 sys.argv = ["", "--output", f"{TEST_FILES_PATH}", 88 os.path.join(TEST_FILES_PATH, "test.csv")] 89 90 generate.main() 91 92 with open(os.path.join(TEST_FILES_PATH, "metadata.json"), "r") as f: 93 return hashlib.sha256(f.read().encode('utf-8')).hexdigest() 94 95 96 def call_proof(): 97 sys.argv = ["", os.path.join(TEST_FILES_PATH, "metadata.json")] 98 proof.main() 99 100 101 def call_proof_sh(user, address, keyPath): 102 result = subprocess.run([ 103 "./proof-sh/proof.sh", 104 user, 105 address, 106 os.path.join(TEST_FILES_PATH, "metadata.bin"), keyPath 107 ], capture_output=True) 108 if result.returncode != 0: 109 raise OSError(result) 110 111 return result.stdout.decode() 112 113 114 def assert_error(test, fn, err, msg): 115 try: 116 fn() 117 except err as e: 118 test.assertEqual(msg, str(e)) 119 return 120 121 test.fail("Did not raise the exception") 122 123 124 def parse_print_out(address, print_mock): 125 printList = print_mock.call_args_list 126 out = printList[len(printList)-1].args[0].split(",") 127 128 index = int(out[0]) 129 tempAddress = out[1] 130 signature = out[2] 131 merkleProof = json.loads(str(base64.b64decode(out[3]).decode())) 132 133 return index, tempAddress, signature, merkleProof 134 135 136 def parse_sh_out(out): 137 result = out.split('\n') 138 result = result[len(result)-2].split(",") 139 140 index = int(result[0]) 141 tempAddress = result[1] 142 asn1Signature = result[2] 143 merkleProof = json.loads(str(base64.b64decode(result[3]).decode())) 144 145 return index, tempAddress, asn1Signature, merkleProof 146 147 148 def is_valid_asn1_sign(asn1Sign, tempAddress, address): 149 decoder = asn1.Decoder() 150 decoder.start(bytes.fromhex(asn1Sign)) 151 _, value = decoder.read() 152 decoder.start(value) 153 _, r = decoder.read() 154 _, s = decoder.read() 155 156 if (s > SECP256K1_HALF_N): 157 s = SECP256K1_N - s 158 159 rs = r.to_bytes(32, "big") + s.to_bytes(32, "big") 160 sign = rs + int(27).to_bytes(1, "big") 161 162 one = w3.eth.account.recover_message( 163 encode_defunct(hexstr=address), signature=HexBytes(sign) 164 ).lower() == tempAddress 165 166 sign = rs + int(28).to_bytes(1, "big") 167 two = w3.eth.account.recover_message( 168 encode_defunct(hexstr=address), signature=HexBytes(sign) 169 ).lower() == tempAddress 170 171 return one or two 172 173 174 def verify_proof(test, index, tempAddress, signature, merkleProof, address, metadata): 175 tree = MerkleTree(metadata.addresses) 176 proof = tree.get_proof(index) 177 test.assertEqual(proof, merkleProof) 178 179 test.assertEqual(proof, merkleProof) 180 test.assertEqual( 181 w3.eth.account.recover_message( 182 encode_defunct(hexstr=address), signature=HexBytes(signature) 183 ).lower(), 184 tempAddress 185 ) 186 187 188 def verify_proof_by_user(test, user, address, keyFileName, print_mock, metadata): 189 path = os.path.join(TEST_FILES_PATH, keyFileName) 190 with mock.patch( 191 'builtins.input', 192 side_effect=[user, address, path] 193 ): 194 call_proof() 195 196 index1, tempAddress1, signature1, merkleProof1 = parse_print_out( 197 address, print_mock 198 ) 199 verify_proof( 200 test, 201 index1, 202 tempAddress1, 203 signature1, 204 merkleProof1, 205 address, 206 metadata 207 ) 208 209 result = call_proof_sh(user, address, path) 210 index2, tempAddress2, asn1Signatrue, merkleProof2 = parse_sh_out( 211 result 212 ) 213 214 test.assertEqual(index1, index2) 215 test.assertEqual(tempAddress1, tempAddress2) 216 test.assertTrue( 217 is_valid_asn1_sign( 218 asn1Signatrue, tempAddress1, address 219 ) 220 ) 221 test.assertEqual(merkleProof1, merkleProof2) 222 223 224 @mock.patch('builtins.print') 225 class Test(unittest.TestCase): 226 @mock.patch('subprocess.run', side_effect=[subprocess.CompletedProcess([], 1, "", "age test error")]) 227 def test_generate_error(self, print_mock, subprocess_mock): 228 assert_error(self, call_generate, OSError, "age test error") 229 return 230 231 def test_generate(self, print_mock): 232 self.assertNotEqual(call_generate(), call_generate()) 233 234 @mock.patch('builtins.input', side_effect=["user_999"]) 235 def test_proof_user_not_found(self, print_mock, inputMock): 236 assert_error(self, call_proof, ValueError, "User not found") 237 238 @mock.patch('builtins.input', side_effect=[USERS[0], "invalidTestAddress"]) 239 def test_proof_invalid_eth_address(self, print_mock, inputMock): 240 assert_error(self, call_proof, ValueError, "Invalid Ethereum address") 241 242 @mock.patch('builtins.input', side_effect=[USERS[0], "0x0000000000000000000000000000000000000006", "key-path"]) 243 @mock.patch('os.path.exists', side_effect=lambda x: mock.DEFAULT if (x != "key-path") else False) 244 def test_proof_file_is_not_exist(self, print_mock, inputMock, pathExistsMock): 245 assert_error(self, call_proof, FileNotFoundError, "File is not exist") 246 247 @mock.patch('builtins.input', side_effect=[USERS[0], "0x0000000000000000000000000000000000000006", "key-path"]) 248 @mock.patch('os.path.exists', side_effect=lambda x: mock.DEFAULT if (x != "key-path") else True) 249 @mock.patch('os.path.isfile', side_effect=lambda x: mock.DEFAULT if (x != "key-path") else False) 250 def test_proof_file_is_not_ssh_key(self, print_mock, inputMock, pathExistsMock, isFileMock): 251 assert_error(self, call_proof, ValueError, "File is not a SSH key") 252 253 def test_proof(self, print_mock): 254 metadata = {} 255 with open(os.path.join(TEST_FILES_PATH, "metadata.json"), "r") as f: 256 metadata = Metadata.from_json(f.read()) 257 258 for user in USERS: 259 verify_proof_by_user( 260 self, 261 user, 262 secrets.token_bytes(20).hex(), 263 user, 264 print_mock, 265 metadata 266 ) 267 268 for user in USERS_WITH_DOUBLE_KEYS: 269 verify_proof_by_user( 270 self, 271 user, 272 secrets.token_bytes(20).hex(), 273 user, 274 print_mock, 275 metadata 276 ) 277 verify_proof_by_user( 278 self, 279 user, 280 secrets.token_bytes(20).hex(), 281 user + SECOND_KEY_POSTFIX, 282 print_mock, 283 metadata 284 ) 285 286 287 if __name__ == '__main__': 288 init() 289 unittest.main()