/ constantine / math / polynomials / polynomials_parallel.nim
polynomials_parallel.nim
  1  # Constantine
  2  # Copyright (c) 2018-2019    Status Research & Development GmbH
  3  # Copyright (c) 2020-Present Mamy André-Ratsimbazafy
  4  # Licensed and distributed under either of
  5  #   * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
  6  #   * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
  7  # at your option. This file may not be copied, modified, or distributed except according to those terms.
  8  
  9  import ./polynomials {.all.}
 10  export polynomials
 11  
 12  import
 13    ../config/curves,
 14    ../arithmetic,
 15    ../../platforms/bithacks,
 16    ../../threadpool/threadpool
 17  
 18  ## ############################################################
 19  ##
 20  ##                    Polynomials
 21  ##                 Parallel Edition
 22  ##
 23  ## ############################################################
 24  
 25  proc evalPolyAt_parallel*[N: static int, Field](
 26         tp: Threadpool,
 27         r: var Field,
 28         poly: ptr PolynomialEval[N, Field],
 29         z: ptr Field,
 30         invRootsMinusZ: ptr array[N, Field],
 31         domain: ptr PolyDomainEval[N, Field]) =
 32    ## Evaluate a polynomial in evaluation form
 33    ## at the point z
 34    ## z MUST NOT be one of the roots of unity
 35    ##
 36    ## Parallelism: This only returns when computation is fully done
 37  
 38    # p(z) = (1-zⁿ)/n ∑ ωⁱ/(ωⁱ-z) . p(ωⁱ)
 39  
 40    mixin globalSum
 41    static: doAssert N.isPowerOf2_vartime()
 42  
 43    tp.parallelFor i in 0 ..< N:
 44      captures: {poly, domain, invRootsMinusZ}
 45      reduceInto(globalSum: Field):
 46        prologue:
 47          var workerSum {.noInit.}: Field
 48          workerSum.setZero()
 49        forLoop:
 50          var iterSummand {.noInit.}: Field
 51          iterSummand.prod(domain.rootsOfUnity[i], invRootsMinusZ[i])
 52          iterSummand *= poly.evals[i]
 53          workerSum += iterSummand
 54        merge(remoteSum: Flowvar[Field]):
 55          workerSum += sync(remoteSum)
 56        epilogue:
 57          return workerSum
 58  
 59    var t {.noInit.}: Field
 60    t = z[]
 61    const numDoublings = log2_vartime(uint32 N) # N is a power of 2
 62    t.square_repeated(int numDoublings)         # exponentiation by a power of 2
 63    t.diff(Field(mres: Field.getMontyOne()), t) # TODO: refactor getMontyOne to getOne and return a field element.
 64  
 65    r.prod(t, domain.invMaxDegree)
 66    r *= sync(globalSum)
 67  
 68  proc differenceQuotientEvalOffDomain_parallel*[N: static int, Field](
 69         tp: Threadpool,
 70         r: ptr PolynomialEval[N, Field],
 71         poly: ptr PolynomialEval[N, Field],
 72         pZ: ptr Field,
 73         invRootsMinusZ: ptr array[N, Field]) =
 74    ## Compute r(x) = (p(x) - p(z)) / (x - z)
 75    ##
 76    ## for z != ωⁱ a power of a root of unity
 77    ##
 78    ## Input:
 79    ##   - invRootsMinusZ:  1/(ωⁱ-z)
 80    ##   - poly:            p(x) a polynomial in evaluation form as an array of p(ωⁱ)
 81    ##   - rootsOfUnity:    ωⁱ
 82    ##   - p(z)
 83    ##
 84    ## Parallelism: This only returns when computation is fully done
 85    # TODO: we might want either awaitable for-loops
 86    #       or awaitable individual iterations
 87    #       for latency-hiding techniques
 88  
 89    syncScope:
 90      tp.parallelFor i in 0 ..< N:
 91        captures: {r, poly, pZ, invRootsMinusZ}
 92        # qᵢ = (p(ωⁱ) - p(z))/(ωⁱ-z)
 93        var qi {.noinit.}: Field
 94        qi.diff(poly.evals[i], pZ[])
 95        r.evals[i].prod(qi, invRootsMinusZ[i])
 96  
 97  proc differenceQuotientEvalInDomain_parallel*[N: static int, Field](
 98         tp: Threadpool,
 99         r: ptr PolynomialEval[N, Field],
100         poly: ptr PolynomialEval[N, Field],
101         zIndex: uint32,
102         invRootsMinusZ: ptr array[N, Field],
103         domain: ptr PolyDomainEval[N, Field],
104         isBitReversedDomain: static bool) =
105    ## Compute r(x) = (p(x) - p(z)) / (x - z)
106    ##
107    ## for z = ωⁱ a power of a root of unity
108    ##
109    ## Input:
110    ##   - poly:            p(x) a polynomial in evaluation form as an array of p(ωⁱ)
111    ##   - rootsOfUnity:    ωⁱ
112    ##   - invRootsMinusZ:  1/(ωⁱ-z)
113    ##   - zIndex:          the index of the root of unity power that matches z = ωⁱᵈˣ
114    ##
115    ## Parallelism: This only returns when computation is fully done
116  
117    static:
118      # For powers of 2: x mod N == x and (N-1)
119      doAssert N.isPowerOf2_vartime()
120  
121    mixin evalsZindex
122  
123    tp.parallelFor i in 0 ..< N:
124      captures: {r, poly, domain, invRootsMinusZ, zIndex}
125      reduceInto(evalsZindex: Field):
126        prologue:
127          var worker_ri {.noInit.}: Field
128          worker_ri.setZero()
129        forLoop:
130          var iter_ri {.noInit.}: Field
131          if i == int(zIndex):
132            iter_ri.setZero()
133          else:
134            # qᵢ = (p(ωⁱ) - p(z))/(ωⁱ-z)
135            var qi {.noinit.}: Field
136            qi.diff(poly.evals[i], poly.evals[zIndex])
137            r.evals[i].prod(qi, invRootsMinusZ[i])
138  
139            # q'ᵢ = -qᵢ * ωⁱ/z
140            # q'idx = ∑ q'ᵢ
141            iter_ri.neg(r.evals[i])                                  # -qᵢ
142            when isBitReversedDomain:
143              const logN = log2_vartime(uint32 N)
144              let invZidx = N - reverseBits(uint32 zIndex, logN)
145              let canonI = reverseBits(uint32 i, logN)
146              let idx = reverseBits((canonI + invZidx) and (N-1), logN)
147              iter_ri *= domain.rootsOfUnity[idx]                    # -qᵢ * ωⁱ/z  (explanation at the bottom of serial impl)
148            else:
149              iter_ri *= domain.rootsOfUnity[(i+N-zIndex) and (N-1)] # -qᵢ * ωⁱ/z  (explanation at the bottom of serial impl)
150            worker_ri += iter_ri
151        merge(remote_ri: Flowvar[Field]):
152          worker_ri += sync(remote_ri)
153        epilogue:
154          return worker_ri
155  
156    r.evals[zIndex] = sync(evalsZindex)