protobuf.nim
1 import libp2p/protobuf/minprotobuf 2 import std/options 3 import endians 4 import ../src/[message, protobufutil, bloom, reliability_utils] 5 6 proc encode*(msg: SdsMessage): ProtoBuffer = 7 var pb = initProtoBuffer() 8 9 pb.write(1, msg.messageId) 10 pb.write(2, uint64(msg.lamportTimestamp)) 11 12 for hist in msg.causalHistory: 13 pb.write(3, hist) 14 15 pb.write(4, msg.channelId) 16 pb.write(5, msg.content) 17 pb.write(6, msg.bloomFilter) 18 pb.finish() 19 20 pb 21 22 proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = 23 let pb = initProtoBuffer(buffer) 24 var msg = SdsMessage() 25 26 if not ?pb.getField(1, msg.messageId): 27 return err(ProtobufError.missingRequiredField("messageId")) 28 29 var timestamp: uint64 30 if not ?pb.getField(2, timestamp): 31 return err(ProtobufError.missingRequiredField("lamportTimestamp")) 32 msg.lamportTimestamp = int64(timestamp) 33 34 var causalHistory: seq[SdsMessageID] 35 let histResult = pb.getRepeatedField(3, causalHistory) 36 if histResult.isOk: 37 msg.causalHistory = causalHistory 38 39 if not ?pb.getField(4, msg.channelId): 40 return err(ProtobufError.missingRequiredField("channelId")) 41 42 if not ?pb.getField(5, msg.content): 43 return err(ProtobufError.missingRequiredField("content")) 44 45 if not ?pb.getField(6, msg.bloomFilter): 46 msg.bloomFilter = @[] # Empty if not present 47 48 ok(msg) 49 50 proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError] = 51 ## For extraction of channel ID without full message deserialization 52 try: 53 let pb = initProtoBuffer(data) 54 var channelId: SdsChannelID 55 let fieldOk = pb.getField(4, channelId).valueOr: 56 return err(ReliabilityError.reDeserializationError) 57 if not fieldOk: 58 return err(ReliabilityError.reDeserializationError) 59 ok(channelId) 60 except: 61 err(ReliabilityError.reDeserializationError) 62 63 proc serializeMessage*(msg: SdsMessage): Result[seq[byte], ReliabilityError] = 64 let pb = encode(msg) 65 ok(pb.buffer) 66 67 proc deserializeMessage*(data: seq[byte]): Result[SdsMessage, ReliabilityError] = 68 let msg = SdsMessage.decode(data).valueOr: 69 return err(ReliabilityError.reDeserializationError) 70 ok(msg) 71 72 proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityError] = 73 var pb = initProtoBuffer() 74 75 # Convert intArray to bytes 76 try: 77 var bytes = newSeq[byte](filter.intArray.len * sizeof(int)) 78 for i, val in filter.intArray: 79 var leVal: int 80 littleEndian64(addr leVal, unsafeAddr val) 81 let start = i * sizeof(int) 82 copyMem(addr bytes[start], addr leVal, sizeof(int)) 83 84 pb.write(1, bytes) 85 pb.write(2, uint64(filter.capacity)) 86 pb.write(3, uint64(filter.errorRate * 1_000_000)) 87 pb.write(4, uint64(filter.kHashes)) 88 pb.write(5, uint64(filter.mBits)) 89 except: 90 return err(ReliabilityError.reSerializationError) 91 92 pb.finish() 93 ok(pb.buffer) 94 95 proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityError] = 96 if data.len == 0: 97 return err(ReliabilityError.reDeserializationError) 98 99 let pb = initProtoBuffer(data) 100 var bytes: seq[byte] 101 var cap, errRate, kHashes, mBits: uint64 102 103 try: 104 let 105 field1_Ok = pb.getField(1, bytes).valueOr: 106 return err(ReliabilityError.reDeserializationError) 107 field2_Ok = pb.getField(2, cap).valueOr: 108 return err(ReliabilityError.reDeserializationError) 109 field3_Ok = pb.getField(3, errRate).valueOr: 110 return err(ReliabilityError.reDeserializationError) 111 field4_Ok = pb.getField(4, kHashes).valueOr: 112 return err(ReliabilityError.reDeserializationError) 113 field5_Ok = pb.getField(5, mBits).valueOr: 114 return err(ReliabilityError.reDeserializationError) 115 116 if not field1_Ok or not field2_Ok or not field3_Ok or not field4_Ok or not field5_Ok: 117 return err(ReliabilityError.reDeserializationError) 118 119 # Convert bytes back to intArray 120 var intArray = newSeq[int](bytes.len div sizeof(int)) 121 for i in 0 ..< intArray.len: 122 var leVal: int 123 let start = i * sizeof(int) 124 copyMem(addr leVal, unsafeAddr bytes[start], sizeof(int)) 125 littleEndian64(addr intArray[i], addr leVal) 126 127 ok( 128 BloomFilter( 129 intArray: intArray, 130 capacity: int(cap), 131 errorRate: float(errRate) / 1_000_000, 132 kHashes: int(kHashes), 133 mBits: int(mBits), 134 ) 135 ) 136 except: 137 return err(ReliabilityError.reDeserializationError)