/ GrothProofs / test / random_circuits.jl
random_circuits.jl
  1  using Random
  2  using GrothAlgebra: BN254ScalarField, bn254_scalar
  3  
  4  """
  5      GeneratedCircuit
  6  
  7  Simple container for randomly generated R1CS data.
  8  
  9  - `r1cs`: the constructed `R1CS` instance
 10  - `witness`: satisfying `Witness`
 11  - `public_indices`: indices (1-based) of public inputs
 12  - `description`: brief string describing constraints for debugging
 13  """
 14  struct GeneratedCircuit{F}
 15      r1cs::R1CS{F}
 16      witness::Witness{F}
 17      public_indices::Vector{Int}
 18      used_public::Vector{Int}
 19      description::Vector{String}
 20  end
 21  
 22  """
 23      generate_small_r1cs(rng; max_constraints=4, max_public=5, value_range=-20:20, retries=20)
 24  
 25  Generate a compact R1CS/witness pair using basic gate templates.
 26  
 27  The construction keeps circuits intentionally small:
 28  - number of constraints ≤ `max_constraints`
 29  - variables are introduced gradually, each constraint either multiplies two existing
 30    values or forms an affine combination.
 31  - public inputs include the constant 1 and the first `public_count` variable slots.
 32  
 33  Returns a `GeneratedCircuit` or throws if unable to find a valid witness within `retries` attempts.
 34  """
 35  function generate_small_r1cs(rng::AbstractRNG; max_constraints::Int=4, max_public::Int=5, value_range=-20:20, retries::Int=20)
 36      max_constraints >= 2 || error("expect at least two constraints")
 37      for attempt in 1:retries
 38  
 39          n_constraints = rand(rng, 2:max_constraints)
 40          # Start with 1 (constant) plus r variable and at least one auxiliary slot
 41          num_vars = 1 + 1 + n_constraints + 2
 42          num_public = min(max_public, num_vars)
 43  
 44          F = BN254ScalarField
 45          L = zeros(F, n_constraints, num_vars)
 46          R = zeros(F, n_constraints, num_vars)
 47          O = zeros(F, n_constraints, num_vars)
 48  
 49          values = zeros(Int, num_vars)
 50          values[1] = 1
 51  
 52          idx_r = 2
 53          descriptions = String[]
 54  
 55          base_values = [rand(rng, value_range) for _ in 1:num_vars]
 56          for i in 2:num_public
 57              values[i] = base_values[i]
 58          end
 59  
 60          next_var = num_public + 1
 61          available = collect(3:num_public)  # exclude constant and r
 62          if isempty(available)
 63              continue
 64          end
 65          used_public = Int[]
 66  
 67          # Build constraints sequentially; accumulate new intermediate variables
 68          for c in 1:n_constraints
 69              if c == n_constraints || next_var > num_vars
 70                  isempty(available) && break
 71                  a_idx = rand(rng, available)
 72                  b_idx = rand(rng, available)
 73                  L[c, a_idx] = one(F)
 74                  R[c, b_idx] = one(F)
 75                  O[c, idx_r] = one(F)
 76                  values[idx_r] = values[a_idx] * values[b_idx]
 77                  push!(descriptions, "v$idx_r = v$a_idx * v$b_idx")
 78                  if a_idx <= num_public
 79                      push!(used_public, a_idx)
 80                  end
 81                  if b_idx <= num_public
 82                      push!(used_public, b_idx)
 83                  end
 84                  break
 85              end
 86  
 87              gate_type = rand(rng, (:mul, :affine))
 88              if gate_type == :mul
 89                  if isempty(available) || next_var > num_vars
 90                      continue
 91                  end
 92                  a_idx = rand(rng, available)
 93                  b_idx = rand(rng, available)
 94                  L[c, a_idx] = one(F)
 95                  R[c, b_idx] = one(F)
 96                  O[c, next_var] = one(F)
 97                  values[next_var] = values[a_idx] * values[b_idx]
 98                  push!(descriptions, "v$next_var = v$a_idx * v$b_idx")
 99                  if a_idx <= num_public
100                      push!(used_public, a_idx)
101                  end
102                  if b_idx <= num_public
103                      push!(used_public, b_idx)
104                  end
105              else
106                  if isempty(available) || next_var > num_vars
107                      continue
108                  end
109                  a_idx = rand(rng, available)
110                  b_idx = rand(rng, available)
111                  α = rand(rng, value_range)
112                  β = rand(rng, value_range)
113                  γ = rand(rng, value_range)
114                  if α == 0 && β == 0 && γ == 0
115                      γ = 1
116                  end
117                  L[c, a_idx] = bn254_scalar(α)
118                  L[c, b_idx] += bn254_scalar(β)
119                  L[c, 1] += bn254_scalar(γ)
120                  R[c, 1] = one(F)
121                  O[c, next_var] = one(F)
122                  values[next_var] = α * values[a_idx] + β * values[b_idx] + γ
123                  push!(descriptions, "v$next_var = $α*v$a_idx + $β*v$b_idx + $γ")
124                  if a_idx <= num_public
125                      push!(used_public, a_idx)
126                  end
127                  if b_idx <= num_public
128                      push!(used_public, b_idx)
129                  end
130              end
131              push!(available, next_var)
132              next_var += 1
133          end
134  
135          if values[idx_r] == 0
136              if length(available) >= 1
137                  a_idx = last(available)
138                  values[idx_r] = values[a_idx] * values[a_idx]
139              else
140                  continue
141              end
142          end
143  
144          # Convert witness values into field elements
145          witness_vals = Vector{BN254ScalarField}(undef, num_vars)
146          for i in 1:num_vars
147              witness_vals[i] = bn254_scalar(values[i])
148          end
149  
150          r1cs = R1CS{BN254ScalarField}(num_vars, n_constraints, num_public, L, R, O)
151          witness = Witness{BN254ScalarField}(witness_vals)
152  
153          if is_satisfied(r1cs, witness)
154              public_indices = collect(1:num_public)
155              unique!(used_public)
156              filter!(i -> 2 <= i <= num_public, used_public)
157              return GeneratedCircuit(r1cs, witness, public_indices, used_public, descriptions)
158          end
159      end
160  
161      error("failed to generate satisfying R1CS after $(retries) attempts")
162  end