/ test-vectors / witness_calculator.js
witness_calculator.js
  1  /* globals WebAssembly */
  2  /*
  3  
  4  Copyright 2020 0KIMS association.
  5  
  6  Licensed under the Apache License, Version 2.0 (the "License");
  7  you may not use this file except in compliance with the License.
  8  You may obtain a copy of the License at
  9  
 10      http://www.apache.org/licenses/LICENSE-2.0
 11  
 12  Unless required by applicable law or agreed to in writing, software
 13  distributed under the License is distributed on an "AS IS" BASIS,
 14  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 15  See the License for the specific language governing permissions and
 16  limitations under the License.
 17  
 18  */
 19  
 20  const bigInt = require("big-integer");
 21  
 22  const fnv = require("fnv-plus");
 23  
 24  function flatArray(a) {
 25      var res = [];
 26      fillArray(res, a);
 27      return res;
 28  
 29      function fillArray(res, a) {
 30          if (Array.isArray(a)) {
 31              for (let i=0; i<a.length; i++) {
 32                  fillArray(res, a[i]);
 33              }
 34          } else {
 35              res.push(bigInt(a));
 36          }
 37      }
 38  }
 39  
 40  function fnvHash(str) {
 41      return fnv.hash(str, 64).hex();
 42  }
 43  
 44  
 45  module.exports = async function builder(code, options) {
 46  
 47      options = options || {};
 48  
 49      const memory = new WebAssembly.Memory({initial:20000});
 50      const wasmModule = await WebAssembly.compile(code);
 51  
 52      let wc;
 53  
 54      const instance = await WebAssembly.instantiate(wasmModule, {
 55          env: {
 56              "memory": memory
 57          },
 58          runtime: {
 59              error: function(code, pstr, a,b,c,d) {
 60                  let errStr;
 61                  if (code == 7) {
 62                      errStr=p2str(pstr) + " " + wc.getFr(b).toString() + " != " + wc.getFr(c).toString() + " " +p2str(d);
 63                  } else {
 64                      errStr=p2str(pstr)+ " " + a + " " + b + " " + c + " " + d;
 65                  }
 66                  console.log("ERROR: ", code, errStr);
 67                  throw new Error(errStr);
 68              },
 69              log: function(a) {
 70                  console.log(wc.getFr(a).toString());
 71              },
 72              logGetSignal: function(signal, pVal) {
 73                  if (options.logGetSignal) {
 74                      options.logGetSignal(signal, wc.getFr(pVal) );
 75                  }
 76              },
 77              logSetSignal: function(signal, pVal) {
 78                  if (options.logSetSignal) {
 79                      options.logSetSignal(signal, wc.getFr(pVal) );
 80                  }
 81              },
 82              logStartComponent: function(cIdx) {
 83                  if (options.logStartComponent) {
 84                      options.logStartComponent(cIdx);
 85                  }
 86              },
 87              logFinishComponent: function(cIdx) {
 88                  if (options.logFinishComponent) {
 89                      options.logFinishComponent(cIdx);
 90                  }
 91              }
 92          }
 93      });
 94  
 95      const sanityCheck =
 96          options &&
 97          (
 98              options.sanityCheck ||
 99              options.logGetSignal ||
100              options.logSetSignal ||
101              options.logStartComponent ||
102              options.logFinishComponent
103          );
104  
105      wc = new WitnessCalculator(memory, instance, sanityCheck);
106      return wc;
107  
108      function p2str(p) {
109          const i8 = new Uint8Array(memory.buffer);
110  
111          const bytes = [];
112  
113          for (let i=0; i8[p+i]>0; i++)  bytes.push(i8[p+i]);
114  
115          return String.fromCharCode.apply(null, bytes);
116      }
117  };
118  
119  class WitnessCalculator {
120      constructor(memory, instance, sanityCheck) {
121          this.memory = memory;
122          this.i32 = new Uint32Array(memory.buffer);
123          this.instance = instance;
124  
125          this.n32 = (this.instance.exports.getFrLen() >> 2) - 2;
126          const pRawPrime = this.instance.exports.getPRawPrime();
127          console.log("pRawPrime:", pRawPrime);
128  
129          // console.log("0:", this.i32[(pRawPrime >> 2)]);
130          this.prime = bigInt(0);
131          for (let i=this.n32-1; i>=0; i--) {
132              this.prime = this.prime.shiftLeft(32);
133              this.prime = this.prime.add(bigInt(this.i32[(pRawPrime >> 2) + i]));
134          }
135          console.log("prime:", this.prime);
136  
137          this.mask32 = bigInt("FFFFFFFF", 16);
138          console.log("mask32:", this.mask32);
139          this.NVars = this.instance.exports.getNVars();
140          console.log("NVars:", this.NVars);
141          this.n64 = Math.floor((this.prime.bitLength() - 1) / 64)+1;
142          console.log("n64:", this.n64);
143          this.R = bigInt.one.shiftLeft(this.n64*64);
144          console.log("R:", this.R);
145          this.RInv = this.R.modInv(this.prime);
146          console.log("RInv:", this.RInv);
147          this.sanityCheck = sanityCheck;
148  
149      }
150  
151      async _doCalculateWitness(input, sanityCheck) {
152          this.instance.exports.init((this.sanityCheck || sanityCheck) ? 1 : 0);
153          const pSigOffset = this.allocInt();
154          console.log("pSigOffset:", pSigOffset);
155          const pFr = this.allocFr();
156          console.log("pFr:", pFr);
157          for (let k in input) {
158              const h = fnvHash(k);
159              const hMSB = parseInt(h.slice(0,8), 16);
160              const hLSB = parseInt(h.slice(8,16), 16);
161              console.log("h(", k, ") =", h, " = ", hMSB, hLSB);
162              this.instance.exports.getSignalOffset32(pSigOffset, 0, hMSB, hLSB);
163              const sigOffset = this.getInt(pSigOffset);
164              console.log("sigOffset:", sigOffset);
165              const fArr = flatArray(input[k]);
166              for (let i=0; i<fArr.length; i++) {
167                  this.setFr(pFr, fArr[i]);
168                  this.instance.exports.setSignal(0, 0, sigOffset + i, pFr);
169              }
170          }
171  
172      }
173  
174      async calculateWitness(input, sanityCheck) {
175          const self = this;
176  
177          const old0 = self.i32[0];
178          const w = [];
179  
180          await self._doCalculateWitness(input, sanityCheck);
181  
182          for (let i=0; i<self.NVars; i++) {
183              const pWitness = self.instance.exports.getPWitness(i);
184              w.push(self.getFr(pWitness));
185          }
186  
187          self.i32[0] = old0;
188          return w;
189      }
190  
191      async calculateBinWitness(input, sanityCheck) {
192          const self = this;
193  
194          const old0 = self.i32[0];
195  
196          await self._doCalculateWitness(input, sanityCheck);
197  
198          const pWitnessBuffer = self.instance.exports.getWitnessBuffer();
199  
200          self.i32[0] = old0;
201  
202          const buff = self.memory.buffer.slice(pWitnessBuffer, pWitnessBuffer + (self.NVars * self.n64 * 8));
203          return buff;
204      }
205  
206      allocInt() {
207          const p = this.i32[0];
208          this.i32[0] = p+8;
209          return p;
210      }
211  
212      allocFr() {
213          const p = this.i32[0];
214          this.i32[0] = p+this.n32*4 + 8;
215          return p;
216      }
217  
218      getInt(p) {
219          return this.i32[p>>2];
220      }
221  
222      setInt(p, v) {
223          this.i32[p>>2] = v;
224      }
225  
226      getFr(p) {
227          const self = this;
228          const idx = (p>>2);
229  
230          if (self.i32[idx + 1] & 0x80000000) {
231              let res= bigInt(0);
232              for (let i=self.n32-1; i>=0; i--) {
233                  res = res.shiftLeft(32);
234                  res = res.add(bigInt(self.i32[idx+2+i]));
235              }
236              if (self.i32[idx + 1] & 0x40000000) {
237                  return fromMontgomery(res);
238              } else {
239                  return res;
240              }
241  
242          } else {
243              if (self.i32[idx] & 0x80000000) {
244                  return self.prime.add( bigInt(self.i32[idx]).minus(bigInt(0x100000000)) );
245              } else {
246                  return bigInt(self.i32[idx]);
247              }
248          }
249  
250          function fromMontgomery(n) {
251              return n.times(self.RInv).mod(self.prime);
252          }
253  
254      }
255  
256  
257      setFr(p, v) {
258          const self = this;
259          v = bigInt(v);
260  
261          if (v.lt(bigInt("80000000", 16)) ) {
262              return setShortPositive(v);
263          }
264          if (v.geq(self.prime.minus(bigInt("80000000", 16))) ) {
265              return setShortNegative(v);
266          }
267          return setLongNormal(v);
268  
269          function setShortPositive(a) {
270              self.i32[(p >> 2)] = parseInt(a);
271              self.i32[(p >> 2) + 1] = 0;
272          }
273  
274          function setShortNegative(a) {
275              const b = bigInt("80000000", 16 ).add(a.minus(  self.prime.minus(bigInt("80000000", 16 ))));
276              self.i32[(p >> 2)] = parseInt(b);
277              self.i32[(p >> 2) + 1] = 0;
278          }
279  
280          function setLongNormal(a) {
281              self.i32[(p >> 2)] = 0;
282              self.i32[(p >> 2) + 1] = 0x80000000;
283              for (let i=0; i<self.n32; i++) {
284                  self.i32[(p >> 2) + 2 + i] = a.shiftRight(i*32).and(self.mask32);
285              }
286              console.log(">>>", self.i32[(p >> 2)] , self.i32[(p >> 2) + 1]);
287              console.log(">>>", self.i32.slice((p >> 2) + 2, (p >> 2) + 2 + self.n32));
288          }
289      }
290  }