/ benchmarks / run.jl
run.jl
  1  #!/usr/bin/env julia
  2  
  3  using Pkg
  4  Pkg.activate(@__DIR__)
  5  
  6  using BenchmarkTools
  7  using Random
  8  using JSON
  9  using Dates
 10  using GrothAlgebra
 11  using GrothCurves
 12  using GrothProofs
 13  
 14  const Fr = BN254ScalarField
 15  const r = BN254_ORDER_R
 16  const ENGINE = BN254_ENGINE
 17  
 18  # -----------------------------------------------------------------------------
 19  # Sampling helpers
 20  # -----------------------------------------------------------------------------
 21  
 22  function rand_scalar(R::AbstractRNG)
 23      # Sample in [0, r)
 24      return BigInt(rand(R, 0:r-1))
 25  end
 26  
 27  function rand_scalar_nonzero(R::AbstractRNG)
 28      # Sample in [1, r-1] to avoid the point at infinity in pairings
 29      return BigInt(rand(R, 1:r-1))
 30  end
 31  
 32  function genscalars(R::AbstractRNG, N::Int)
 33      return [rand_scalar(R) for _ in 1:N]
 34  end
 35  
 36  function genpoints_fixed(g, scalars::Vector{BigInt})
 37      # Naive scalar multiplication loop
 38      return [scalar_mul(g, s) for s in scalars]
 39  end
 40  
 41  # -----------------------------------------------------------------------------
 42  # Trial helpers for consistent JSON output
 43  # -----------------------------------------------------------------------------
 44  
 45  function estimate_to_seconds(est)
 46      # Work with the pretty string so units match printed output
 47      s = replace(string(est), "\u03bc" => "u")  # normalize micro symbol
 48      m = match(r"TrialEstimate\(([0-9\.eE+-]+) ([a-z]+)\)", s)
 49      m === nothing && error("Unable to parse TrialEstimate string: $s")
 50      value = parse(Float64, m.captures[1])
 51      unit = m.captures[2]
 52      factor = unit == "s"  ? 1.0 :
 53               unit == "ms" ? 1e-3 :
 54               unit == "us" ? 1e-6 :
 55               unit == "ns" ? 1e-9 :
 56               error("Unsupported time unit $unit in TrialEstimate")
 57      return value * factor
 58  end
 59  
 60  function trial_summary(tr::BenchmarkTools.Trial)
 61      min_est = minimum(tr)
 62      med_est = median(tr)
 63      Dict(
 64          "min_pretty" => string(min_est),
 65          "median_pretty" => string(med_est),
 66          "min_seconds" => estimate_to_seconds(min_est),
 67          "median_seconds" => estimate_to_seconds(med_est),
 68          "memory_bytes" => memory(tr),
 69      )
 70  end
 71  
 72  function record_result!(results::Dict{Symbol,Any}, group::Symbol, nkey::String, label::String, tr::BenchmarkTools.Trial)
 73      group_dict = get!(results, group, Dict{String,Any}())
 74      entry = get!(group_dict, nkey, Dict{String,Any}())
 75      entry[label] = trial_summary(tr)
 76  end
 77  
 78  function record_simple!(results::Dict{Symbol,Any}, group::Symbol, label::String, tr::BenchmarkTools.Trial)
 79      group_dict = get!(results, group, Dict{String,Any}())
 80      group_dict[label] = trial_summary(tr)
 81  end
 82  
 83  function print_stats(label, tr::BenchmarkTools.Trial)
 84      tmin = minimum(tr)
 85      tmed = median(tr)
 86      tmean = mean(tr)
 87      bytes = memory(tr)
 88      println(rpad(label, 40), " min=$(tmin) med=$(tmed) mean=$(tmean) mem=$(bytes)B")
 89  end
 90  
 91  # -----------------------------------------------------------------------------
 92  # Benchmark families
 93  # -----------------------------------------------------------------------------
 94  
 95  function bench_fixed_base(results)
 96      println("\n== Fixed-base precompute: table build and batch_mul (G1) ==")
 97      g1 = g1_generator()
 98      rng = MersenneTwister(42)
 99      for N in (32, 128, 512)
