generate.py
1 #!/usr/bin/env python 2 3 import base64 4 import csv 5 import argparse 6 import dataclasses 7 import os 8 import subprocess 9 import secrets 10 import json 11 from turtle import st 12 from helpers.merkle import MerkleTree 13 from eth_account import Account 14 from helpers.common import Metadata 15 from cryptography.hazmat.primitives.asymmetric import ec 16 from cryptography.hazmat.primitives import serialization as crypto_serialization 17 from cryptography.hazmat.primitives import serialization 18 from cryptography.hazmat.primitives.asymmetric import rsa, ed25519 19 from cryptography.hazmat.backends import default_backend 20 import csv 21 import cryptography.exceptions as exceptions 22 23 DEFAULT_OUTPUT_DIR = os.path.join(os. getcwd(), "output") 24 DEFAULT_FRACTION = 1000 25 26 27 @dataclasses.dataclass 28 class User: 29 name: str 30 pubKey: str 31 32 33 def validate_key(key: str): 34 try: 35 public_key = serialization.load_ssh_public_key( 36 key.encode(), 37 backend=default_backend() 38 ) 39 except exceptions.UnsupportedAlgorithm as e: 40 return False 41 except ValueError as e: 42 return False 43 44 if isinstance(public_key, rsa.RSAPublicKey): 45 if public_key.key_size < 2048: 46 return False 47 return True 48 elif isinstance(public_key, ed25519.Ed25519PublicKey): 49 return True 50 else: 51 return False 52 53 54 def read_csv(filename): 55 users = [] 56 with open(filename, 'r') as csvfile: 57 reader = csv.reader(csvfile, delimiter=',') 58 i = 1 59 totalInvalidCount = 0 60 totalSkippedCount = 0 61 for row in reader: 62 print( 63 f"Progress: {i} valid, {totalSkippedCount} skipped, {totalInvalidCount} invalid keys", 64 end="\r") 65 66 if len(row) != 2: 67 totalSkippedCount += 1 68 continue 69 70 is_valid = validate_key(row[1]) 71 if not is_valid: 72 totalInvalidCount += 1 73 continue 74 75 i += 1 76 users.append(User(name=row[0], pubKey=row[1])) 77 78 print(f"Total skipped keys: {totalSkippedCount}") 79 print(f"Total invalid keys: {totalInvalidCount}") 80 return users 81 82 83 def gen_eth_keys(users): 84 addresses = [] 85 privateKeys = {} 86 87 i = 0 88 fraction001 = int(len(users) / DEFAULT_FRACTION) 89 for user in users: 90 username = user.name 91 privateKey = privateKeys.get(username) 92 93 print(f"Progress: {(i / fraction001 / 10):.2f}%", end="\r") 94 95 i += 1 96 if privateKey is not None: 97 continue 98 99 privateKey = "0x" + secrets.token_hex(32) 100 privateKeys[username] = privateKey 101 addresses.append(Account.from_key(privateKey).address.lower()) 102 103 return addresses, privateKeys 104 105 106 def encrypt_data_with_ssh(data, sshPubKey): 107 result = subprocess.run(["age", 108 "--encrypt", 109 "--recipient", 110 sshPubKey, 111 "-o", 112 "-", 113 "--armor"], 114 capture_output=True, 115 input=data.encode(), 116 env=env) 117 if result.returncode != 0: 118 raise OSError(result.stderr) 119 return result.stdout.decode() 120 121 122 def random_sort(addresses): 123 length = len(addresses) 124 125 for i in range(length): 126 j = secrets.randbelow(length) 127 addresses[i], addresses[j] = addresses[j], addresses[i] 128 129 130 def encrypt_for_standart_output(users, privateKeys): 131 encryptedKeys: dict[str, dict[str, str]] = {} 132 133 fraction001 = int(len(users) / DEFAULT_FRACTION) 134 i = 0 135 totalFailedCount = 0 136 for user in users: 137 username = user.name 138 sshPubKey = user.pubKey 139 privateKey = privateKeys[username] 140 141 print(f"Progress: {(i / fraction001 / 10):.2f}%", end="\r") 142 i += 1 143 144 try: 145 encrypted_data = encrypt_data_with_ssh( 146 privateKey, sshPubKey 147 ) 148 if username not in encryptedKeys: 149 encryptedKeys[username] = {} 150 151 encryptedKeys[username][sshPubKey] = encrypted_data 152 except OSError as e: 153 totalFailedCount += 1 154 print(f"Failed to encrypt key for {username} with {sshPubKey}") 155 print(e) 156 157 return encryptedKeys 158 159 160 def encrypt_for_sh_output(tree, users, addresses, privateKeys): 161 encryptedKeys = {} 162 indexes = {} 163 164 for i in range(len(addresses)): 165 indexes[addresses[i]] = i 166 167 fraction001 = int(len(addresses) / DEFAULT_FRACTION) 168 i = 0 169 for user in users: 170 username = user.name 171 sshPubKey = user.pubKey 172 privateKey = privateKeys[username] 173 address = Account.from_key(privateKeys[username]).address.lower() 174 175 print(f"Progress: {(i / fraction001 / 10):.2f}%", end="\r") 176 i += 1 177 178 if username not in encryptedKeys: 179 encryptedKeys[username] = [] 180 181 index = indexes[address] 182 183 proof = base64.b64encode(json.dumps( 184 tree.get_proof(index)).encode() 185 ).decode() 186 187 key = ec.derive_private_key(int(privateKey, base=16), ec.SECP256K1()) 188 openSSLPrivKey = "0x" + key.private_bytes( 189 crypto_serialization.Encoding.DER, 190 crypto_serialization.PrivateFormat.TraditionalOpenSSL, 191 crypto_serialization.NoEncryption() 192 ).hex() 193 194 try: 195 encryptedData = encrypt_data_with_ssh( 196 f"{index},{address},{openSSLPrivKey},{proof}", sshPubKey)\ 197 .replace("-----BEGIN AGE ENCRYPTED FILE-----", "")\ 198 .replace("-----END AGE ENCRYPTED FILE-----", "") 199 encryptedKeys[username].append( 200 base64.b64decode(encryptedData).hex()) 201 except OSError as e: 202 print(f"Failed to encrypt key for {username} with {sshPubKey}") 203 print(e) 204 205 return encryptedKeys 206 207 208 def write_output(filePath, root, addresses, encryptedKeys): 209 metadata = Metadata(root=root, addresses=addresses, 210 encryptedKeys=encryptedKeys) 211 212 with open(filePath, 'w') as f: 213 json.dump(metadata.to_dict(), f, ensure_ascii=False, indent=4) 214 215 216 def write_output_for_sh_script(filePath, encryptedKeys): 217 with open(filePath, 'w') as f: 218 writer = csv.writer(f, delimiter=",") 219 for username in encryptedKeys: 220 for key in encryptedKeys[username]: 221 writer.writerow([username, key]) 222 223 224 def main(): 225 parser = argparse.ArgumentParser() 226 parser.add_argument('input', type=str) 227 parser.add_argument('-o', '--output', type=str, default=DEFAULT_OUTPUT_DIR) 228 229 args = parser.parse_args() 230 inputFilePath = args.input 231 outputFilePath = args.output 232 233 metadataFilePath = os.path.join(outputFilePath, "metadata.json") 234 shOutputFILEPath = os.path.join(outputFilePath, "metadata.bin") 235 236 if not os.path.exists(outputFilePath): 237 os.mkdir(outputFilePath) 238 239 print(f"Reading from {inputFilePath}") 240 users = read_csv(inputFilePath) 241 242 print(f"Generating keys for {len(users)} users") 243 addresses, privateKeys = gen_eth_keys(users) 244 245 print(f"Suffling addresses") 246 random_sort(addresses) 247 248 print(f"Generating Merkle tree") 249 tree = MerkleTree(addresses) 250 251 write_output(metadataFilePath, tree.get_root(), addresses, 252 encrypt_for_standart_output(users, privateKeys)) 253 write_output_for_sh_script(shOutputFILEPath, encrypt_for_sh_output( 254 tree, users, addresses, privateKeys)) 255 256 print(f"Metadata file written to {metadataFilePath}") 257 258 259 if __name__ == '__main__': 260 main()