/ python / generate.py
generate.py
  1  #!/usr/bin/env python
  2  
  3  import base64
  4  import csv
  5  import argparse
  6  import dataclasses
  7  import os
  8  import subprocess
  9  import secrets
 10  import json
 11  from turtle import st
 12  from helpers.merkle import MerkleTree
 13  from eth_account import Account
 14  from helpers.common import Metadata
 15  from cryptography.hazmat.primitives.asymmetric import ec
 16  from cryptography.hazmat.primitives import serialization as crypto_serialization
 17  from cryptography.hazmat.primitives import serialization
 18  from cryptography.hazmat.primitives.asymmetric import rsa, ed25519
 19  from cryptography.hazmat.backends import default_backend
 20  import csv
 21  import cryptography.exceptions as exceptions
 22  
 23  DEFAULT_OUTPUT_DIR = os.path.join(os. getcwd(), "output")
 24  DEFAULT_FRACTION = 1000
 25  
 26  
 27  @dataclasses.dataclass
 28  class User:
 29      name: str
 30      pubKey: str
 31  
 32  
 33  def validate_key(key: str):
 34      try:
 35          public_key = serialization.load_ssh_public_key(
 36              key.encode(),
 37              backend=default_backend()
 38          )
 39      except exceptions.UnsupportedAlgorithm as e:
 40          return False
 41      except ValueError as e:
 42          return False
 43  
 44      if isinstance(public_key, rsa.RSAPublicKey):
 45          if public_key.key_size < 2048:
 46              return False
 47          return True
 48      elif isinstance(public_key, ed25519.Ed25519PublicKey):
 49          return True
 50      else:
 51          return False
 52  
 53  
 54  def read_csv(filename):
 55      users = []
 56      with open(filename, 'r') as csvfile:
 57          reader = csv.reader(csvfile, delimiter=',')
 58          i = 1
 59          totalInvalidCount = 0
 60          totalSkippedCount = 0
 61          for row in reader:
 62              print(
 63                  f"Progress: {i} valid, {totalSkippedCount} skipped, {totalInvalidCount} invalid keys",
 64                  end="\r")
 65  
 66              if len(row) != 2:
 67                  totalSkippedCount += 1
 68                  continue
 69  
 70              is_valid = validate_key(row[1])
 71              if not is_valid:
 72                  totalInvalidCount += 1
 73                  continue
 74  
 75              i += 1
 76              users.append(User(name=row[0], pubKey=row[1]))
 77  
 78      print(f"Total skipped keys: {totalSkippedCount}")
 79      print(f"Total invalid keys: {totalInvalidCount}")
 80      return users
 81  
 82  
 83  def gen_eth_keys(users):
 84      addresses = []
 85      privateKeys = {}
 86  
 87      i = 0
 88      fraction001 = int(len(users) / DEFAULT_FRACTION)
 89      for user in users:
 90          username = user.name
 91          privateKey = privateKeys.get(username)
 92  
 93          print(f"Progress: {(i / fraction001 / 10):.2f}%", end="\r")
 94  
 95          i += 1
 96          if privateKey is not None:
 97              continue
 98  
 99          privateKey = "0x" + secrets.token_hex(32)
