/ src / protobuf.nim
protobuf.nim
  1  import libp2p/protobuf/minprotobuf
  2  import std/options
  3  import endians
  4  import ../src/[message, protobufutil, bloom, reliability_utils]
  5  
  6  proc encode*(msg: SdsMessage): ProtoBuffer =
  7    var pb = initProtoBuffer()
  8  
  9    pb.write(1, msg.messageId)
 10    pb.write(2, uint64(msg.lamportTimestamp))
 11  
 12    for hist in msg.causalHistory:
 13      pb.write(3, hist)
 14  
 15    pb.write(4, msg.channelId)
 16    pb.write(5, msg.content)
 17    pb.write(6, msg.bloomFilter)
 18    pb.finish()
 19  
 20    pb
 21  
 22  proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] =
 23    let pb = initProtoBuffer(buffer)
 24    var msg = SdsMessage()
 25  
 26    if not ?pb.getField(1, msg.messageId):
 27      return err(ProtobufError.missingRequiredField("messageId"))
 28  
 29    var timestamp: uint64
 30    if not ?pb.getField(2, timestamp):
 31      return err(ProtobufError.missingRequiredField("lamportTimestamp"))
 32    msg.lamportTimestamp = int64(timestamp)
 33  
 34    var causalHistory: seq[SdsMessageID]
 35    let histResult = pb.getRepeatedField(3, causalHistory)
 36    if histResult.isOk:
 37      msg.causalHistory = causalHistory
 38  
 39    if not ?pb.getField(4, msg.channelId):
 40      return err(ProtobufError.missingRequiredField("channelId"))
 41  
 42    if not ?pb.getField(5, msg.content):
 43      return err(ProtobufError.missingRequiredField("content"))
 44  
 45    if not ?pb.getField(6, msg.bloomFilter):
 46      msg.bloomFilter = @[] # Empty if not present
 47  
 48    ok(msg)
 49  
 50  proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError] =
 51    ## For extraction of channel ID without full message deserialization
 52    try:
 53      let pb = initProtoBuffer(data)
 54      var channelId: SdsChannelID
 55      let fieldOk = pb.getField(4, channelId).valueOr:
 56        return err(ReliabilityError.reDeserializationError)
 57      if not fieldOk:
 58        return err(ReliabilityError.reDeserializationError)
 59      ok(channelId)
 60    except:
 61      err(ReliabilityError.reDeserializationError)
 62  
 63  proc serializeMessage*(msg: SdsMessage): Result[seq[byte], ReliabilityError] =
 64    let pb = encode(msg)
 65    ok(pb.buffer)
 66  
 67  proc deserializeMessage*(data: seq[byte]): Result[SdsMessage, ReliabilityError] =
 68    let msg = SdsMessage.decode(data).valueOr:
 69      return err(ReliabilityError.reDeserializationError)
 70    ok(msg)
 71  
 72  proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityError] =
 73    var pb = initProtoBuffer()
 74  
 75    # Convert intArray to bytes
 76    try:
 77      var bytes = newSeq[byte](filter.intArray.len * sizeof(int))
 78      for i, val in filter.intArray:
 79        var leVal: int
 80        littleEndian64(addr leVal, unsafeAddr val)
 81        let start = i * sizeof(int)
 82        copyMem(addr bytes[start], addr leVal, sizeof(int))
 83  
 84      pb.write(1, bytes)
 85      pb.write(2, uint64(filter.capacity))
 86      pb.write(3, uint64(filter.errorRate * 1_000_000))
 87      pb.write(4, uint64(filter.kHashes))
 88      pb.write(5, uint64(filter.mBits))
 89    except:
 90      return err(ReliabilityError.reSerializationError)
 91  
 92    pb.finish()
 93    ok(pb.buffer)
 94  
 95  proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityError] =
 96    if data.len == 0:
 97      return err(ReliabilityError.reDeserializationError)
 98  
 99    let pb = initProtoBuffer(data)
100    var bytes: seq[byte]
101    var cap, errRate, kHashes, mBits: uint64
102  
103    try:
104      let
105        field1_Ok = pb.getField(1, bytes).valueOr:
106          return err(ReliabilityError.reDeserializationError)
107        field2_Ok = pb.getField(2, cap).valueOr:
108          return err(ReliabilityError.reDeserializationError)
109        field3_Ok = pb.getField(3, errRate).valueOr:
110          return err(ReliabilityError.reDeserializationError)
111        field4_Ok = pb.getField(4, kHashes).valueOr:
112          return err(ReliabilityError.reDeserializationError)
113        field5_Ok = pb.getField(5, mBits).valueOr:
114          return err(ReliabilityError.reDeserializationError)
115  
116      if not field1_Ok or not field2_Ok or not field3_Ok or not field4_Ok or not field5_Ok:
117        return err(ReliabilityError.reDeserializationError)
118  
119      # Convert bytes back to intArray
120      var intArray = newSeq[int](bytes.len div sizeof(int))
121      for i in 0 ..< intArray.len:
122        var leVal: int
123        let start = i * sizeof(int)
124        copyMem(addr leVal, unsafeAddr bytes[start], sizeof(int))
125        littleEndian64(addr intArray[i], addr leVal)
126  
127      ok(
128        BloomFilter(
129          intArray: intArray,
130          capacity: int(cap),
131          errorRate: float(errRate) / 1_000_000,
132          kHashes: int(kHashes),
133          mBits: int(mBits),
134        )
135      )
136    except:
137      return err(ReliabilityError.reDeserializationError)