/ constantine / math_codegen / fields_nvidia.nim
fields_nvidia.nim
  1  # Constantine
  2  # Copyright (c) 2018-2019    Status Research & Development GmbH
  3  # Copyright (c) 2020-Present Mamy André-Ratsimbazafy
  4  # Licensed and distributed under either of
  5  #   * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
  6  #   * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
  7  # at your option. This file may not be copied, modified, or distributed except according to those terms.
  8  
  9  import
 10    ../platforms/code_generator/[llvm, nvidia, ir]
 11  
 12  # ############################################################
 13  #
 14  #               Field arithmetic on Nvidia GPU
 15  #
 16  # ############################################################
 17  
 18  # Loads from global (kernel params) take over 100 cycles
 19  # https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-costs
 20  
 21  proc finalSubMayOverflow*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array) =
 22    ## If a >= Modulus: r <- a-M
 23    ## else:            r <- a
 24    ##
 25    ## This is constant-time straightline code.
 26    ## Due to warp divergence, the overhead of doing comparison with shortcutting might not be worth it on GPU.
 27    ##
 28    ## To be used when the final substraction can
 29    ## also overflow the limbs (a 2^256 order of magnitude modulus stored in n words of total max size 2^256)
 30  
 31    let bld = asy.builder
 32    let fieldTy = cm.getFieldType(field)
 33    let scratch = bld.makeArray(fieldTy)
 34    let M = cm.getModulus(field)
 35    let N = M.len
 36  
 37    # Contains 0x0001 (if overflowed limbs) or 0x0000
 38    let overflowedLimbs = bld.add_ci(0'u32, 0'u32)
 39  
 40    # Now substract the modulus, and test a < M with the last borrow
 41    scratch[0] = bld.sub_bo(a[0], M[0])
 42    for i in 1 ..< N:
 43      scratch[i] = bld.sub_bio(a[i], M[i])
 44  
 45    # 1. if `overflowedLimbs`, underflowedModulus >= 0
 46    # 2. if a >= M, underflowedModulus >= 0
 47    # if underflowedModulus >= 0: a-M else: a
 48    let underflowedModulus = bld.sub_bi(overflowedLimbs, 0'u32)
 49  
 50    for i in 0 ..< N:
 51      r[i] = bld.slct(scratch[i], a[i], underflowedModulus)
 52  
 53  proc finalSubNoOverflow*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array) =
 54    ## If a >= Modulus: r <- a-M
 55    ## else:            r <- a
 56    ##
 57    ## This is constant-time straightline code.
 58    ## Due to warp divergence, the overhead of doing comparison with shortcutting might not be worth it on GPU.
 59    ##
 60    ## To be used when the modulus does not use the full bitwidth of the storing words
 61    ## (say using 255 bits for the modulus out of 256 available in words)
 62  
 63    let bld = asy.builder
 64    let fieldTy = cm.getFieldType(field)
 65    let scratch = bld.makeArray(fieldTy)
 66    let M = cm.getModulus(field)
 67    let N = M.len
 68  
 69    # Now substract the modulus, and test a < M with the last borrow
 70    scratch[0] = bld.sub_bo(a[0], M[0])
 71    for i in 1 ..< N:
 72      scratch[i] = bld.sub_bio(a[i], M[i])
 73  
 74    # If it underflows here a was smaller than the modulus, which is what we want
 75    let underflowedModulus = bld.sub_bi(0'u32, 0'u32)
 76  
 77    for i in 0 ..< N:
 78      r[i] = bld.slct(scratch[i], a[i], underflowedModulus)
 79  
 80  proc field_add_gen*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field): FnDef =
 81    ## Generate an optimized modular addition kernel
 82    ## with parameters `a, b, modulus: Limbs -> Limbs`
 83  
 84    let procName = cm.genSymbol(block:
 85      case field
 86      of fp: opFpAdd
 87      of fr: opFrAdd)
 88    let fieldTy = cm.getFieldType(field)
 89    let pFieldTy = pointer_t(fieldTy)
 90  
 91    let addModTy = function_t(asy.void_t, [pFieldTy, pFieldTy, pFieldTy])
 92    let addModKernel = asy.module.addFunction(cstring procName, addModTy)
 93    let blck = asy.ctx.appendBasicBlock(addModKernel, "addModBody")
 94    asy.builder.positionAtEnd(blck)
 95  
 96    let bld = asy.builder
 97  
 98    let r = bld.asArray(addModKernel.getParam(0), fieldTy)
 99    let a = bld.asArray(addModKernel.getParam(1), fieldTy)
100    let b = bld.asArray(addModKernel.getParam(2), fieldTy)
101  
102    let t = bld.makeArray(fieldTy)
103    let N = cm.getNumWords(field)
104  
105    t[0] = bld.add_co(a[0], b[0])
106    for i in 1 ..< N:
107      t[i] = bld.add_cio(a[i], b[i])
108  
109    if cm.getSpareBits(field) >= 1:
110      asy.finalSubNoOverflow(cm, field, t, t)
111    else:
112      asy.finalSubMayOverflow(cm, field, t, t)
113  
114    bld.store(r, t)
115    bld.retVoid()
116  
117    return (addModTy, addModKernel)