100          privateKeys[username] = privateKey
101          addresses.append(Account.from_key(privateKey).address.lower())
102  
103      return addresses, privateKeys
104  
105  
106  def encrypt_data_with_ssh(data, sshPubKey):
107      result = subprocess.run(["age",
108                               "--encrypt",
109                               "--recipient",
110                               sshPubKey,
111                               "-o",
112                               "-",
113                               "--armor"],
114                              capture_output=True,
115                              input=data.encode(),
116                              env=env)
117      if result.returncode != 0:
118          raise OSError(result.stderr)
119      return result.stdout.decode()
120  
121  
122  def random_sort(addresses):
123      length = len(addresses)
124  
125      for i in range(length):
126          j = secrets.randbelow(length)
127          addresses[i], addresses[j] = addresses[j], addresses[i]
128  
129  
130  def encrypt_for_standart_output(users, privateKeys):
131      encryptedKeys: dict[str, dict[str, str]] = {}
132  
133      fraction001 = int(len(users) / DEFAULT_FRACTION)
134      i = 0
135      totalFailedCount = 0
136      for user in users:
137          username = user.name
138          sshPubKey = user.pubKey
139          privateKey = privateKeys[username]
140  
141          print(f"Progress: {(i / fraction001 / 10):.2f}%", end="\r")
142          i += 1
143  
144          try:
145              encrypted_data = encrypt_data_with_ssh(
146                  privateKey, sshPubKey
147              )
148              if username not in encryptedKeys:
149                  encryptedKeys[username] = {}
150  
151              encryptedKeys[username][sshPubKey] = encrypted_data
152          except OSError as e:
153              totalFailedCount += 1
154              print(f"Failed to encrypt key for {username} with {sshPubKey}")
155              print(e)
156  
157      return encryptedKeys
158  
159  
160  def encrypt_for_sh_output(tree, users, addresses, privateKeys):
161      encryptedKeys = {}
162      indexes = {}
163  
164      for i in range(len(addresses)):
165          indexes[addresses[i]] = i
166  
167      fraction001 = int(len(addresses) / DEFAULT_FRACTION)
168      i = 0
169      for user in users:
170          username = user.name
171          sshPubKey = user.pubKey
172          privateKey = privateKeys[username]
173          address = Account.from_key(privateKeys[username]).address.lower()
174  
175          print(f"Progress: {(i / fraction001 / 10):.2f}%", end="\r")
176          i += 1
177  
178          if username not in encryptedKeys:
179              encryptedKeys[username] = []
180  
181          index = indexes[address]
182  
183          proof = base64.b64encode(json.dumps(
184              tree.get_proof(index)).encode()
185          ).decode()
186  
187          key = ec.derive_private_key(int(privateKey, base=16), ec.SECP256K1())
188          openSSLPrivKey = "0x" + key.private_bytes(
189              crypto_serialization.Encoding.DER,
190              crypto_serialization.PrivateFormat.TraditionalOpenSSL,
191              crypto_serialization.NoEncryption()
192          ).hex()
193  
194          try:
195              encryptedData = encrypt_data_with_ssh(
196                  f"{index},{address},{openSSLPrivKey},{proof}", sshPubKey)\
197                  .replace("-----BEGIN AGE ENCRYPTED FILE-----", "")\
198                  .replace("-----END AGE ENCRYPTED FILE-----", "")
199              encryptedKeys[username].append(
200                  base64.b64decode(encryptedData).hex())
201          except OSError as e:
202              print(f"Failed to encrypt key for {username} with {sshPubKey}")
203              print(e)
204  
205      return encryptedKeys
206  
207  
208  def write_output(filePath, root, addresses, encryptedKeys):
209      metadata = Metadata(root=root, addresses=addresses,
210                          encryptedKeys=encryptedKeys)
211  
212      with open(filePath, 'w') as f:
213          json.dump(metadata.to_dict(), f, ensure_ascii=False, indent=4)
214  
215  
216  def write_output_for_sh_script(filePath, encryptedKeys):
217      with open(filePath, 'w') as f:
218          writer = csv.writer(f, delimiter=",")
219          for username in encryptedKeys:
220              for key in encryptedKeys[username]:
221                  writer.writerow([username, key])
222  
223  
224  def main():
225      parser = argparse.ArgumentParser()
226      parser.add_argument('input', type=str)
227      parser.add_argument('-o', '--output', type=str, default=DEFAULT_OUTPUT_DIR)
228  
229      args = parser.parse_args()
230      inputFilePath = args.input
231      outputFilePath = args.output
232  
233      metadataFilePath = os.path.join(outputFilePath, "metadata.json")
234      shOutputFILEPath = os.path.join(outputFilePath, "metadata.bin")
235  
236      if not os.path.exists(outputFilePath):
237          os.mkdir(outputFilePath)
238  
239      print(f"Reading from {inputFilePath}")
240      users = read_csv(inputFilePath)
241  
242      print(f"Generating keys for {len(users)} users")
243      addresses, privateKeys = gen_eth_keys(users)
244  
245      print(f"Suffling addresses")
246      random_sort(addresses)
247  
248      print(f"Generating Merkle tree")
249      tree = MerkleTree(addresses)
250  
251      write_output(metadataFilePath, tree.get_root(), addresses,
252                   encrypt_for_standart_output(users, privateKeys))
253      write_output_for_sh_script(shOutputFILEPath, encrypt_for_sh_output(
254          tree, users, addresses, privateKeys))
255  
256      print(f"Metadata file written to {metadataFilePath}")
257  
258  
259  if __name__ == '__main__':
260      main()