/ bin / zkrunner / zkrunner.py
zkrunner.py
  1  #!/usr/bin/env python3
  2  # This file is part of DarkFi (https://dark.fi)
  3  #
  4  # Copyright (C) 2020-2025 Dyne.org foundation
  5  #
  6  # This program is free software: you can redistribute it and/or modify
  7  # it under the terms of the GNU Affero General Public License as
  8  # published by the Free Software Foundation, either version 3 of the
  9  # License, or (at your option) any later version.
 10  #
 11  # This program is distributed in the hope that it will be useful,
 12  # but WITHOUT ANY WARRANTY; without even the implied warranty of
 13  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 14  # GNU Affero General Public License for more details.
 15  #
 16  # You should have received a copy of the GNU Affero General Public License
 17  # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 18  """
 19  Python tool to prototype zkVM proofs given zkas source code and necessary
 20  witness values in JSON format.
 21  """
 22  import json
 23  import sys
 24  from darkfi_sdk.pasta import Fp, Fq, Ep
 25  from darkfi_sdk.zkas import (MockProver, ZkBinary, ZkCircuit, ProvingKey,
 26                               Proof, VerifyingKey)
 27  
 28  def eprint(fstr, *args):
 29      print("error: " + fstr, *args, file=sys.stderr)
 30  
 31  def show_trace(opcodes, trace):
 32      print(f"{'Line':<4} {'Opcode':<22} {'Type':<10} {'Values'}")
 33      for i, (opcode, (optype, args)) in enumerate(zip(opcodes, trace)):
 34          if args:
 35              args = ", ".join([str(arg) for arg in args])
 36              args = f"[{args}]"
 37          else:
 38              args = ""
 39          opcode = str(opcode)
 40          optype = str(optype)
 41          print(f"{i:<4} {opcode:<22} {optype:<10} {args}")
 42  
 43  def load_circuit_witness(circuit, witness_file):
 44      # We attempt to decode the witnesses from the JSON file.
 45      # Refer to the `witness_gen.py` file to see what the format of this
 46      # file should be.
 47      if witness_file == "-":
 48          witness_data = json.load(sys.stdin)
 49      else:
 50          with open(witness_file, "r", encoding="utf-8") as json_file:
 51              witness_data = json.load(json_file)
 52  
 53      # Now we scan through the parsed JSON witness file and
 54      # build our "heap". These will be appended to the initial
 55      # circuit and decide the code path for the prover.
 56      for witness in witness_data["witnesses"]:
 57          assert len(witness) == 1
 58          if (value := witness.get("EcPoint")) is not None:
 59              circuit.witness_ecpoint(Ep(value))
 60  
 61          elif (value := witness.get("EcNiPoint")) is not None:
 62              assert len(value) == 2
 63              xcoord, ycoord = Fp(value[0]), Fp(value[1])
 64              circuit.witness_ecnipoint(Ep(xcoord, ycoord))
 65  
 66          elif (value := witness.get("Base")) is not None:
 67              circuit.witness_base(Fp(value))
 68  
 69          elif (value := witness.get("Scalar")) is not None:
 70              circuit.witness_scalar(Fq(value))
 71  
 72          elif (value := witness.get("MerklePath")) is not None:
 73              path = [Fp(i) for i in value]
 74              assert len(path) == 32
 75              circuit.witness_merklepath(path)
 76  
 77          elif (value := witness.get("SparseMerklePath")) is not None:
 78              path = [Fp(i) for i in value]
 79              assert len(path) == 255
 80              circuit.witness_sparsemerklepath(path)
 81  
 82          elif (value := witness.get("Uint32")) is not None:
 83              print("here")
 84              circuit.witness_uint32(value)
 85  
 86          elif (value := witness.get("Uint64")) is not None:
 87              circuit.witness_uint64(value)
 88  
 89          else:
 90              eprint(f"Invalid Witness type for witness {witness}")
 91              return -1
 92  
 93      # Instances are our public inputs for the proof and they're also
 94      # part of the JSON file.
 95      instances = []
 96      for instance in witness_data["instances"]:
 97          instances.append(Fp(instance))
 98      return instances
 99  
100  def main(witness_file, source_file, mock=False, trace=False):
101      """main zkrunner logic"""
102      # Then we attempt to compile the given zkas code and create a
103      # zkVM circuit. This compiling logic happens in the Python bindings'
104      # `ZkBinary::new` function, and should be equivalent to the actual
105      # `zkas` binary provided in the DarkFi codebase.
106      print("Compiling zkas code...")
107      with open(source_file, "r", encoding="utf-8") as zkas_file:
108          zkas_source = zkas_file.read()
109  
110      # This line will compile the source code
111      zkbin = ZkBinary(source_file, zkas_source)
112  
113      # Construct the initial circuit object.
114      circuit = ZkCircuit(zkbin)
115      print("Decoding witnesses...")
116      instances = load_circuit_witness(circuit, witness_file)
117  
118      # If we want to build an actual proof, we'll need a proving key
119      # and a verifying key.
120      # circuit.verifier_build() is called so that the inital circuit
121      # (which contains no witnesses) actually calls empty_witnesses()
122      # in order to have the correct code path when the circuit gets
123      # synthesized.
124      if not mock:
125          print("Building proving key...")
126          proving_key = ProvingKey.build(zkbin.k(), circuit.verifier_build())
127  
128          print("Building verifying key...")
129          verifying_key = VerifyingKey.build(zkbin.k(), circuit.verifier_build())
130  
131      # circuit.prover_build() will actually construct the circuit
132      # with the values witnessed above.
133      circuit = circuit.prover_build()
134      if trace:
135          if mock:
136              eprint(f"Debug trace can only be enabled with --prove")
137              return -2
138          circuit.enable_trace()
139  
140      # If we're building an actual proof, we'll use the ProvingKey to
141      # prove and our VerifyingKey to verify the proof.
142      if not mock:
143          print("Proving knowledge of witnesses...")
144          proof = Proof.create(proving_key, [circuit], instances)
145          if proof is None:
146              eprint(f"Proof creation failed")
147              return -3
148  
149          if trace:
150              show_trace(zkbin.opcodes(), circuit.opvalues())
151  
152          print("Verifying ZK proof...")
153          verify_status = proof.verify(verifying_key, instances)
154  
155      # Otherwise, we'll simply run the MockProver:
156      else:
157          print("Running MockProver...")
158          proof = MockProver.run(zkbin.k(), circuit, instances)
159  
160          print("Verifying MockProver...")
161          verify_status = proof.verify()
162  
163      if not verify_status:
164          eprint("Proof failed to verify")
165          return -3
166  
167      print("Proof verified successfully!")
168      return 0
169  
170  if __name__ == "__main__":
171      from argparse import ArgumentParser
172  
173      parser = ArgumentParser(
174          prog="zkrunner",
175          description="Python util for running zk proofs",
176          epilog="This tool is only for prototyping purposes",
177      )
178  
179      parser.add_argument(
180          "SOURCE",
181          help="Path to zkas source code",
182      )
183      parser.add_argument(
184          "-w",
185          "--witness",
186          required=True,
187          help="Path to JSON file holding witnesses",
188      )
189      parser.add_argument(
190          "--prove",
191          action="store_true",
192          help="Actually create a real proof instead of using MockProver",
193      )
194      parser.add_argument(
195          "--trace",
196          action="store_true",
197          help="Enable debug trace (only works with --prove enabled)",
198      )
199  
200      args = parser.parse_args()
201      sys.exit(main(args.witness, args.SOURCE, mock=not args.prove,
202                    trace=args.trace))
203