/ src / zk / gadget / less_than.rs
less_than.rs
  1  /* This file is part of DarkFi (https://dark.fi)
  2   *
  3   * Copyright (C) 2020-2025 Dyne.org foundation
  4   *
  5   * This program is free software: you can redistribute it and/or modify
  6   * it under the terms of the GNU Affero General Public License as
  7   * published by the Free Software Foundation, either version 3 of the
  8   * License, or (at your option) any later version.
  9   *
 10   * This program is distributed in the hope that it will be useful,
 11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
 12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 13   * GNU Affero General Public License for more details.
 14   *
 15   * You should have received a copy of the GNU Affero General Public License
 16   * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 17   */
 18  
 19  //! Less-Than Gadget
 20  //!
 21  //! Given two values:
 22  //!     - `a`, a NUM_OF_BITS-length value and
 23  //!     - `b`, an arbitrary field element,
 24  //! this gadget constrains them in the following way:
 25  //!     - in `strict` mode, `a` is constrained to be strictly less than `b`;
 26  //!     - else, `a` is constrained to be less than or equal to `b`.
 27  
 28  use halo2_proofs::{
 29      arithmetic::Field,
 30      circuit::{AssignedCell, Chip, Layouter, Region, Value},
 31      pasta::pallas,
 32      plonk::{Advice, Column, ConstraintSystem, Error, Expression, Selector, TableColumn},
 33      poly::Rotation,
 34  };
 35  
 36  use super::native_range_check::{NativeRangeCheckChip, NativeRangeCheckConfig};
 37  
 38  #[derive(Clone, Debug)]
 39  pub struct LessThanConfig<const WINDOW_SIZE: usize, const NUM_OF_BITS: usize> {
 40      pub s_lt: Selector,
 41      pub s_leq: Selector,
 42      pub a: Column<Advice>,
 43      pub b: Column<Advice>,
 44      pub a_offset: Column<Advice>,
 45      pub range_a_config: NativeRangeCheckConfig<WINDOW_SIZE, NUM_OF_BITS>,
 46      pub range_a_offset_config: NativeRangeCheckConfig<WINDOW_SIZE, NUM_OF_BITS>,
 47      pub k_values_table: TableColumn,
 48  }
 49  
 50  #[derive(Clone, Debug)]
 51  pub struct LessThanChip<const WINDOW_SIZE: usize, const NUM_OF_BITS: usize> {
 52      config: LessThanConfig<WINDOW_SIZE, NUM_OF_BITS>,
 53  }
 54  
 55  impl<const WINDOW_SIZE: usize, const NUM_OF_BITS: usize> Chip<pallas::Base>
 56      for LessThanChip<WINDOW_SIZE, NUM_OF_BITS>
 57  {
 58      type Config = LessThanConfig<WINDOW_SIZE, NUM_OF_BITS>;
 59      type Loaded = ();
 60  
 61      fn config(&self) -> &Self::Config {
 62          &self.config
 63      }
 64  
 65      fn loaded(&self) -> &Self::Loaded {
 66          &()
 67      }
 68  }
 69  
 70  impl<const WINDOW_SIZE: usize, const NUM_OF_BITS: usize> LessThanChip<WINDOW_SIZE, NUM_OF_BITS> {
 71      pub fn construct(config: LessThanConfig<WINDOW_SIZE, NUM_OF_BITS>) -> Self {
 72          Self { config }
 73      }
 74  
 75      pub fn configure(
 76          meta: &mut ConstraintSystem<pallas::Base>,
 77          a: Column<Advice>,
 78          b: Column<Advice>,
 79          a_offset: Column<Advice>,
 80          z1: Column<Advice>,
 81          z2: Column<Advice>,
 82          k_values_table: TableColumn,
 83      ) -> LessThanConfig<WINDOW_SIZE, NUM_OF_BITS> {
 84          let s_lt = meta.selector();
 85          let s_leq = meta.selector();
 86  
 87          meta.enable_equality(a);
 88          meta.enable_equality(b);
 89          meta.enable_equality(a_offset);
 90          meta.enable_equality(z1);
 91          meta.enable_equality(z2);
 92  
 93          // configure range check for `a` and `offset`
 94          let range_a_config =
 95              NativeRangeCheckChip::<WINDOW_SIZE, NUM_OF_BITS>::configure(meta, z1, k_values_table);
 96  
 97          let range_a_offset_config =
 98              NativeRangeCheckChip::<WINDOW_SIZE, NUM_OF_BITS>::configure(meta, z2, k_values_table);
 99  
100          let config = LessThanConfig {
101              s_lt,
102              s_leq,
103              a,
104              b,
105              a_offset,
106              range_a_config,
107              range_a_offset_config,
108              k_values_table,
109          };
110  
111          meta.create_gate("a_offset", |meta| {
112              let s_lt = meta.query_selector(config.s_lt);
113              let s_leq = meta.query_selector(config.s_leq);
114              let a = meta.query_advice(config.a, Rotation::cur());
115              let b = meta.query_advice(config.b, Rotation::cur());
116              let a_offset = meta.query_advice(config.a_offset, Rotation::cur());
117              let two_pow_m =
118                  Expression::Constant(pallas::Base::from(2).pow([NUM_OF_BITS as u64, 0, 0, 0]));
119  
120              // If strict, a_offset = a + 2^m - b
121              let strict_check =
122                  s_lt * (a_offset.clone() - two_pow_m.clone() + b.clone() - a.clone());
123              // If leq, a_offset = a + 2^m - b - 1
124              let leq_check =
125                  s_leq * (a_offset - two_pow_m + b - a + Expression::Constant(pallas::Base::one()));
126  
127              vec![strict_check, leq_check]
128          });
129  
130          config
131      }
132  
133      pub fn witness_less_than(
134          &self,
135          mut layouter: impl Layouter<pallas::Base>,
136          a: Value<pallas::Base>,
137          b: Value<pallas::Base>,
138          offset: usize,
139          strict: bool,
140      ) -> Result<(), Error> {
141          let (a, _, a_offset) = layouter.assign_region(
142              || "a less than b",
143              |mut region: Region<'_, pallas::Base>| {
144                  let a = region.assign_advice(|| "a", self.config.a, offset, || a)?;
145                  let b = region.assign_advice(|| "b", self.config.b, offset, || b)?;
146                  let a_offset = self.less_than(region, a.clone(), b.clone(), offset, strict)?;
147                  Ok((a, b, a_offset))
148              },
149          )?;
150  
151          self.less_than_range_check(layouter, a, a_offset)?;
152  
153          Ok(())
154      }
155  
156      /*
157      pub fn witness_less_than2(
158          &self,
159          mut layouter: impl Layouter<pallas::Base>,
160          a: Value<pallas::Base>,
161          b: Value<pallas::Base>,
162          offset: usize,
163          strict: bool,
164      ) -> Result<AssignedCell<pallas::Base, pallas::Base>, Error> {
165          let (a, _, a_offset) = layouter.assign_region(
166              || "a less than b",
167              |mut region: Region<'_, pallas::Base>| {
168                  let a = region.assign_advice(|| "a", self.config.a, offset, || a)?;
169                  let b = region.assign_advice(|| "b", self.config.b, offset, || b)?;
170                  let a_offset = self.less_than(region, a.clone(), b.clone(), offset)?;
171                  Ok((a, b, a_offset))
172              },
173          )?;
174  
175          self.less_than_range_check(layouter, a, a_offset.clone(), strict)?;
176  
177          Ok(a_offset)
178      }
179      */
180  
181      pub fn copy_less_than(
182          &self,
183          mut layouter: impl Layouter<pallas::Base>,
184          a: AssignedCell<pallas::Base, pallas::Base>,
185          b: AssignedCell<pallas::Base, pallas::Base>,
186          offset: usize,
187          strict: bool,
188      ) -> Result<(), Error> {
189          let (a, _, a_offset) = layouter.assign_region(
190              || "a less than b",
191              |mut region: Region<'_, pallas::Base>| {
192                  let a = a.copy_advice(|| "a", &mut region, self.config.a, offset)?;
193                  let b = b.copy_advice(|| "b", &mut region, self.config.b, offset)?;
194                  let a_offset = self.less_than(region, a.clone(), b.clone(), offset, strict)?;
195                  Ok((a, b, a_offset))
196              },
197          )?;
198  
199          self.less_than_range_check(layouter, a, a_offset)?;
200  
201          Ok(())
202      }
203  
204      pub fn less_than_range_check(
205          &self,
206          mut layouter: impl Layouter<pallas::Base>,
207          a: AssignedCell<pallas::Base, pallas::Base>,
208          a_offset: AssignedCell<pallas::Base, pallas::Base>,
209      ) -> Result<(), Error> {
210          let range_a_chip = NativeRangeCheckChip::<WINDOW_SIZE, NUM_OF_BITS>::construct(
211              self.config.range_a_config.clone(),
212          );
213          let range_a_offset_chip = NativeRangeCheckChip::<WINDOW_SIZE, NUM_OF_BITS>::construct(
214              self.config.range_a_offset_config.clone(),
215          );
216  
217          range_a_chip.copy_range_check(layouter.namespace(|| "a copy_range_check"), a)?;
218  
219          range_a_offset_chip
220              .copy_range_check(layouter.namespace(|| "a_offset copy_range_check"), a_offset)?;
221  
222          Ok(())
223      }
224  
225      pub fn less_than(
226          &self,
227          mut region: Region<'_, pallas::Base>,
228          a: AssignedCell<pallas::Base, pallas::Base>,
229          b: AssignedCell<pallas::Base, pallas::Base>,
230          offset: usize,
231          strict: bool,
232      ) -> Result<AssignedCell<pallas::Base, pallas::Base>, Error> {
233          if strict {
234              // enable `less_than` selector
235              self.config.s_lt.enable(&mut region, offset)?;
236          } else {
237              self.config.s_leq.enable(&mut region, offset)?;
238          }
239  
240          let two_pow_m = pallas::Base::from(2).pow([NUM_OF_BITS as u64, 0, 0, 0]);
241          let a_offset = if strict {
242              a.value().zip(b.value()).map(|(a, b)| *a + (two_pow_m - b))
243          } else {
244              a.value().zip(b.value()).map(|(a, b)| *a + (two_pow_m - b) - pallas::Base::one())
245          };
246          let a_offset =
247              region.assign_advice(|| "a_offset", self.config.a_offset, offset, || a_offset)?;
248  
249          Ok(a_offset)
250      }
251  }
252  
253  #[cfg(test)]
254  mod tests {
255      use super::*;
256      use darkfi_sdk::crypto::pasta_prelude::PrimeField;
257      use halo2_proofs::{
258          circuit::floor_planner,
259          dev::{CircuitLayout, MockProver},
260          plonk::Circuit,
261      };
262  
263      macro_rules! test_circuit {
264          ($k: expr, $strict:expr, $window_size:expr, $num_bits:expr, $valid_pairs:expr, $invalid_pairs:expr) => {
265              #[derive(Default)]
266              struct LessThanCircuit {
267                  a: Value<pallas::Base>,
268                  b: Value<pallas::Base>,
269              }
270  
271              impl Circuit<pallas::Base> for LessThanCircuit {
272                  type Config = (LessThanConfig<$window_size, $num_bits>, Column<Advice>);
273                  type FloorPlanner = floor_planner::V1;
274                  type Params = ();
275  
276                  fn without_witnesses(&self) -> Self {
277                      Self { a: Value::unknown(), b: Value::unknown() }
278                  }
279  
280                  fn configure(meta: &mut ConstraintSystem<pallas::Base>) -> Self::Config {
281                      let w = meta.advice_column();
282                      meta.enable_equality(w);
283  
284                      let a = meta.advice_column();
285                      let b = meta.advice_column();
286                      let a_offset = meta.advice_column();
287                      let z1 = meta.advice_column();
288                      let z2 = meta.advice_column();
289  
290                      let k_values_table = meta.lookup_table_column();
291  
292                      let constants = meta.fixed_column();
293                      meta.enable_constant(constants);
294  
295                      (
296                          LessThanChip::<$window_size, $num_bits>::configure(
297                              meta,
298                              a,
299                              b,
300                              a_offset,
301                              z1,
302                              z2,
303                              k_values_table,
304                          ),
305                          w,
306                      )
307                  }
308  
309                  fn synthesize(
310                      &self,
311                      config: Self::Config,
312                      mut layouter: impl Layouter<pallas::Base>,
313                  ) -> Result<(), Error> {
314                      let less_than_chip =
315                          LessThanChip::<$window_size, $num_bits>::construct(config.0.clone());
316  
317                      NativeRangeCheckChip::<$window_size, $num_bits>::load_k_table(
318                          &mut layouter,
319                          config.0.k_values_table,
320                      )?;
321  
322                      less_than_chip.witness_less_than(
323                          layouter.namespace(|| "a < b"),
324                          self.a,
325                          self.b,
326                          0,
327                          $strict,
328                      )?;
329  
330                      Ok(())
331                  }
332              }
333  
334              use plotters::prelude::*;
335              let circuit = LessThanCircuit {
336                  a: Value::known(pallas::Base::zero()),
337                  b: Value::known(pallas::Base::one()),
338              };
339              let file_name = format!("target/lessthan_check_{:?}_circuit_layout.png", $num_bits);
340              let root = BitMapBackend::new(file_name.as_str(), (3840, 2160)).into_drawing_area();
341              CircuitLayout::default().render($k, &circuit, &root).unwrap();
342  
343              let check = if $strict { "<" } else { "<=" };
344              for (a, b) in $valid_pairs {
345                  println!("{:?} bit (valid) {:?} {} {:?} check", $num_bits, a, check, b);
346                  let circuit = LessThanCircuit { a: Value::known(a), b: Value::known(b) };
347                  let prover = MockProver::run($k, &circuit, vec![]).unwrap();
348                  prover.assert_satisfied();
349              }
350  
351              for (a, b) in $invalid_pairs {
352                  println!("{:?} bit (invalid) {:?} {} {:?} check", $num_bits, a, check, b);
353                  let circuit = LessThanCircuit { a: Value::known(a), b: Value::known(b) };
354                  let prover = MockProver::run($k, &circuit, vec![]).unwrap();
355                  assert!(prover.verify().is_err())
356              }
357          };
358      }
359  
360      #[test]
361      fn leq_64() {
362          let k = 5;
363          const WINDOW_SIZE: usize = 3;
364          const NUM_OF_BITS: usize = 64;
365  
366          let valid_pairs = [
367              (pallas::Base::ZERO, pallas::Base::ZERO),
368              (pallas::Base::ONE, pallas::Base::ONE),
369              (pallas::Base::from(13), pallas::Base::from(15)),
370              (pallas::Base::ZERO, pallas::Base::from(u64::MAX)),
371              (pallas::Base::ONE, pallas::Base::from(rand::random::<u64>())),
372              (pallas::Base::from(u64::MAX), pallas::Base::from(u64::MAX) + pallas::Base::ONE),
373              (pallas::Base::from(u64::MAX), pallas::Base::from(u64::MAX)),
374          ];
375  
376          let invalid_pairs = [
377              (pallas::Base::from(14), pallas::Base::from(11)),
378              (pallas::Base::from(u64::MAX), pallas::Base::ZERO),
379              (pallas::Base::ONE, pallas::Base::ZERO),
380          ];
381          test_circuit!(k, false, WINDOW_SIZE, NUM_OF_BITS, valid_pairs, invalid_pairs);
382      }
383  
384      #[test]
385      fn less_than_64() {
386          let k = 5;
387          const WINDOW_SIZE: usize = 3;
388          const NUM_OF_BITS: usize = 64;
389  
390          let valid_pairs = [
391              (pallas::Base::from(13), pallas::Base::from(15)),
392              (pallas::Base::ZERO, pallas::Base::from(u64::MAX)),
393              (pallas::Base::ONE, pallas::Base::from(rand::random::<u64>())),
394              (pallas::Base::from(u64::MAX), pallas::Base::from(u64::MAX) + pallas::Base::ONE),
395          ];
396  
397          let invalid_pairs = [
398              (pallas::Base::from(14), pallas::Base::from(11)),
399              (pallas::Base::from(u64::MAX), pallas::Base::ZERO),
400              (pallas::Base::ZERO, pallas::Base::ZERO),
401              (pallas::Base::ONE, pallas::Base::ONE),
402              (pallas::Base::ONE, pallas::Base::ZERO),
403              (pallas::Base::from(u64::MAX), pallas::Base::from(u64::MAX)),
404          ];
405          test_circuit!(k, true, WINDOW_SIZE, NUM_OF_BITS, valid_pairs, invalid_pairs);
406      }
407  
408      #[test]
409      fn leq_253() {
410          let k = 7;
411          const WINDOW_SIZE: usize = 3;
412          const NUM_OF_BITS: usize = 253;
413  
414          const P_MINUS_1: pallas::Base = pallas::Base::from_raw([
415              0x992d30ed00000000,
416              0x224698fc094cf91b,
417              0x0000000000000000,
418              0x4000000000000000,
419          ]);
420  
421          // 2^253 - 1. This is the maximum we can check.
422          const MAX_253: pallas::Base = pallas::Base::from_raw([
423              0xFFFFFFFFFFFFFFFF,
424              0xFFFFFFFFFFFFFFFF,
425              0xFFFFFFFFFFFFFFFF,
426              0x1FFFFFFFFFFFFFFF,
427          ]);
428  
429          let valid_pairs = [
430              (pallas::Base::ZERO, pallas::Base::ZERO),
431              (pallas::Base::ZERO, pallas::Base::ONE),
432              (pallas::Base::from(u64::MAX), pallas::Base::from(u64::MAX) + pallas::Base::ONE),
433              (
434                  pallas::Base::from_u128(u128::MAX),
435                  pallas::Base::from_u128(u128::MAX) + pallas::Base::ONE,
436              ),
437              (MAX_253, MAX_253),
438              (MAX_253 - pallas::Base::from(2), MAX_253 - pallas::Base::ONE),
439              (MAX_253 - pallas::Base::ONE, MAX_253),
440              (MAX_253, MAX_253 + pallas::Base::ONE),
441          ];
442  
443          let invalid_pairs = [
444              (pallas::Base::ONE, pallas::Base::ZERO),
445              (P_MINUS_1 - pallas::Base::ONE, P_MINUS_1),
446              (P_MINUS_1, pallas::Base::ZERO),
447              (P_MINUS_1, P_MINUS_1),
448              (MAX_253, pallas::Base::ZERO),
449              (MAX_253, pallas::Base::ONE),
450              (MAX_253 + pallas::Base::ONE, pallas::Base::ZERO),
451              (MAX_253 + pallas::Base::ONE, pallas::Base::ONE),
452              (MAX_253 + pallas::Base::ONE, MAX_253 + pallas::Base::ONE),
453              (MAX_253 + pallas::Base::ONE, MAX_253 + pallas::Base::from(2)),
454          ];
455  
456          test_circuit!(k, false, WINDOW_SIZE, NUM_OF_BITS, valid_pairs, invalid_pairs);
457      }
458  
459      #[test]
460      fn less_than_253() {
461          let k = 7;
462          const WINDOW_SIZE: usize = 3;
463          const NUM_OF_BITS: usize = 253;
464  
465          const P_MINUS_1: pallas::Base = pallas::Base::from_raw([
466              0x992d30ed00000000,
467              0x224698fc094cf91b,
468              0x0000000000000000,
469              0x4000000000000000,
470          ]);
471  
472          // 2^253 - 1. This is the maximum we can check.
473          const MAX_253: pallas::Base = pallas::Base::from_raw([
474              0xFFFFFFFFFFFFFFFF,
475              0xFFFFFFFFFFFFFFFF,
476              0xFFFFFFFFFFFFFFFF,
477              0x1FFFFFFFFFFFFFFF,
478          ]);
479  
480          let valid_pairs = [
481              (pallas::Base::ZERO, pallas::Base::ONE),
482              (pallas::Base::from(u64::MAX), pallas::Base::from(u64::MAX) + pallas::Base::ONE),
483              (
484                  pallas::Base::from_u128(u128::MAX),
485                  pallas::Base::from_u128(u128::MAX) + pallas::Base::ONE,
486              ),
487              (MAX_253 - pallas::Base::from(2), MAX_253 - pallas::Base::ONE),
488              (MAX_253 - pallas::Base::ONE, MAX_253),
489              (MAX_253, MAX_253 + pallas::Base::ONE),
490          ];
491  
492          let invalid_pairs = [
493              (pallas::Base::ZERO, pallas::Base::ZERO),
494              (pallas::Base::ONE, pallas::Base::ZERO),
495              (P_MINUS_1 - pallas::Base::ONE, P_MINUS_1),
496              (P_MINUS_1, P_MINUS_1),
497              (P_MINUS_1, pallas::Base::ZERO),
498              (MAX_253, MAX_253),
499              (MAX_253, pallas::Base::ZERO),
500              (MAX_253, pallas::Base::ONE),
501              (MAX_253 + pallas::Base::ONE, pallas::Base::ZERO),
502              (MAX_253 + pallas::Base::ONE, pallas::Base::ONE),
503              (MAX_253 + pallas::Base::ONE, MAX_253 + pallas::Base::ONE),
504              (MAX_253 + pallas::Base::ONE, MAX_253 + pallas::Base::from(2)),
505          ];
506  
507          test_circuit!(k, true, WINDOW_SIZE, NUM_OF_BITS, valid_pairs, invalid_pairs);
508      }
509  }