polynomials_parallel.nim
1 # Constantine 2 # Copyright (c) 2018-2019 Status Research & Development GmbH 3 # Copyright (c) 2020-Present Mamy André-Ratsimbazafy 4 # Licensed and distributed under either of 5 # * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). 6 # * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). 7 # at your option. This file may not be copied, modified, or distributed except according to those terms. 8 9 import ./polynomials {.all.} 10 export polynomials 11 12 import 13 ../config/curves, 14 ../arithmetic, 15 ../../platforms/bithacks, 16 ../../threadpool/threadpool 17 18 ## ############################################################ 19 ## 20 ## Polynomials 21 ## Parallel Edition 22 ## 23 ## ############################################################ 24 25 proc evalPolyAt_parallel*[N: static int, Field]( 26 tp: Threadpool, 27 r: var Field, 28 poly: ptr PolynomialEval[N, Field], 29 z: ptr Field, 30 invRootsMinusZ: ptr array[N, Field], 31 domain: ptr PolyDomainEval[N, Field]) = 32 ## Evaluate a polynomial in evaluation form 33 ## at the point z 34 ## z MUST NOT be one of the roots of unity 35 ## 36 ## Parallelism: This only returns when computation is fully done 37 38 # p(z) = (1-zⁿ)/n ∑ ωⁱ/(ωⁱ-z) . p(ωⁱ) 39 40 mixin globalSum 41 static: doAssert N.isPowerOf2_vartime() 42 43 tp.parallelFor i in 0 ..< N: 44 captures: {poly, domain, invRootsMinusZ} 45 reduceInto(globalSum: Field): 46 prologue: 47 var workerSum {.noInit.}: Field 48 workerSum.setZero() 49 forLoop: 50 var iterSummand {.noInit.}: Field 51 iterSummand.prod(domain.rootsOfUnity[i], invRootsMinusZ[i]) 52 iterSummand *= poly.evals[i] 53 workerSum += iterSummand 54 merge(remoteSum: Flowvar[Field]): 55 workerSum += sync(remoteSum) 56 epilogue: 57 return workerSum 58 59 var t {.noInit.}: Field 60 t = z[] 61 const numDoublings = log2_vartime(uint32 N) # N is a power of 2 62 t.square_repeated(int numDoublings) # exponentiation by a power of 2 63 t.diff(Field(mres: Field.getMontyOne()), t) # TODO: refactor getMontyOne to getOne and return a field element. 64 65 r.prod(t, domain.invMaxDegree) 66 r *= sync(globalSum) 67 68 proc differenceQuotientEvalOffDomain_parallel*[N: static int, Field]( 69 tp: Threadpool, 70 r: ptr PolynomialEval[N, Field], 71 poly: ptr PolynomialEval[N, Field], 72 pZ: ptr Field, 73 invRootsMinusZ: ptr array[N, Field]) = 74 ## Compute r(x) = (p(x) - p(z)) / (x - z) 75 ## 76 ## for z != ωⁱ a power of a root of unity 77 ## 78 ## Input: 79 ## - invRootsMinusZ: 1/(ωⁱ-z) 80 ## - poly: p(x) a polynomial in evaluation form as an array of p(ωⁱ) 81 ## - rootsOfUnity: ωⁱ 82 ## - p(z) 83 ## 84 ## Parallelism: This only returns when computation is fully done 85 # TODO: we might want either awaitable for-loops 86 # or awaitable individual iterations 87 # for latency-hiding techniques 88 89 syncScope: 90 tp.parallelFor i in 0 ..< N: 91 captures: {r, poly, pZ, invRootsMinusZ} 92 # qᵢ = (p(ωⁱ) - p(z))/(ωⁱ-z) 93 var qi {.noinit.}: Field 94 qi.diff(poly.evals[i], pZ[]) 95 r.evals[i].prod(qi, invRootsMinusZ[i]) 96 97 proc differenceQuotientEvalInDomain_parallel*[N: static int, Field]( 98 tp: Threadpool, 99 r: ptr PolynomialEval[N, Field], 100 poly: ptr PolynomialEval[N, Field], 101 zIndex: uint32, 102 invRootsMinusZ: ptr array[N, Field], 103 domain: ptr PolyDomainEval[N, Field], 104 isBitReversedDomain: static bool) = 105 ## Compute r(x) = (p(x) - p(z)) / (x - z) 106 ## 107 ## for z = ωⁱ a power of a root of unity 108 ## 109 ## Input: 110 ## - poly: p(x) a polynomial in evaluation form as an array of p(ωⁱ) 111 ## - rootsOfUnity: ωⁱ 112 ## - invRootsMinusZ: 1/(ωⁱ-z) 113 ## - zIndex: the index of the root of unity power that matches z = ωⁱᵈˣ 114 ## 115 ## Parallelism: This only returns when computation is fully done 116 117 static: 118 # For powers of 2: x mod N == x and (N-1) 119 doAssert N.isPowerOf2_vartime() 120 121 mixin evalsZindex 122 123 tp.parallelFor i in 0 ..< N: 124 captures: {r, poly, domain, invRootsMinusZ, zIndex} 125 reduceInto(evalsZindex: Field): 126 prologue: 127 var worker_ri {.noInit.}: Field 128 worker_ri.setZero() 129 forLoop: 130 var iter_ri {.noInit.}: Field 131 if i == int(zIndex): 132 iter_ri.setZero() 133 else: 134 # qᵢ = (p(ωⁱ) - p(z))/(ωⁱ-z) 135 var qi {.noinit.}: Field 136 qi.diff(poly.evals[i], poly.evals[zIndex]) 137 r.evals[i].prod(qi, invRootsMinusZ[i]) 138 139 # q'ᵢ = -qᵢ * ωⁱ/z 140 # q'idx = ∑ q'ᵢ 141 iter_ri.neg(r.evals[i]) # -qᵢ 142 when isBitReversedDomain: 143 const logN = log2_vartime(uint32 N) 144 let invZidx = N - reverseBits(uint32 zIndex, logN) 145 let canonI = reverseBits(uint32 i, logN) 146 let idx = reverseBits((canonI + invZidx) and (N-1), logN) 147 iter_ri *= domain.rootsOfUnity[idx] # -qᵢ * ωⁱ/z (explanation at the bottom of serial impl) 148 else: 149 iter_ri *= domain.rootsOfUnity[(i+N-zIndex) and (N-1)] # -qᵢ * ωⁱ/z (explanation at the bottom of serial impl) 150 worker_ri += iter_ri 151 merge(remote_ri: Flowvar[Field]): 152 worker_ri += sync(remote_ri) 153 epilogue: 154 return worker_ri 155 156 r.evals[zIndex] = sync(evalsZindex)