100          scalars = genscalars(rng, N)
101          println("N = ", N)
102          # Warmup
103          _ = build_fixed_table(g1)
104          _ = genpoints_fixed(g1, scalars)
105          # Separate table build vs batch_mul
106          tr_build = @benchmark build_fixed_table($g1) seconds=1 samples=10
107          tab = build_fixed_table(g1)
108          tr_naive = @benchmark genpoints_fixed($g1, $scalars) seconds=1 samples=10
109          tr_batch = @benchmark batch_mul($tab, $scalars) seconds=1 samples=10
110          print_stats("G1 build", tr_build)
111          print_stats("G1 naive", tr_naive)
112          print_stats("G1 batch", tr_batch)
113          record_result!(results, :fixed_g1, string(N), "build", tr_build)
114          record_result!(results, :fixed_g1, string(N), "naive", tr_naive)
115          record_result!(results, :fixed_g1, string(N), "batch", tr_batch)
116      end
117  
118      println("\n== Fixed-base precompute: table build and batch_mul (G2) ==")
119      g2 = g2_generator()
120      rng = MersenneTwister(123)
121      for N in (32, 128, 512)
122          scalars = genscalars(rng, N)
123          println("N = ", N)
124          _ = build_fixed_table(g2)
125          _ = genpoints_fixed(g2, scalars)
126          tr_build = @benchmark build_fixed_table($g2) seconds=1 samples=10
127          tab = build_fixed_table(g2)
128          tr_naive = @benchmark genpoints_fixed($g2, $scalars) seconds=1 samples=10
129          tr_batch = @benchmark batch_mul($tab, $scalars) seconds=1 samples=10
130          print_stats("G2 build", tr_build)
131          print_stats("G2 naive", tr_naive)
132          print_stats("G2 batch", tr_batch)
133          record_result!(results, :fixed_g2, string(N), "build", tr_build)
134          record_result!(results, :fixed_g2, string(N), "naive", tr_naive)
135          record_result!(results, :fixed_g2, string(N), "batch", tr_batch)
136      end
137  end
138  
139  function naive_msm(bases, scalars)
140      acc = zero(bases[1])
141      @inbounds for i in eachindex(bases)
142          s = scalars[i]
143          iszero(s) && continue
144          acc += scalar_mul(bases[i], s)
145      end
146      return acc
147  end
148  
149  function gen_random_bases_g1(rng::AbstractRNG, N::Int)
150      g = g1_generator()
151      scal = genscalars(rng, N)
152      return [scalar_mul(g, s) for s in scal]
153  end
154  
155  function gen_random_bases_g2(rng::AbstractRNG, N::Int)
156      g = g2_generator()
157      scal = genscalars(rng, N)
158      return [scalar_mul(g, s) for s in scal]
159  end
160  
161  function bench_variable_msm(results)
162      println("\n== Variable-base MSM: multi_scalar_mul vs naive (G1) ==")
163      rng = MersenneTwister(7)
164      for N in (32, 128, 512)
165          bases = gen_random_bases_g1(rng, N)
166          scalars = genscalars(rng, N)
167          println("N = ", N)
168          _ = naive_msm(bases, scalars)
169          _ = GrothAlgebra.multi_scalar_mul(bases, scalars)
170          tr_naive = @benchmark naive_msm($bases, $scalars) seconds=1 samples=10
171          tr_msm = @benchmark GrothAlgebra.multi_scalar_mul($bases, $scalars) seconds=1 samples=10
172          print_stats("G1 naive", tr_naive)
173          print_stats("G1 MSM", tr_msm)
174          record_result!(results, :msm_g1, string(N), "naive", tr_naive)
175          record_result!(results, :msm_g1, string(N), "msm", tr_msm)
176      end
177  
178      println("\n== Variable-base MSM: multi_scalar_mul vs naive (G2) ==")
179      rng = MersenneTwister(9)
180      for N in (32, 128, 512)
181          bases = gen_random_bases_g2(rng, N)
182          scalars = genscalars(rng, N)
183          println("N = ", N)
184          _ = naive_msm(bases, scalars)
185          _ = GrothAlgebra.multi_scalar_mul(bases, scalars)
186          tr_naive = @benchmark naive_msm($bases, $scalars) seconds=1 samples=10
187          tr_msm = @benchmark GrothAlgebra.multi_scalar_mul($bases, $scalars) seconds=1 samples=10
188          print_stats("G2 naive", tr_naive)
189          print_stats("G2 MSM", tr_msm)
190          record_result!(results, :msm_g2, string(N), "naive", tr_naive)
191          record_result!(results, :msm_g2, string(N), "msm", tr_msm)
192      end
193  end
194  
195  function bench_batch_norm(results)
196      println("\n== Batch normalization: batch_to_affine! vs per-point to_affine (G1) ==")
197      g = g1_generator()
198      rng = MersenneTwister(11)
199      for N in (32, 128, 512)
200          scalars = genscalars(rng, N)
201          proj_pts = [scalar_mul(g, s) for s in scalars]
202          println("N = ", N)
203          _ = GrothCurves.batch_to_affine!(copy(proj_pts))
204          _ = begin
205              tmp = copy(proj_pts)
206              for i in eachindex(tmp)
207                  x, y = to_affine(tmp[i])
208                  tmp[i] = G1Point(x, y, one(BN254Field))
209              end
210          end
211          tr_batch = @benchmark begin
212              tmp = copy($proj_pts)
213              GrothCurves.batch_to_affine!(tmp)
214          end seconds=1 samples=10
215          tr_each = @benchmark begin
216              tmp = copy($proj_pts)
217              for i in eachindex(tmp)
218                  x, y = to_affine(tmp[i])
219                  tmp[i] = G1Point(x, y, one(BN254Field))
220              end
221          end seconds=1 samples=10
222          print_stats("G1 batch_norm", tr_batch)
223          print_stats("G1 per_point", tr_each)
224          record_result!(results, :norm_g1, string(N), "batch", tr_batch)
225          record_result!(results, :norm_g1, string(N), "each", tr_each)
226      end
227  
228      println("\n== Batch normalization: batch_to_affine! vs per-point to_affine (G2) ==")
229      g = g2_generator()
230      rng = MersenneTwister(13)
231      for N in (32, 128, 512)
232          scalars = genscalars(rng, N)
233          proj_pts = [scalar_mul(g, s) for s in scalars]
234          println("N = ", N)
235          _ = GrothCurves.batch_to_affine!(copy(proj_pts))
236          _ = begin
237              tmp = copy(proj_pts)
238              for i in eachindex(tmp)
239                  x, y = to_affine(tmp[i])
240                  tmp[i] = G2Point(x, y, one(Fp2Element))
241              end
242          end
243          tr_batch = @benchmark begin
244              tmp = copy($proj_pts)
245              GrothCurves.batch_to_affine!(tmp)
246          end seconds=1 samples=10
247          tr_each = @benchmark begin
248              tmp = copy($proj_pts)
249              for i in eachindex(tmp)
250                  x, y = to_affine(tmp[i])
251                  tmp[i] = G2Point(x, y, one(Fp2Element))
252              end
253          end seconds=1 samples=10
254          print_stats("G2 batch_norm", tr_batch)
255          print_stats("G2 per_point", tr_each)
256          record_result!(results, :norm_g2, string(N), "batch", tr_batch)
257          record_result!(results, :norm_g2, string(N), "each", tr_each)
258      end
259  end
260  
261  function bench_pairings(results)
262      println("\n== Pairing engine: sequential vs batch ==")
263      g1 = g1_generator()
264      g2 = g2_generator()
265      for N in (1, 4, 16)
266          rng = MersenneTwister(200 + N)
267          p_vec = [scalar_mul(g1, rand_scalar_nonzero(rng)) for _ in 1:N]
268          q_vec = [scalar_mul(g2, rand_scalar_nonzero(rng)) for _ in 1:N]
269          println("Batch size N = ", N)
270          for (P, Q) in zip(p_vec, q_vec)
271              pairing(ENGINE, P, Q)
272          end
273          pairing_batch(ENGINE, p_vec, q_vec)
274          tr_seq = @benchmark begin
275              acc = one(GTElement)
276              @inbounds for i in eachindex($p_vec)
277                  acc *= pairing($ENGINE, $p_vec[i], $q_vec[i])
278              end
279              acc
280          end seconds=2 samples=8
281          tr_batch = @benchmark pairing_batch($ENGINE, $p_vec, $q_vec) seconds=2 samples=8
282          print_stats("Pairing sequential (N=$(N))", tr_seq)
283          print_stats("Pairing batch (N=$(N))", tr_batch)
284          record_result!(results, :pairing, string(N), "sequential", tr_seq)
285          record_result!(results, :pairing, string(N), "batch", tr_batch)
286      end
287  
288      println("\n== Pairing engine micro-ops ==")
289      P1 = scalar_mul(g1, BigInt(5))
290      Q1 = scalar_mul(g2, BigInt(7))
291      f_pre = miller_loop(ENGINE, P1, Q1)
292      tr_single = @benchmark pairing($ENGINE, $P1, $Q1) seconds=2 samples=10
293      tr_miller = @benchmark miller_loop($ENGINE, $P1, $Q1) seconds=2 samples=10
294      tr_final = @benchmark final_exponentiation($ENGINE, $f_pre) seconds=2 samples=10
295      print_stats("Pairing single", tr_single)
296      print_stats("Miller loop", tr_miller)
297      print_stats("Final exponentiation", tr_final)
298      record_simple!(results, :pairing_single, "pairing", tr_single)
299      record_simple!(results, :pairing_single, "miller_loop", tr_miller)
300      record_simple!(results, :pairing_single, "final_exponentiation", tr_final)
301  end
302  
303  function bench_groth16(results)
304      println("\n== Groth16 end-to-end pipeline ==")
305      r1cs = create_r1cs_example_sum_of_products()
306      tr_r1cs_to_qap = @benchmark r1cs_to_qap($r1cs) seconds=2 samples=10
307      print_stats("R1CS -> QAP", tr_r1cs_to_qap)
308      record_simple!(results, :groth16, "r1cs_to_qap", tr_r1cs_to_qap)
309  
310      qap = r1cs_to_qap(r1cs)
311      witness = create_witness_sum_of_products(3, 5, 7, 11)
312      public_inputs = witness.values[1:r1cs.num_public]
313  
314      tr_setup = @benchmark setup_full($qap; rng=MersenneTwister(42)) seconds=5 samples=5
315      print_stats("Groth16 setup", tr_setup)
316      record_simple!(results, :groth16, "setup", tr_setup)
317  
318      keypair = setup_full(qap; rng=MersenneTwister(42))
319      proof = prove_full(keypair.pk, qap, witness; rng=MersenneTwister(1337))
320  
321      tr_prove = @benchmark prove_full($keypair.pk, $qap, $witness; rng=MersenneTwister(1337)) seconds=5 samples=5
322      print_stats("Groth16 prove", tr_prove)
323      record_simple!(results, :groth16, "prove", tr_prove)
324  
325      tr_verify = @benchmark verify_full($keypair.vk, $proof, $public_inputs) seconds=5 samples=15
326      print_stats("Groth16 verify", tr_verify)
327      record_simple!(results, :groth16, "verify_full", tr_verify)
328  
329      tr_prepare_vk = @benchmark prepare_verifying_key($keypair.vk) seconds=2 samples=10
330      print_stats("Groth16 prepare_vk", tr_prepare_vk)
331      record_simple!(results, :groth16, "prepare_vk", tr_prepare_vk)
332  
333      pvk = prepare_verifying_key(keypair.vk)
334      tr_prepare_inputs = @benchmark prepare_inputs($pvk, $public_inputs) seconds=2 samples=10
335      print_stats("Groth16 prepare_inputs", tr_prepare_inputs)
336      record_simple!(results, :groth16, "prepare_inputs", tr_prepare_inputs)
337  
338      prepared_inputs = prepare_inputs(pvk, public_inputs)
339      tr_verify_prepared = @benchmark verify_with_prepared($pvk, $proof, $prepared_inputs) seconds=5 samples=15
340      print_stats("Groth16 verify prepared", tr_verify_prepared)
341      record_simple!(results, :groth16, "verify_prepared", tr_verify_prepared)
342  end
343  
344  # -----------------------------------------------------------------------------
345  # Entry point
346  # -----------------------------------------------------------------------------
347  
348  function main()
349      println("GrothBenchmarks — Julia $(VERSION)")
350      results = Dict{Symbol,Any}()
351      bench_fixed_base(results)
352      bench_variable_msm(results)
353      bench_batch_norm(results)
354      bench_pairings(results)
355      bench_groth16(results)
356  
357      ts = Dates.format(Dates.now(), "yyyy-mm-dd_HHMMSS")
358      out = joinpath(@__DIR__, "results_" * ts * ".json")
359      open(out, "w") do io
360          JSON.print(io, results)
361      end
362      println("\nSaved results to ", out)
363  end
364  
365  main()