Parser.hs
1 2 -- | Parsing the graph binary format 3 4 {-# LANGUAGE Strict, PackageImports, BangPatterns #-} 5 module Parser where 6 7 -------------------------------------------------------------------------------- 8 9 import Data.Bits 10 import Data.Word 11 import Data.List 12 import Data.Ord 13 14 import Control.Monad 15 import Control.Applicative 16 import System.IO 17 18 import Data.ByteString.Lazy ( ByteString ) 19 import qualified Data.ByteString.Lazy as L 20 import qualified Data.ByteString.Lazy.Char8 as LC 21 22 import "binary" Data.Binary.Get 23 import "binary" Data.Binary.Get.Internal ( lookAhead ) 24 import "binary" Data.Binary.Builder as Builder 25 26 import Graph 27 28 -------------------------------------------------------------------------------- 29 30 {- 31 test :: IO () 32 test = do 33 Right graph <- parseGraphFile "../tmp/graph.bin" 34 print graph 35 -} 36 37 {- 38 nodeEx1 = 0x06 : nodeEx1' :: [Word8] 39 nodeEx2 = 0x08 : nodeEx2' :: [Word8] 40 41 nodeEx1' = 0x22 : 0x04 : duoNodeEx1 :: [Word8] 42 nodeEx2' = 0x22 : 0x06 : duoNodeEx2 :: [Word8] 43 44 duoNodeEx1 = [ 0x10 , 0x05 , 0x18 , 0x05 ] :: [Word8] 45 duoNodeEx2 = [ 0x08 , 0x02 , 0x10 , 0x03 , 0x18 , 0x04 ] :: [Word8] 46 -} 47 48 -------------------------------------------------------------------------------- 49 50 type Msg = String 51 52 parseGraphFile :: FilePath -> IO (Either Msg Graph) 53 parseGraphFile fname = do 54 h <- openBinaryFile fname ReadMode 55 ei <- readGraphFile h 56 hClose h 57 return ei 58 59 hGetBytes :: Handle -> Int -> IO ByteString 60 hGetBytes h n = L.hGet h (fromIntegral n) 61 62 hSeekInt :: Handle -> Int -> IO () 63 hSeekInt h ofs = hSeek h AbsoluteSeek (fromIntegral ofs) 64 65 readGraphFile :: Handle -> IO (Either Msg Graph) 66 readGraphFile h = do 67 flen <- (fromIntegral :: Integer -> Int) <$> hFileSize h 68 magic <- hGetBytes h (length magicHeader) 69 if magic /= LC.pack magicHeader 70 then return $ Left "magic header not found or invalid" 71 else do 72 hSeekInt h (flen - 8) 73 offset <- (fromIntegral . runGet getWord64le) <$> hGetBytes h 8 74 -- putStrLn $ "metadata offset = " ++ show offset 75 if (offset >= flen) || (offset <= 18) 76 then return $ Left "invalid final `graphMetaData` offset bytes" 77 else do 78 hSeekInt h (length magicHeader) 79 part1 <- hGetBytes h (offset - length magicHeader) 80 part2 <- hGetBytes h (flen - offset - 8) 81 return (Right $ Graph (parseNodes part1) (parseMeta part2)) 82 83 magicHeader :: String 84 magicHeader = "wtns.graph.001" 85 86 parseNodes :: ByteString -> [Node] 87 parseNodes = runGet getNodes 88 89 parseMeta :: ByteString -> GraphMetaData 90 parseMeta = runGet getMetaData 91 92 -------------------------------------------------------------------------------- 93 94 varInt' :: Get Word64 95 varInt' = go 0 where 96 go !cnt = if cnt >= 8 then return 0 else do 97 w <- getWord8 98 if (w < 128) 99 then return (fromIntegral w) 100 else do 101 let x = fromIntegral (w .&. 127) 102 y <- go (cnt+1) 103 return (x + 128*y) 104 105 varInt :: Get Int 106 varInt = fromIntegral <$> varInt' 107 108 varUInt :: Get Word32 109 varUInt = fromIntegral <$> varInt' 110 111 -------------------------------------------------------------------------------- 112 113 expectingError :: Int -> String -> Int -> Get a 114 expectingError actual what shouldbe = do 115 error $ what ++ ": expecting field " ++ show shouldbe ++ "; got " ++ show actual ++ " instead" 116 117 getNodes :: Get [Node] 118 getNodes = do 119 n <- getWord64le 120 replicateM (fromIntegral n) getNode 121 122 -- | with varint length prefix 123 getNode :: Get Node 124 getNode = do 125 len <- varInt 126 bs <- getLazyByteString (fromIntegral len) 127 return (runGet getNode' bs) 128 129 -- | without varint length prefix 130 getNode' :: Get Node 131 getNode' = do 132 nodetype <- getFieldId LEN 133 case nodetype of 134 1 -> AnInputNode <$> getInputNode 135 2 -> AConstantNode <$> getConstantNode 136 3 -> AnUnoOpNode <$> getUnoOpNode 137 4 -> ADuoOpNode <$> getDuoOpNode 138 5 -> ATresOpNode <$> getTresOpNode 139 _ -> error "unexpected node type" 140 141 getInputNode :: Get InputNode 142 getInputNode = do 143 SomeNode idx _ _ _ <- getSomeNode 144 return (InputNode idx) 145 146 getConstantNode :: Get ConstantNode 147 getConstantNode = do 148 len <- varInt 149 bs <- getLazyByteString (fromIntegral len) 150 return $ ConstantNode (runGet getBigUInt bs) 151 152 getBigUInt' :: FieldId -> Get BigUInt 153 getBigUInt' expectedId = do 154 fld <- getFieldId LEN 155 if fld /= expectedId 156 then expectingError fld "getBigUInt" expectedId 157 else do 158 len <- varInt 159 bs <- getLazyByteString (fromIntegral len) 160 return $ BigUInt (runGet getByteList bs) 161 162 getBigUInt ::Get BigUInt 163 getBigUInt = getBigUInt' 1 164 165 getByteList :: Get [Word8] 166 getByteList = do 167 fld <- getFieldId LEN 168 if fld /= 1 169 then expectingError fld "getByteList" 1 170 else do 171 len <- varInt 172 bs <- getLazyByteString (fromIntegral len) 173 return (L.unpack bs) 174 175 getString' :: FieldId -> Get String 176 getString' expectedId = do 177 fld <- getFieldId LEN 178 if fld /= expectedId 179 then expectingError fld "getString" expectedId 180 else do 181 len <- varInt 182 bs <- getLazyByteString (fromIntegral len) 183 return (LC.unpack bs) 184 185 getUnoOpNode :: Get UnoOpNode 186 getUnoOpNode = do 187 SomeNode op arg1 _ _ <- getSomeNode 188 return (UnoOpNode (wordToEnum op) arg1) 189 190 getDuoOpNode :: Get DuoOpNode 191 getDuoOpNode = do 192 SomeNode op arg1 arg2 _ <- getSomeNode 193 return (DuoOpNode (wordToEnum op) arg1 arg2) 194 195 getTresOpNode :: Get TresOpNode 196 getTresOpNode = do 197 SomeNode op arg1 arg2 arg3 <- getSomeNode 198 return (TresOpNode (wordToEnum op) arg1 arg2 arg3) 199 200 wordToEnum :: Enum a => Word32 -> a 201 wordToEnum w = toEnum (fromIntegral w) 202 203 -------------------------------------------------------------------------------- 204 205 data SomeNode = SomeNode 206 { field1 :: Word32 207 , field2 :: Word32 208 , field3 :: Word32 209 , field4 :: Word32 210 } 211 deriving Show 212 213 defaultSomeNode :: SomeNode 214 defaultSomeNode = SomeNode 0 0 0 0 215 216 insert1 :: (Int,Word32) -> SomeNode -> SomeNode 217 insert1 (idx,val) old = case idx of 218 1 -> old { field1 = val } 219 2 -> old { field2 = val } 220 3 -> old { field3 = val } 221 4 -> old { field4 = val } 222 223 insertMany :: [(Int,Word32)] -> SomeNode -> SomeNode 224 insertMany list old = foldl' (flip insert1) old list 225 226 getSomeNode :: Get SomeNode 227 getSomeNode = do 228 len <- varInt 229 bs <- getLazyByteString (fromIntegral len) 230 let list = runGet getRecord bs 231 return $ insertMany list defaultSomeNode 232 233 -------------------------------------------------------------------------------- 234 -- TODO: refactor this mess 235 236 getMetaData :: Get GraphMetaData 237 getMetaData = do 238 len <- varInt 239 mapping <- getWitnessMapping 240 inputs <- getCircuitInputs 241 prime <- getPrime 242 return $ GraphMetaData mapping inputs prime 243 244 getPrime :: Get Prime 245 getPrime = do 246 number <- getBigUInt' 3 247 name <- getString' 4 248 return $ Prime 249 { primeNumber = number 250 , primeName = name 251 } 252 253 getWitnessMapping :: Get WitnessMapping 254 getWitnessMapping = do 255 fld <- getFieldId LEN 256 if fld /= 1 257 then expectingError fld "getWitnessMapping" 1 258 else do 259 len <- varInt 260 bs <- getLazyByteString (fromIntegral len) 261 return $ WitnessMapping (runGet worker bs) 262 where 263 worker :: Get [Word32] 264 worker = isEmpty >>= \b -> if b 265 then return [] 266 else (:) <$> varUInt <*> worker 267 268 getCircuitInputs :: Get CircuitInputs 269 getCircuitInputs = worker where 270 271 {- 272 worker :: Get [(String, SignalDescription)] 273 worker = isEmpty >>= \b -> if b 274 then return [] 275 else (:) <$> getSingleInput <*> worker 276 -} 277 278 worker :: Get [(String, SignalDescription)] 279 worker = do 280 mb <- getSingleInput 281 case mb of 282 Nothing -> return [] 283 Just this -> (this:) <$> worker 284 285 getSingleInput :: Get (Maybe (String, SignalDescription)) 286 getSingleInput = do 287 fld <- lookAhead (getFieldId LEN) 288 if fld /= 2 289 then return Nothing -- expectingError fld "getSingleInput" 2 290 else do 291 _fld <- getFieldId LEN 292 len <- varInt 293 bs <- getLazyByteString (fromIntegral len) 294 return $ Just $ runGet inputHelper bs 295 296 inputHelper = do 297 name <- getName 298 signal <- getSignal 299 return (name,signal) 300 301 getName :: Get String 302 getName = do 303 fld <- getFieldId LEN 304 if fld /= 1 305 then expectingError fld "getCircuitInputs/getName" 1 306 else do 307 len <- varInt 308 bs <- getLazyByteString (fromIntegral len) 309 return (LC.unpack bs) 310 311 getSignal :: Get SignalDescription 312 getSignal = do 313 fld <- getFieldId LEN 314 if fld /= 2 315 then expectingError fld "getCircuitInputs/getSignal" 2 316 else do 317 len <- varInt 318 bs <- getLazyByteString (fromIntegral len) 319 return $ runGet signalHelper bs 320 321 signalHelper = do 322 ofs <- getSignalOffset 323 len <- getSignalLength 324 return $ SignalDescription { signalOffset = ofs , signalLength = len } 325 326 getSignalOffset = do 327 fld <- getFieldId VARINT 328 if fld /= 1 329 then expectingError fld "getCircuitInputs/getSignalOffset" 1 330 else varUInt 331 332 getSignalLength = do 333 fld <- getFieldId VARINT 334 if fld /= 2 335 then expectingError fld "getCircuitInputs/getSignalLength" 2 336 else varUInt 337 338 -------------------------------------------------------------------------------- 339 -- * protobuf stuff 340 341 -- | There are six wire types: VARINT, I64, LEN, SGROUP, EGROUP, and I32 342 data WireType 343 = VARINT -- ^ used for: int32, int64, uint32, uint64, sint32, sint64, bool, enum 344 | I64 -- ^ used for: fixed64, sfixed64, double 345 | LEN -- ^ used for: string, bytes, embedded messages, packed repeated fields 346 | SGROUP -- ^ used for: group start (deprecated) 347 | EGROUP -- ^ used for: group end (deprecated) 348 | I32 -- ^ fixed32, sfixed32, float 349 deriving (Eq,Show,Enum,Bounded) 350 351 type FieldId = Int 352 353 getFieldId :: WireType -> Get FieldId 354 getFieldId wty = do 355 tag <- getWord8 356 let (fld,wty') = decodeTag tag 357 if wty == wty' 358 then return fld 359 else error "getFieldId: unexpected protobuf wire type" 360 361 decodeTag_ :: Word8 -> FieldId 362 decodeTag_ = fst . decodeTag 363 364 decodeTag :: Word8 -> (FieldId, WireType) 365 decodeTag w = (fld , wty) where 366 fld = fromIntegral (shiftR w 3) 367 wty = toEnum (fromIntegral (w .&. 7)) 368 369 -- (index, value) pair 370 getEntry :: Get (Int,Word32) 371 getEntry = do 372 idx <- getFieldId VARINT 373 val <- varUInt 374 return (idx,val) 375 376 -- list of (index, value) pairs 377 getRecord :: Get [(Int,Word32)] 378 getRecord = sort <$> go where 379 go = isEmpty >>= \b -> if b 380 then return [] 381 else (:) <$> getEntry <*> go 382 383 --------------------------------------------------------------------------------