fields_nvidia.nim
1 # Constantine 2 # Copyright (c) 2018-2019 Status Research & Development GmbH 3 # Copyright (c) 2020-Present Mamy André-Ratsimbazafy 4 # Licensed and distributed under either of 5 # * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). 6 # * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). 7 # at your option. This file may not be copied, modified, or distributed except according to those terms. 8 9 import 10 ../platforms/code_generator/[llvm, nvidia, ir] 11 12 # ############################################################ 13 # 14 # Field arithmetic on Nvidia GPU 15 # 16 # ############################################################ 17 18 # Loads from global (kernel params) take over 100 cycles 19 # https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-costs 20 21 proc finalSubMayOverflow*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array) = 22 ## If a >= Modulus: r <- a-M 23 ## else: r <- a 24 ## 25 ## This is constant-time straightline code. 26 ## Due to warp divergence, the overhead of doing comparison with shortcutting might not be worth it on GPU. 27 ## 28 ## To be used when the final substraction can 29 ## also overflow the limbs (a 2^256 order of magnitude modulus stored in n words of total max size 2^256) 30 31 let bld = asy.builder 32 let fieldTy = cm.getFieldType(field) 33 let scratch = bld.makeArray(fieldTy) 34 let M = cm.getModulus(field) 35 let N = M.len 36 37 # Contains 0x0001 (if overflowed limbs) or 0x0000 38 let overflowedLimbs = bld.add_ci(0'u32, 0'u32) 39 40 # Now substract the modulus, and test a < M with the last borrow 41 scratch[0] = bld.sub_bo(a[0], M[0]) 42 for i in 1 ..< N: 43 scratch[i] = bld.sub_bio(a[i], M[i]) 44 45 # 1. if `overflowedLimbs`, underflowedModulus >= 0 46 # 2. if a >= M, underflowedModulus >= 0 47 # if underflowedModulus >= 0: a-M else: a 48 let underflowedModulus = bld.sub_bi(overflowedLimbs, 0'u32) 49 50 for i in 0 ..< N: 51 r[i] = bld.slct(scratch[i], a[i], underflowedModulus) 52 53 proc finalSubNoOverflow*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array) = 54 ## If a >= Modulus: r <- a-M 55 ## else: r <- a 56 ## 57 ## This is constant-time straightline code. 58 ## Due to warp divergence, the overhead of doing comparison with shortcutting might not be worth it on GPU. 59 ## 60 ## To be used when the modulus does not use the full bitwidth of the storing words 61 ## (say using 255 bits for the modulus out of 256 available in words) 62 63 let bld = asy.builder 64 let fieldTy = cm.getFieldType(field) 65 let scratch = bld.makeArray(fieldTy) 66 let M = cm.getModulus(field) 67 let N = M.len 68 69 # Now substract the modulus, and test a < M with the last borrow 70 scratch[0] = bld.sub_bo(a[0], M[0]) 71 for i in 1 ..< N: 72 scratch[i] = bld.sub_bio(a[i], M[i]) 73 74 # If it underflows here a was smaller than the modulus, which is what we want 75 let underflowedModulus = bld.sub_bi(0'u32, 0'u32) 76 77 for i in 0 ..< N: 78 r[i] = bld.slct(scratch[i], a[i], underflowedModulus) 79 80 proc field_add_gen*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field): FnDef = 81 ## Generate an optimized modular addition kernel 82 ## with parameters `a, b, modulus: Limbs -> Limbs` 83 84 let procName = cm.genSymbol(block: 85 case field 86 of fp: opFpAdd 87 of fr: opFrAdd) 88 let fieldTy = cm.getFieldType(field) 89 let pFieldTy = pointer_t(fieldTy) 90 91 let addModTy = function_t(asy.void_t, [pFieldTy, pFieldTy, pFieldTy]) 92 let addModKernel = asy.module.addFunction(cstring procName, addModTy) 93 let blck = asy.ctx.appendBasicBlock(addModKernel, "addModBody") 94 asy.builder.positionAtEnd(blck) 95 96 let bld = asy.builder 97 98 let r = bld.asArray(addModKernel.getParam(0), fieldTy) 99 let a = bld.asArray(addModKernel.getParam(1), fieldTy) 100 let b = bld.asArray(addModKernel.getParam(2), fieldTy) 101 102 let t = bld.makeArray(fieldTy) 103 let N = cm.getNumWords(field) 104 105 t[0] = bld.add_co(a[0], b[0]) 106 for i in 1 ..< N: 107 t[i] = bld.add_cio(a[i], b[i]) 108 109 if cm.getSpareBits(field) >= 1: 110 asy.finalSubNoOverflow(cm, field, t, t) 111 else: 112 asy.finalSubMayOverflow(cm, field, t, t) 113 114 bld.store(r, t) 115 bld.retVoid() 116 117 return (addModTy, addModKernel)