codec.nim
  1  # Copyright (c) 2022 Status Research & Development GmbH
  2  # Licensed under either of
  3  #  * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
  4  #  * MIT license ([LICENSE-MIT](LICENSE-MIT))
  5  # at your option.
  6  # This file may not be copied, modified, or distributed except according to
  7  # those terms.
  8  
  9  ## This module implements core primitives for the protobuf language as seen in
 10  ## `.proto` files
 11  
 12  {.push raises: [], gcsafe.}
 13  
 14  import
 15    std/[typetraits, unicode],
 16    faststreams,
 17    stew/[leb128, endians2],
 18    ./types
 19  
 20  export types
 21  
 22  type
 23    WireKind* = enum
 24      Varint = 0
 25      Fixed64 = 1
 26      LengthDelim = 2
 27      # StartGroup = 3 # Not used
 28      # EndGroup = 4 # Not used
 29      Fixed32 = 5
 30  
 31    SomePBInt* = int32 | int64 | uint32 | uint64
 32  
 33    FieldHeader* = distinct uint32
 34  
 35    # Scalar types used in `.proto` files
 36    # https://developers.google.com/protocol-buffers/docs/proto3#scalar
 37    pdouble* = distinct float64
 38    pfloat* = distinct float32
 39  
 40    pint32* = distinct int32 ## varint-encoded signed integer
 41    pint64* = distinct int64 ## varint-encoded signed integer
 42  
 43    puint32* = distinct uint32 ## varint-encoded unsigned integer
 44    puint64* = distinct uint64 ## varint-encoded unsigned integer
 45  
 46    sint32* = distinct int32 ## zig-zag-varint-encoded signed integer
 47    sint64* = distinct int64 ## zig-zag-varint-encoded signed integer
 48  
 49    fixed32* = distinct uint32 ## fixed-width unsigned integer
 50    fixed64* = distinct uint64 ## fixed-width unsigned integer
 51  
 52    sfixed32* = distinct int32 ## fixed-width signed integer
 53    sfixed64* = distinct int64 ## fixed-width signed integer
 54  
 55    pbool* = distinct bool
 56    penum* = distinct int32
 57  
 58    pstring* = distinct string ## UTF-8-encoded string
 59    pbytes* = distinct seq[byte] ## byte sequence
 60  
 61    SomeScalar* =
 62      pint32 | pint64 | puint32 | puint64 | sint32 | sint64 | pbool | penum |
 63      fixed64 | sfixed64 | pdouble |
 64      pstring | pbytes |
 65      fixed32 | sfixed32 | pfloat
 66  
 67    # Mappings of proto type to wire type
 68    SomeVarint* =
 69      pint32 | pint64 | puint32 | puint64 | sint32 | sint64 | pbool | penum
 70    SomeFixed64* = fixed64 | sfixed64 | pdouble
 71    SomeLengthDelim* = pstring | pbytes # Also messages and packed repeated fields
 72    SomeFixed32* = fixed32 | sfixed32 | pfloat
 73  
 74    SomePrimitive* = SomeVarint | SomeFixed64 | SomeFixed32
 75      ## Types that may appear packed
 76  
 77  const
 78    SupportedWireKinds* = {
 79      uint8(WireKind.Varint),
 80      uint8(WireKind.Fixed64),
 81      uint8(WireKind.LengthDelim),
 82      uint8(WireKind.Fixed32)
 83    }
 84  
 85  template wireKind*(T: type SomeVarint): WireKind = WireKind.Varint
 86  template wireKind*(T: type SomeFixed64): WireKind = WireKind.Fixed64
 87  template wireKind*(T: type SomeLengthDelim): WireKind = WireKind.LengthDelim
 88  template wireKind*(T: type SomeFixed32): WireKind = WireKind.Fixed32
 89  
 90  template validFieldNumber*(i: int, strict: bool = false): bool =
 91    # https://developers.google.com/protocol-buffers/docs/proto#assigning
 92    # Field numbers in the 19k range are reserved for the protobuf implementation
 93    (i > 0 and i < (1 shl 29)) and (not strict or not(i >= 19000 and i <= 19999))
 94  
 95  template init*(_: type FieldHeader, index: int, wire: WireKind): FieldHeader =
 96    ## Get protobuf's field header integer for ``index`` and ``wire``.
 97    FieldHeader((uint32(index) shl 3) or uint32(wire))
 98  
 99  template number*(p: FieldHeader): int =
