/ src / toprf.c
toprf.c
  1  #include <string.h>
  2  #include "oprf.h"
  3  #include "toprf.h"
  4  #include <arpa/inet.h>
  5  #ifdef UNIT_TEST
  6  #include "utils.h"
  7  #endif
  8  
  9  /*
 10      @copyright 2023, Stefan Marsiske toprf@ctrlc.hu
 11      This file is part of liboprf.
 12  
 13      liboprf is free software: you can redistribute it and/or
 14      modify it under the terms of the GNU Lesser General Public License
 15      as published by the Free Software Foundation, either version 3 of
 16      the License, or (at your option) any later version.
 17  
 18      liboprf is distributed in the hope that it will be useful,
 19      but WITHOUT ANY WARRANTY; without even the implied warranty of
 20      MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 21      GNU Lesser General Public License for more details.
 22  
 23      You should have received a copy of the License
 24      along with liboprf. If not, see <http://www.gnu.org/licenses/>.
 25  */
 26  
 27  // implements TOPRF from https://eprint.iacr.org/2017/363
 28  // quote from page 9 (first line is last on page 8)
 29  
 30  // The underlying PRF, fk(x) = H2(x, (H1(x))k), remains unchanged, but the
 31  // key k is shared using Shamir secret-sharing across n servers, where server Si
 32  // stores the key share ki. The initialization of such secret-sharing can be done via
 33  // a Distributed Key Generation (DKG) for discrete-log-based systems, e.g. [16],
 34  // and in Figure 2 we assume it is done with a UC functionality FDKG which we
 35  // discuss further below. For evaluation, given any subset SE of t + 1 servers, the
 36  // user U sends to each of them the same message a = (H′(x))r for random r,
 37  // exactly as in the single-server OPRF protocol 2HashDH. If each server Si in SE
 38  // returned bi = aki then U could reconstruct the value ak using standard Lagrange
 39  // interpolation in the exponent, i.e. ak = � i∈SE bλi i with the Lagrange coefficients
 40  // λi computed using the indexes of servers in SE. After computing ak, the value
 41  // of fk(x) is computed by U by deblinding ak exactly as in the case of protocol
 42  // 2HashDH. Note that this takes a single exponentiation for each server and two
 43  // exponentiations for the user (to compute a and to deblind ak) plus one multi-
 44  // exponentiation by U to compute the Lagrange interpolation on the bi values.
 45  
 46  // run with
 47  // gcc -o toprf -g -Wall toprf.c -lsodium liboprf.a
 48  
 49  typedef struct {
 50    uint8_t index;
 51    uint8_t value[crypto_core_ristretto255_BYTES];
 52  } __attribute((packed)) TOPRF_Part;
 53  
 54  void __attribute__((visibility("hidden"))) lcoeff(const uint8_t index, const uint8_t x, const size_t degree, const uint8_t peers[degree], uint8_t result[crypto_scalarmult_ristretto255_SCALARBYTES]) {
 55    uint8_t xscalar[crypto_scalarmult_ristretto255_SCALARBYTES]={0};
 56    xscalar[0]=x;
 57  
 58    uint8_t iscalar[crypto_scalarmult_ristretto255_SCALARBYTES]={0};
 59    iscalar[0]=index;
 60  
 61    uint8_t divident[crypto_scalarmult_ristretto255_SCALARBYTES]={0};
 62    divident[0]=1;
 63  
 64    uint8_t divisor[crypto_scalarmult_ristretto255_SCALARBYTES]={0};
 65    divisor[0]=1;
 66  
 67    for(size_t j=0;j<degree;j++) {
 68      if(peers[j]==index) continue;
 69      uint8_t tmp[crypto_scalarmult_ristretto255_SCALARBYTES]={0};
 70      tmp[0]=peers[j];
 71      //divident*=x-peers[j];
 72      crypto_core_ristretto255_scalar_sub(tmp, xscalar, tmp);
 73      crypto_core_ristretto255_scalar_mul(divident, divident, tmp);
 74      //divisor*=peers[j]-i;
 75      memset(tmp, 0, sizeof tmp);
 76      tmp[0]=peers[j];
 77      crypto_core_ristretto255_scalar_sub(tmp, iscalar, tmp);
 78      crypto_core_ristretto255_scalar_mul(divisor, divisor, tmp);
 79    }
 80    crypto_core_ristretto255_scalar_invert(divisor, divisor);
 81    crypto_core_ristretto255_scalar_mul(result, divisor, divident);
 82  }
 83  
 84  // interpolates a polynomial of degree t at point x: y = f(x), given t shares of the polynomial
 85  void __attribute__((visibility("hidden"))) interpolate(const uint8_t x, const uint8_t t, const TOPRF_Share shares[t], uint8_t y[crypto_scalarmult_ristretto255_SCALARBYTES]) {
 86    memset(y,0,crypto_scalarmult_ristretto255_SCALARBYTES);
 87    uint8_t l[crypto_scalarmult_ristretto255_SCALARBYTES];
 88  
 89    uint8_t indexes[t];
 90    for(size_t i=0;i<t;i++) {
 91      indexes[i]=shares[i].index;
 92    }
 93    //dump(indexes, sizeof indexes, "indexes");
 94  
 95    for(unsigned i=0;i<t;i++) {
 96      lcoeff(indexes[i], x, t, indexes, l);
 97      //dump(l, sizeof l, "l %d,%d", i+1, x);
 98      uint8_t tmp[crypto_scalarmult_ristretto255_BYTES];
 99      //dump(shares[i].value, 32, "share %d", shares[i].index);
100      crypto_core_ristretto255_scalar_mul(tmp, l, shares[i].value);
101      crypto_core_ristretto255_scalar_add(y, y, tmp);
102      //dump(y, 32, "result ");
103    }
104  }
105  
106  void __attribute__((visibility("hidden"))) coeff(const uint8_t index, const size_t peers_len, const uint8_t peers[peers_len], uint8_t result[crypto_scalarmult_ristretto255_SCALARBYTES]) {
107    lcoeff(index,0,peers_len,peers,result);
108  }
109  
110  void toprf_create_shares(const uint8_t secret[crypto_core_ristretto255_SCALARBYTES],
111                     const uint8_t n,
112                     const uint8_t threshold,
113                     uint8_t _shares[n][TOPRF_Share_BYTES]) {
114    TOPRF_Share *shares= (TOPRF_Share*)_shares;
115  
116    uint8_t a[threshold-1][crypto_core_ristretto255_SCALARBYTES];
117    uint8_t i;
118    for(i=0;i<threshold-1;i++) {
119  #ifdef UNIT_TEST
120      debian_rng_scalar(a[i]);
121      dump(a[i],crypto_core_ristretto255_SCALARBYTES,"\t");
122  #else
123      crypto_core_ristretto255_scalar_random(a[i]);
124  #endif
125    }
126    for(i=1;i<=n;i++) {
127      //f(x) = a_0 + a_1*x + a_2*x^2 + a_3*x^3 + ⋯ + a_(k−1)*x^(k−1)
128      shares[i-1].index=i;
129      uint8_t x[crypto_core_ristretto255_SCALARBYTES]={0};
130      x[0]=i;
131      memcpy(shares[i-1].value, secret, crypto_core_ristretto255_SCALARBYTES);
132      for(int j=0;j<threshold-1;j++) {
133        // a_j^j
134        uint8_t tmp[crypto_core_ristretto255_SCALARBYTES];
135        crypto_core_ristretto255_scalar_mul(tmp, a[j], x);
136        for(int exp=0;exp<j;exp++) {
137          crypto_core_ristretto255_scalar_mul(tmp, tmp, x);
138        }
139        crypto_core_ristretto255_scalar_add(shares[i-1].value, shares[i-1].value, tmp);
140      }
141  #ifdef UNIT_TEST
142      dump(shares[i-1].value,32,"f(%d)", i-1);
143  #endif
144    }
145  }
146  
147  static void sort_parts(const int n, const TOPRF_Part parts[n], uint8_t indexes[n]) {
148    uint8_t arr[n];
149    for(uint8_t i=0;i<n;i++) {
150      arr[i]=parts[i].index;
151      indexes[i]=i;
152    }
153  
154    for (uint8_t c = 1 ; c <= n - 1; c++) {
155      uint8_t d = c, t, t1;
156      while(d > 0 && arr[d] < arr[d-1]) {
157        t = arr[d];
158        t1 = indexes[d];
159        arr[d] = arr[d-1];
160        indexes[d] = indexes[d-1];
161        arr[d-1] = t;
162        indexes[d-1] = t1;
163        d--;
164      }
165    }
166  }
167  
168  int toprf_thresholdmult(const size_t response_len,
169                          const uint8_t _responses[response_len][TOPRF_Part_BYTES],
170                          uint8_t result[crypto_scalarmult_ristretto255_BYTES]) {
171    const TOPRF_Part *responses=(TOPRF_Part*) _responses;
172    uint8_t lpoly[crypto_scalarmult_ristretto255_SCALARBYTES];
173    uint8_t gki[crypto_scalarmult_ristretto255_BYTES];
174    memset(result,0,crypto_scalarmult_ristretto255_BYTES);
175  
176    // sort the responses by their indexes
177    uint8_t indexed_indexes[response_len];
178    if(response_len>255) return 1;
179    sort_parts((uint8_t) response_len, (TOPRF_Part*) responses, indexed_indexes);
180  
181    uint8_t indexes[response_len];
182    for(size_t i=0;i<response_len;i++) {
183      indexes[indexed_indexes[i]]=responses[indexed_indexes[i]].index;
184    }
185    for(size_t i=0;i<response_len;i++) {
186      coeff(responses[indexed_indexes[i]].index, response_len, indexes, lpoly);
187  
188      // betaki = g^{k_i}^{lpoly}
189      if(crypto_scalarmult_ristretto255(gki, lpoly, responses[indexed_indexes[i]].value)) {
190        return 1;
191      }
192      crypto_core_ristretto255_add(result,result,gki);
193    }
194    return 0;
195  }
196  
197  int toprf_Evaluate(const uint8_t _k[TOPRF_Share_BYTES],
198                     const uint8_t blinded[crypto_core_ristretto255_BYTES],
199                     const uint8_t self, const uint8_t *indexes, const uint16_t index_len,
200                     uint8_t _Z[TOPRF_Part_BYTES]) {
201    uint8_t lpoly[crypto_scalarmult_ristretto255_SCALARBYTES];
202    coeff(self, index_len, indexes, lpoly);
203    // kl = k * lpoly
204  
205    uint8_t kl[crypto_core_ristretto255_SCALARBYTES];
206    const TOPRF_Share *k=(TOPRF_Share*) _k;
207    crypto_core_ristretto255_scalar_mul(kl, k->value, lpoly);
208  
209    TOPRF_Part *Z=(TOPRF_Part*) _Z;
210    if(oprf_Evaluate(kl,blinded, Z->value)) return 1;
211  
212    return 0;
213  }
214  
215  int toprf_thresholdcombine(const size_t response_len,
216                              const uint8_t _responses[response_len][TOPRF_Part_BYTES],
217                              uint8_t result[crypto_scalarmult_ristretto255_BYTES]) {
218    if(response_len>255) return 1;
219    const TOPRF_Part *responses=(TOPRF_Part*) _responses;
220    memset(result,0,crypto_scalarmult_ristretto255_BYTES);
221  
222    uint8_t indexed_indexes[response_len];
223    sort_parts((uint8_t) response_len, (TOPRF_Part*) responses, indexed_indexes);
224  
225    for(size_t i=0;i<response_len;i++) {
226      crypto_core_ristretto255_add(result,result,responses[indexed_indexes[i]].value);
227    }
228    return 0;
229  }
230  
231  
232  int toprf_3hashtdh(const uint8_t _k[TOPRF_Share_BYTES],
233                     const uint8_t _z[TOPRF_Share_BYTES],
234                     const uint8_t alpha[crypto_core_ristretto255_BYTES],
235                     const uint8_t *ssid_S, const uint16_t ssid_S_len,
236                     uint8_t _beta[TOPRF_Part_BYTES]) {
237    // essentially calculates the following pythonish value
238    //h2 = evaluate(
239    //    z[1:],
240    //    crypto_core_ristretto255_from_hash(crypto_generichash(ssid_S + alpha, outlen=64)),
241    //    )
242    //beta = evaluate(k[1:], alpha)
243    //return (k[0]+crypto_core_ristretto255_add(beta, h2))
244  
245    const TOPRF_Share *k=(TOPRF_Share*) _k;
246    TOPRF_Part *beta=(TOPRF_Part*) _beta;
247  
248    beta->index=k->index;
249    if(oprf_Evaluate(k->value, alpha, beta->value)) return 1;
250  
251    // hash (ssid_S + alpha, outlen=64)
252    crypto_generichash_state h_state;
253    crypto_generichash_init(&h_state, NULL, 0, crypto_core_ristretto255_HASHBYTES);
254    uint16_t len=htons((uint16_t) ssid_S_len); // we have a guard above restricting to 1KB the proto_name_len
255    crypto_generichash_update(&h_state, (uint8_t*) &len, 2);
256    crypto_generichash_update(&h_state, ssid_S, ssid_S_len);
257    crypto_generichash_update(&h_state, alpha, crypto_core_ristretto255_BYTES);
258    uint8_t hash[crypto_core_ristretto255_HASHBYTES];
259    crypto_generichash_final(&h_state,hash,sizeof hash);
260  
261    // hash-to-curve
262    uint8_t point[crypto_scalarmult_ristretto255_BYTES];
263    if(0!=voprf_hash_to_group(hash, sizeof hash, point)) return -1;
264  
265    TOPRF_Part h2;
266    const TOPRF_Share *z=(TOPRF_Share*) _z;
267    if(oprf_Evaluate(z->value, point, h2.value)) return 1;
268  
269    crypto_core_ristretto255_add(beta->value, beta->value, h2.value);
270  
271    return 0;
272  }