/ haskell / src / Parser.hs
Parser.hs
  1  
  2  -- | Parsing the graph binary format
  3  
  4  {-# LANGUAGE Strict, PackageImports, BangPatterns #-}
  5  module Parser where
  6  
  7  --------------------------------------------------------------------------------
  8  
  9  import Data.Bits
 10  import Data.Word
 11  import Data.List
 12  import Data.Ord
 13  
 14  import Control.Monad
 15  import Control.Applicative
 16  import System.IO
 17  
 18  import Data.ByteString.Lazy ( ByteString ) 
 19  import qualified Data.ByteString.Lazy       as L
 20  import qualified Data.ByteString.Lazy.Char8 as LC
 21  
 22  import "binary" Data.Binary.Get
 23  import "binary" Data.Binary.Get.Internal ( lookAhead )
 24  import "binary" Data.Binary.Builder as Builder
 25  
 26  import Graph
 27  
 28  --------------------------------------------------------------------------------
 29  
 30  {-
 31  test :: IO ()
 32  test = do
 33    Right graph <- parseGraphFile "../tmp/graph.bin"
 34    print graph
 35  -}
 36  
 37  {-
 38  nodeEx1  = 0x06 : nodeEx1'            :: [Word8]
 39  nodeEx2  = 0x08 : nodeEx2'            :: [Word8]
 40  
 41  nodeEx1' = 0x22 : 0x04 : duoNodeEx1   :: [Word8]
 42  nodeEx2' = 0x22 : 0x06 : duoNodeEx2   :: [Word8]
 43  
 44  duoNodeEx1 = [ 0x10 , 0x05 , 0x18 , 0x05 ]                 :: [Word8]
 45  duoNodeEx2 = [ 0x08 , 0x02 , 0x10 , 0x03 , 0x18 , 0x04 ]   :: [Word8]
 46  -}
 47  
 48  --------------------------------------------------------------------------------
 49  
 50  type Msg = String
 51  
 52  parseGraphFile :: FilePath -> IO (Either Msg Graph)
 53  parseGraphFile fname = do
 54    h    <- openBinaryFile fname ReadMode 
 55    ei   <- readGraphFile h
 56    hClose h
 57    return ei
 58  
 59  hGetBytes :: Handle -> Int -> IO ByteString
 60  hGetBytes h n = L.hGet h (fromIntegral n)
 61  
 62  hSeekInt :: Handle -> Int -> IO ()
 63  hSeekInt h ofs = hSeek h AbsoluteSeek (fromIntegral ofs)
 64  
 65  readGraphFile :: Handle -> IO (Either Msg Graph)
 66  readGraphFile h = do
 67    flen  <- (fromIntegral :: Integer -> Int) <$> hFileSize h 
 68    magic <- hGetBytes h (length magicHeader)
 69    if magic /= LC.pack magicHeader
 70      then return $ Left "magic header not found or invalid"
 71      else do
 72        hSeekInt h (flen - 8)
 73        offset <- (fromIntegral . runGet getWord64le) <$> hGetBytes h 8
 74        -- putStrLn $ "metadata offset = " ++ show offset
 75        if (offset >= flen) || (offset <= 18)
 76          then return $ Left "invalid final `graphMetaData` offset bytes"
 77          else do
 78            hSeekInt h (length magicHeader)
 79            part1 <- hGetBytes h (offset - length magicHeader)
 80            part2 <- hGetBytes h (flen - offset - 8)
 81            return (Right $ Graph (parseNodes part1) (parseMeta part2))
 82  
 83  magicHeader :: String
 84  magicHeader = "wtns.graph.001"
 85  
 86  parseNodes :: ByteString -> [Node]
 87  parseNodes = runGet getNodes
 88  
 89  parseMeta :: ByteString -> GraphMetaData
 90  parseMeta = runGet getMetaData
 91  
 92  --------------------------------------------------------------------------------
 93  
 94  varInt' :: Get Word64
 95  varInt' = go 0 where
 96    go !cnt = if cnt >= 8 then return 0 else do
 97      w <- getWord8
 98      if (w < 128) 
 99        then return (fromIntegral w)
100        else do
101          let x = fromIntegral (w .&. 127) 
102          y <- go (cnt+1)
103          return (x + 128*y)
104  
105  varInt :: Get Int
106  varInt = fromIntegral <$> varInt'
107  
108  varUInt :: Get Word32
109  varUInt = fromIntegral <$> varInt'
110  
111  --------------------------------------------------------------------------------
112  
113  expectingError :: Int -> String -> Int -> Get a
114  expectingError actual what shouldbe = do
115    error $ what ++ ": expecting field " ++ show shouldbe ++ "; got " ++ show actual ++ " instead" 
116  
117  getNodes :: Get [Node]
118  getNodes = do
119    n <- getWord64le 
120    replicateM (fromIntegral n) getNode
121  
122  -- | with varint length prefix
123  getNode :: Get Node
124  getNode = do
125    len <- varInt
126    bs  <- getLazyByteString (fromIntegral len)
127    return (runGet getNode' bs)
128  
129  -- | without varint length prefix
130  getNode' :: Get Node
131  getNode' = do
132    nodetype <- getFieldId LEN
133    case nodetype of
134      1 -> AnInputNode   <$> getInputNode   
135      2 -> AConstantNode <$> getConstantNode
136      3 -> AnUnoOpNode   <$> getUnoOpNode   
137      4 -> ADuoOpNode    <$> getDuoOpNode   
138      5 -> ATresOpNode   <$> getTresOpNode  
139      _ -> error "unexpected node type"
140  
141  getInputNode :: Get InputNode
142  getInputNode = do
143    SomeNode idx _ _ _ <- getSomeNode 
144    return (InputNode idx)
145  
146  getConstantNode :: Get ConstantNode
147  getConstantNode = do
148    len  <- varInt
149    bs   <- getLazyByteString (fromIntegral len)
150    return $ ConstantNode (runGet getBigUInt bs)
151  
152  getBigUInt' :: FieldId -> Get BigUInt
153  getBigUInt' expectedId = do
154    fld <- getFieldId LEN
155    if fld /= expectedId
156      then expectingError fld "getBigUInt" expectedId
157      else do
158        len <- varInt
159        bs  <- getLazyByteString (fromIntegral len)
160        return $ BigUInt (runGet getByteList bs)
161  
162  getBigUInt ::Get BigUInt
163  getBigUInt = getBigUInt' 1
164  
165  getByteList :: Get [Word8]
166  getByteList = do
167    fld <- getFieldId LEN
168    if fld /= 1 
169      then expectingError fld "getByteList" 1
170      else do
171        len <- varInt
172        bs  <- getLazyByteString (fromIntegral len)
173        return (L.unpack bs)
174  
175  getString' :: FieldId -> Get String
176  getString' expectedId = do
177    fld <- getFieldId LEN 
178    if fld /= expectedId
179      then expectingError fld "getString" expectedId
180      else do
181        len <- varInt
182        bs  <- getLazyByteString (fromIntegral len)
183        return (LC.unpack bs)
184  
185  getUnoOpNode :: Get UnoOpNode
186  getUnoOpNode = do
187    SomeNode op arg1 _ _ <- getSomeNode 
188    return (UnoOpNode (wordToEnum op) arg1)
189  
190  getDuoOpNode :: Get DuoOpNode
191  getDuoOpNode = do
192    SomeNode op arg1 arg2 _ <- getSomeNode 
193    return (DuoOpNode (wordToEnum op) arg1 arg2)
194  
195  getTresOpNode :: Get TresOpNode
196  getTresOpNode = do
197    SomeNode op arg1 arg2 arg3 <- getSomeNode 
198    return (TresOpNode (wordToEnum op) arg1 arg2 arg3)
199  
200  wordToEnum :: Enum a => Word32 -> a
201  wordToEnum w = toEnum (fromIntegral w)
202  
203  --------------------------------------------------------------------------------
204  
205  data SomeNode = SomeNode
206    { field1 :: Word32
207    , field2 :: Word32
208    , field3 :: Word32
209    , field4 :: Word32
210    }
211    deriving Show
212  
213  defaultSomeNode :: SomeNode
214  defaultSomeNode = SomeNode 0 0 0 0
215  
216  insert1 :: (Int,Word32) -> SomeNode -> SomeNode
217  insert1 (idx,val) old = case idx of
218    1 -> old { field1 = val }
219    2 -> old { field2 = val } 
220    3 -> old { field3 = val } 
221    4 -> old { field4 = val }
222  
223  insertMany :: [(Int,Word32)] -> SomeNode -> SomeNode
224  insertMany list old = foldl' (flip insert1) old list
225  
226  getSomeNode :: Get SomeNode
227  getSomeNode = do
228    len  <- varInt
229    bs   <- getLazyByteString (fromIntegral len)
230    let list = runGet getRecord bs
231    return $ insertMany list defaultSomeNode 
232  
233  --------------------------------------------------------------------------------
234  -- TODO: refactor this mess
235  
236  getMetaData :: Get GraphMetaData
237  getMetaData = do
238    len <- varInt
239    mapping <- getWitnessMapping
240    inputs  <- getCircuitInputs
241    prime   <- getPrime
242    return $ GraphMetaData mapping inputs prime
243  
244  getPrime :: Get Prime
245  getPrime = do
246    number <- getBigUInt' 3
247    name   <- getString'  4
248    return $ Prime
249      { primeNumber = number
250      , primeName   = name
251      }
252  
253  getWitnessMapping :: Get WitnessMapping
254  getWitnessMapping = do
255    fld <- getFieldId LEN 
256    if fld /= 1 
257      then expectingError fld "getWitnessMapping" 1
258      else do
259        len <- varInt
260        bs  <- getLazyByteString (fromIntegral len)
261        return $ WitnessMapping (runGet worker bs)
262    where
263      worker :: Get [Word32]
264      worker = isEmpty >>= \b -> if b 
265        then return [] 
266        else (:) <$> varUInt <*> worker
267  
268  getCircuitInputs :: Get CircuitInputs
269  getCircuitInputs = worker where
270  
271  {-
272    worker :: Get [(String, SignalDescription)]
273    worker = isEmpty >>= \b -> if b 
274      then return [] 
275      else (:) <$> getSingleInput <*> worker
276  -}
277  
278    worker :: Get [(String, SignalDescription)]
279    worker = do
280      mb <- getSingleInput 
281      case mb of
282        Nothing   -> return []
283        Just this -> (this:) <$> worker
284  
285    getSingleInput :: Get (Maybe (String, SignalDescription))
286    getSingleInput = do
287      fld <- lookAhead (getFieldId LEN)
288      if fld /= 2
289        then return Nothing    -- expectingError fld "getSingleInput" 2
290        else do
291          _fld <- getFieldId LEN 
292          len  <- varInt
293          bs   <- getLazyByteString (fromIntegral len)
294          return $ Just $ runGet inputHelper bs
295  
296    inputHelper = do
297      name   <- getName
298      signal <- getSignal
299      return (name,signal)
300  
301    getName :: Get String
302    getName = do
303      fld <- getFieldId LEN 
304      if fld /= 1
305        then expectingError fld "getCircuitInputs/getName" 1
306        else do
307          len <- varInt
308          bs  <- getLazyByteString (fromIntegral len)
309          return (LC.unpack bs)
310  
311    getSignal :: Get SignalDescription
312    getSignal = do
313      fld <- getFieldId LEN 
314      if fld /= 2
315        then expectingError fld "getCircuitInputs/getSignal" 2
316        else do
317          len <- varInt
318          bs  <- getLazyByteString (fromIntegral len)
319          return $ runGet signalHelper bs
320  
321    signalHelper = do
322      ofs <- getSignalOffset
323      len <- getSignalLength
324      return $ SignalDescription { signalOffset = ofs , signalLength = len }
325  
326    getSignalOffset = do
327      fld <- getFieldId VARINT
328      if fld /= 1
329        then expectingError fld "getCircuitInputs/getSignalOffset" 1
330        else varUInt
331  
332    getSignalLength = do
333      fld <- getFieldId VARINT
334      if fld /= 2
335        then expectingError fld "getCircuitInputs/getSignalLength" 2
336        else varUInt
337  
338  --------------------------------------------------------------------------------
339  -- * protobuf stuff
340  
341  -- | There are six wire types: VARINT, I64, LEN, SGROUP, EGROUP, and I32
342  data WireType
343    = VARINT    -- ^ used for: int32, int64, uint32, uint64, sint32, sint64, bool, enum
344    | I64       -- ^ used for: fixed64, sfixed64, double
345    | LEN       -- ^ used for: string, bytes, embedded messages, packed repeated fields
346    | SGROUP    -- ^ used for: group start (deprecated)
347    | EGROUP    -- ^ used for: group end (deprecated)
348    | I32       -- ^ fixed32, sfixed32, float
349    deriving (Eq,Show,Enum,Bounded)
350  
351  type FieldId = Int
352  
353  getFieldId :: WireType -> Get FieldId
354  getFieldId wty = do
355    tag <- getWord8
356    let (fld,wty') = decodeTag tag
357    if wty == wty' 
358      then return fld
359      else error "getFieldId: unexpected protobuf wire type"
360  
361  decodeTag_ :: Word8 -> FieldId
362  decodeTag_ = fst . decodeTag
363  
364  decodeTag :: Word8 -> (FieldId, WireType)
365  decodeTag w = (fld , wty) where
366    fld = fromIntegral (shiftR w 3) 
367    wty = toEnum (fromIntegral (w .&. 7))
368  
369  -- (index, value) pair
370  getEntry :: Get (Int,Word32)
371  getEntry = do
372    idx <- getFieldId VARINT
373    val <- varUInt
374    return (idx,val)
375  
376  -- list of (index, value) pairs
377  getRecord :: Get [(Int,Word32)]
378  getRecord = sort <$> go where
379    go = isEmpty >>= \b -> if b
380      then return []
381      else (:) <$> getEntry <*> go
382  
383  --------------------------------------------------------------------------------