100    int(uint32(p) shr 3)
101  
102  template kind*(p: FieldHeader): WireKind =
103    cast[WireKind](uint8(p) and 0x07'u8) # 3 lower bits
104  
105  template toUleb(x: puint64): uint64 = uint64(x)
106  template toUleb(x: puint32): uint32 = uint32(x)
107  
108  func toUleb(x: sint64): uint64 =
109    let v = cast[uint64](x)
110    (v shl 1) xor (0 - (v shr 63))
111  
112  func toUleb(x: sint32): uint32 =
113    let v = cast[uint32](x)
114    (v shl 1) xor (0 - (v shr 31))
115  
116  template toUleb(x: pint64): uint64 = cast[uint64](x)
117  template toUleb(x: pint32): uint32 = cast[uint32](x)
118  template toUleb(x: pbool): uint8 = cast[uint8](x)
119  template toUleb(x: penum): uint64 = cast[uint32](x)
120  
121  template fromUleb(x: uint64, T: type puint64): T = puint64(x)
122  template fromUleb(x: uint64, T: type pbool): T = pbool(x != 0)
123  
124  template fromUleb(x: uint64, T: type puint64): T = puint64(x)
125  template fromUleb(x: uint64, T: type puint32): T = puint32(x)
126  
127  template fromUleb(x: uint64, T: type sint64): T =
128    cast[T]((x shr 1) xor (0 - (x and 1)))
129  template fromUleb(x: uint64, T: type sint32): T =
130    cast[T]((uint32(x) shr 1) xor (0 - (uint32(x) and 1)))
131  
132  template fromUleb(x: uint64, T: type pint64): T = cast[T](x)
133  template fromUleb(x: uint64, T: type pint32): T = cast[T](x)
134  template fromUleb(x: uint64, T: type penum): T = cast[T](x)
135  
136  template toBytes*(x: SomeVarint): openArray[byte] =
137    toBytes(toUleb(x), Leb128).toOpenArray()
138  
139  template toBytes*(x: fixed32 | fixed64): openArray[byte] =
140    type Base = distinctBase(typeof(x))
141    toBytesLE(Base(x))
142  
143  template toBytes*(x: sfixed32): openArray[byte] =
144    toBytes(fixed32(x))
145  template toBytes*(x: sfixed64): openArray[byte] =
146    toBytes(fixed64(x))
147  
148  template toBytes*(x: pdouble): openArray[byte] =
149    cast[array[8, byte]](x)
150  template toBytes*(x: pfloat): openArray[byte] =
151    cast[array[4, byte]](x)
152  
153  template toBytes*(header: FieldHeader): openArray[byte] =
154    toBytes(uint32(header), Leb128).toOpenArray()
155  
156  func computeSize*(x: SomeVarint): int =
157    ## Returns number of bytes required to encode integer ``x`` as varint.
158    Leb128.len(toUleb(x))
159  
160  func computeSize*(x: SomeFixed64 | SomeFixed32): int =
161    ## Returns number of bytes required to encode integer ``x`` as varint.
162    sizeof(x)
163  
164  func computeSize*(x: pstring | pbytes): int =
165    let len = distinctBase(x).len()
166    computeSize(puint64(len)) + len
167  
168  func computeSize*(x: FieldHeader): int =
169    ## Returns number of bytes required to encode integer ``x`` as varint.
170    computeSize(puint32(x))
171  
172  func computeSize*(field: int, x: SomeScalar): int =
173    computeSize(FieldHeader.init(field, wireKind(typeof(x)))) +
174      computeSize(x)
175  
176  proc writeValue*(output: OutputStream, value: SomeVarint) {.raises: [IOError].} =
177    output.write(toBytes(value))
178  
179  proc writeValue*(output: OutputStream, value: SomeFixed64) {.raises: [IOError].} =
180    output.write(toBytes(value))
181  
182  proc writeValue*(output: OutputStream, value: pstring) {.raises: [IOError].} =
183    output.write(toBytes(puint64(string(value).len())))
184    output.write(string(value).toOpenArrayByte(0, string(value).high()))
185  
186  proc writeValue*(output: OutputStream, value: pbytes) {.raises: [IOError].} =
187    output.write(toBytes(puint64(seq[byte](value).len())))
188    output.write(seq[byte](value))
189  
190  proc writeValue*(output: OutputStream, value: SomeFixed32) {.raises: [IOError].} =
191    output.write(toBytes(value))
192  
193  proc writeValue*(output: OutputStream, value: FieldHeader) {.raises: [IOError].} =
194    output.write(toBytes(value))
195  
196  proc writeField*(output: OutputStream, field: int, value: SomeScalar) {.raises: [IOError].} =
197    output.writeValue(FieldHeader.init(field, wireKind(typeof(value))))
198    output.writeValue(value)
199  
200  proc readValue*[T: SomeVarint](input: InputStream, _: type T): T {.raises: [SerializationError, IOError].} =
201    # TODO This is not entirely correct: we should truncate value if it doesn't
202    #      fit, according to the docs:
203    #      https://developers.google.com/protocol-buffers/docs/proto#updating
204    var buf: Leb128Buf[uint64]
205    while buf.len < buf.data.len and input.readable():
206      let b = input.read()
207      buf.data[buf.len] = b
208      buf.len += 1
209      if (b and 0x80'u8) == 0:
210        break
211  
212    let (val, len) = uint64.fromBytes(buf)
213    if buf.len == 0 or len != buf.len:
214      raise (ref ProtobufValueError)(msg: "Cannot read varint from stream")
215  
216    fromUleb(val, T)
217  
218  proc readValue*[T: SomeFixed32 | SomeFixed64](input: InputStream, _: type T): T {.raises: [SerializationError, IOError].} =
219    var tmp {.noinit.}: array[sizeof(T), byte]
220    if not input.readInto(tmp):
221      raise (ref ProtobufValueError)(msg: "Not enough bytes")
222    when T is pdouble | pfloat:
223      copyMem(addr result, addr tmp[0], sizeof(result))
224    elif sizeof(T) == 8:
225      cast[T](uint64.fromBytesLE(tmp)) # Cast so we don't run into signed trouble
226    else:
227      cast[T](uint32.fromBytesLE(tmp)) # Cast so we don't run into signed trouble
228  
229  proc readLength*(input: InputStream): int {.raises: [SerializationError, IOError].} =
230    let lenu32 = input.readValue(puint32)
231    if uint64(lenu32) > uint64(int.high()):
232      raise (ref ProtobufValueError)(msg: "Invalid length")
233    int(lenu32)
234  
235  proc readValue*[T: SomeLengthDelim](input: InputStream, _: type T): T {.raises: [SerializationError, IOError].} =
236    let len = input.readLength()
237    if len > 0:
238      type Base = typetraits.distinctBase(T)
239      let inputLen = input.len()
240      if inputLen.isSome() and len > inputLen.get():
241          raise (ref ProtobufValueError)(msg: "Missing bytes: " & $len)
242  
243      Base(result).setLen(len)
244      template bytes(): openArray[byte] =
245        when Base is seq[byte]:
246          Base(result).toOpenArray(0, len - 1)
247        else:
248          Base(result).toOpenArrayByte(0, len - 1)
249      if not input.readInto(bytes()):
250        raise (ref ProtobufValueError)(msg: "Missing bytes: " & $len)
251  
252      when T is pstring:
253        if validateUtf8(string(result)) != -1:
254          raise (ref ProtobufValueError)(msg: "String not valid UTF-8")
255  
256  proc readHeader*(input: InputStream): FieldHeader {.raises: [SerializationError, IOError].} =
257    let
258      hdr = uint32(input.readValue(puint32))
259      wire = uint8(hdr and 0x07)
260  
261    if wire notin SupportedWireKinds:
262      raise (ref ProtobufValueError)(msg: "Invalid wire type " & $wire)
263  
264    FieldHeader(hdr)