/ utils / merkleTree.js
merkleTree.js
  1  const { keccak256, bufferToHex } = require('ethereumjs-util');
  2  
  3  class MerkleTree {
  4    constructor(elements) {
  5      // Filter empty strings and hash elements
  6      this.elements = elements.filter(el => el).map(el => keccak256(el));
  7      
  8      // Deduplicate elements
  9      this.elements = this.bufDedup(this.elements);
 10      // Sort elements
 11      this.elements.sort(Buffer.compare);
 12      // Create layers
 13      this.layers = this.getLayers(this.elements);
 14    }
 15  
 16    getProofFlags(els, proofs) {
 17      let ids = els.map((el) => this.bufIndexOf(el, this.elements)).sort((a,b) => a == b ? 0 : a > b ? 1 : -1);
 18      if (!ids.every((idx) => idx != -1)) {
 19        throw new Error("Element does not exist in Merkle tree");
 20      }
 21  
 22      const tested = [];
 23      const flags = [];
 24      for (let index = 0; index < this.layers.length; index++) {
 25        const layer = this.layers[index];
 26        ids = ids.reduce((ids, idx) => {
 27          const skipped = tested.includes(layer[idx]);
 28          if(!skipped) {
 29            const pairElement = this.getPairElement(idx, layer);
 30            const proofUsed = proofs.includes(layer[idx]) || proofs.includes(pairElement);
 31            pairElement && flags.push(!proofUsed);
 32            tested.push(layer[idx]);
 33            tested.push(pairElement);
 34          } 
 35          ids.push(Math.floor(idx / 2));  
 36          return ids;
 37        }, [])
 38      }
 39      return flags;
 40    }
 41  
 42    getElements(els) {
 43      let ids = els.map((el) => this.bufIndexOf(el, this.elements));
 44      if (!ids.every((idx) => idx != -1)) {
 45        throw new Error("Element does not exist in Merkle tree");
 46      }
 47      
 48      const elsH = [];
 49      for (let j = 0; j < ids.length; j++) {
 50        elsH.push(this.layers[0][ids[j]]);
 51      }
 52      return this.bufDedup(elsH).sort(Buffer.compare);
 53    }
 54  
 55    getMultiProof(els) {
 56      let ids = els.map((el) => this.bufIndexOf(el, this.elements)).sort((a,b) => a == b ? 0 : a > b ? 1 : -1);
 57      if (!ids.every((idx) => idx != -1)) {
 58        throw new Error("Element does not exist in Merkle tree");
 59      }
 60      
 61      const hashes = [];
 62      const proof = [];
 63      var nextIds = [];
 64  
 65      for (let index = 0; index < this.layers.length; index++) {
 66        const layer = this.layers[index];
 67        for (let j = 0; j < ids.length; j++) {
 68          const idx = ids[j];
 69          const pairElement = this.getPairElement(idx, layer);
 70          
 71          hashes.push(layer[idx]);
 72          pairElement && proof.push(pairElement)
 73    
 74          nextIds.push(Math.floor(idx / 2));  
 75        }
 76        ids = nextIds.filter((value, index, self) => self.indexOf(value) === index);
 77        nextIds = [];
 78      }
 79      return proof.filter((value) => !hashes.includes(value));
 80    }
 81  
 82    getHexMultiProof(els) {
 83      const multiProof = this.getMultiProof(els);
 84  
 85      return this.bufArrToHex(multiProof);
 86    }
 87  
 88    getLayers(elements) {
 89      if (elements.length == 0) {
 90        return [[""]];
 91      }
 92  
 93      const layers = [];
 94      layers.push(elements);
 95  
 96      // Get next layer until we reach the root
 97      while (layers[layers.length - 1].length > 1) {
 98        layers.push(this.getNextLayer(layers[layers.length - 1]));
 99      }
100  
101      return layers;
102    }
103  
104    getNextLayer(elements) {
105      return elements.reduce((layer, el, idx, arr) => {
106        if (idx % 2 === 0) {
107          // Hash the current element with its pair element
108          layer.push(this.combinedHash(el, arr[idx + 1]));
109        }
110  
111        return layer;
112      }, []);
113    }
114  
115    combinedHash(first, second) {
116      if (!first) { return second; }
117      if (!second) { return first; }
118  
119      return keccak256(this.sortAndConcat(first, second));
120    }
121  
122    getRoot() {
123      return this.layers[this.layers.length - 1][0];
124    }
125  
126    getHexRoot() {
127      return bufferToHex(this.getRoot());
128    }
129  
130    getProof(el) {
131      let idx = this.bufIndexOf(el, this.elements);
132  
133      if (idx === -1) {
134        throw new Error("Element does not exist in Merkle tree");
135      }
136  
137      return this.layers.reduce((proof, layer) => {
138        const pairElement = this.getPairElement(idx, layer);
139  
140        if (pairElement) {
141          proof.push(pairElement);
142        }
143  
144        idx = Math.floor(idx / 2);
145  
146        return proof;
147      }, []);
148    }
149  
150  
151    getHexProof(el) {
152      const proof = this.getProof(el);
153  
154      return this.bufArrToHex(proof);
155    }
156  
157    getPairElement(idx, layer) {
158      const pairIdx = idx % 2 === 0 ? idx + 1 : idx - 1;
159  
160      if (pairIdx < layer.length) {
161        return layer[pairIdx];
162      } else {
163        return null;
164      }
165    }
166  
167    bufIndexOf(el, arr) {
168      let hash;
169  
170      // Convert element to 32 byte hash if it is not one already
171      if (el.length !== 32 || !Buffer.isBuffer(el)) {
172        hash = keccak256(el);
173      } else {
174        hash = el;
175      }
176  
177      for (let i = 0; i < arr.length; i++) {
178        if (hash.equals(arr[i])) {
179          return i;
180        }
181      }
182  
183      return -1;
184    }
185  
186    bufDedup(elements) {
187      return elements.filter((el, idx) => {
188        return this.bufIndexOf(el, elements) === idx;
189      });
190    }
191  
192    bufArrToHex(arr) {
193      if (arr.some(el => !Buffer.isBuffer(el))) {
194        throw new Error("Array is not an array of buffers");
195      }
196      
197      return arr.map(el => '0x' + el.toString('hex'));
198    }
199  
200    sortAndConcat(...args) {
201      return Buffer.concat([...args].sort(Buffer.compare));
202    }
203  }
204  
205  exports.MerkleTree = MerkleTree;