/ src / examples / haskell-cxx / FFI.hs
FFI.hs
  1  {-# LANGUAGE ForeignFunctionInterface #-}
  2  
  3  {- | FFI bindings to C++ code.
  4  
  5  This module demonstrates calling C++ from Haskell via the FFI.
  6  The C++ code is compiled separately and linked via Buck2's extra_libraries.
  7  -}
  8  module FFI (
  9      -- * Simple arithmetic
 10      add,
 11      multiply,
 12  
 13      -- * Vector operations
 14      dotProduct,
 15      norm,
 16      scaleVector,
 17  
 18      -- * String operations
 19      greet,
 20  
 21      -- * Counter (opaque handle)
 22      Counter,
 23      newCounter,
 24      freeCounter,
 25      withCounter,
 26      getCounter,
 27      incrementCounter,
 28      addCounter,
 29  ) where
 30  
 31  import Control.Exception (bracket)
 32  import Foreign.C.String
 33  import Foreign.C.Types
 34  import Foreign.Marshal.Array
 35  import Foreign.Ptr
 36  import System.IO.Unsafe (unsafePerformIO)
 37  
 38  -- =============================================================================
 39  -- Simple arithmetic
 40  -- =============================================================================
 41  
 42  foreign import ccall unsafe "ffi_add"
 43      c_add :: CInt -> CInt -> CInt
 44  
 45  foreign import ccall unsafe "ffi_multiply"
 46      c_multiply :: CInt -> CInt -> CInt
 47  
 48  -- | Add two integers (calls C++).
 49  add :: Int -> Int -> Int
 50  add a b = fromIntegral $ c_add (fromIntegral a) (fromIntegral b)
 51  
 52  -- | Multiply two integers (calls C++).
 53  multiply :: Int -> Int -> Int
 54  multiply a b = fromIntegral $ c_multiply (fromIntegral a) (fromIntegral b)
 55  
 56  -- =============================================================================
 57  -- Vector operations
 58  -- =============================================================================
 59  
 60  foreign import ccall unsafe "ffi_dot_product"
 61      c_dot_product :: Ptr CDouble -> Ptr CDouble -> CSize -> CDouble
 62  
 63  foreign import ccall unsafe "ffi_norm"
 64      c_norm :: Ptr CDouble -> CSize -> CDouble
 65  
 66  foreign import ccall unsafe "ffi_scale"
 67      c_scale :: Ptr CDouble -> CSize -> CDouble -> IO ()
 68  
 69  -- | Compute dot product of two vectors.
 70  dotProduct :: [Double] -> [Double] -> Double
 71  dotProduct xs ys
 72      | length xs /= length ys = error "dotProduct: vectors must have same length"
 73      | otherwise = realToFrac $ unsafePerformIOWithArrays xs ys $ \pxs pys len ->
 74          return $ c_dot_product pxs pys (fromIntegral len)
 75    where
 76      unsafePerformIOWithArrays :: [Double] -> [Double] -> (Ptr CDouble -> Ptr CDouble -> Int -> IO a) -> a
 77      unsafePerformIOWithArrays as bs f = unsafePerformIO $
 78          withArray (map realToFrac as) $ \pas ->
 79              withArray (map realToFrac bs) $ \pbs ->
 80                  f pas pbs (length as)
 81  
 82  -- | Compute L2 norm of a vector.
 83  norm :: [Double] -> Double
 84  norm xs = unsafePerformIO $
 85      withArray (map realToFrac xs) $ \pxs ->
 86          return $ realToFrac $ c_norm pxs (fromIntegral $ length xs)
 87  
 88  -- | Scale a vector by a scalar (returns new vector).
 89  scaleVector :: Double -> [Double] -> [Double]
 90  scaleVector scalar xs = unsafePerformIO $
 91      withArray (map realToFrac xs) $ \pxs -> do
 92          c_scale pxs (fromIntegral $ length xs) (realToFrac scalar)
 93          map realToFrac <$> peekArray (length xs) pxs
 94  
 95  -- =============================================================================
 96  -- String operations
 97  -- =============================================================================
 98  
 99  foreign import ccall unsafe "ffi_greet"
100      c_greet :: CString -> IO CString
101  
102  foreign import ccall unsafe "ffi_free_string"
103      c_free_string :: CString -> IO ()
104  
105  -- | Generate a greeting (calls C++).
106  greet :: String -> IO String
107  greet name = withCString name $ \cname -> do
108      cresult <- c_greet cname
109      result <- peekCString cresult
110      c_free_string cresult
111      return result
112  
113  -- =============================================================================
114  -- Counter (opaque handle pattern)
115  -- =============================================================================
116  
117  -- | Opaque handle to a C++ Counter object.
118  newtype Counter = Counter (Ptr Counter)
119  
120  foreign import ccall unsafe "ffi_counter_new"
121      c_counter_new :: CInt -> IO (Ptr Counter)
122  
123  foreign import ccall unsafe "ffi_counter_free"
124      c_counter_free :: Ptr Counter -> IO ()
125  
126  foreign import ccall unsafe "ffi_counter_get"
127      c_counter_get :: Ptr Counter -> CInt
128  
129  foreign import ccall unsafe "ffi_counter_increment"
130      c_counter_increment :: Ptr Counter -> IO CInt
131  
132  foreign import ccall unsafe "ffi_counter_add"
133      c_counter_add :: Ptr Counter -> CInt -> IO CInt
134  
135  -- | Create a new counter with initial value.
136  newCounter :: Int -> IO Counter
137  newCounter initial = Counter <$> c_counter_new (fromIntegral initial)
138  
139  -- | Free a counter.
140  freeCounter :: Counter -> IO ()
141  freeCounter (Counter ptr) = c_counter_free ptr
142  
143  -- | Use a counter with automatic cleanup.
144  withCounter :: Int -> (Counter -> IO a) -> IO a
145  withCounter initial = bracket (newCounter initial) freeCounter
146  
147  -- | Get the current value.
148  getCounter :: Counter -> Int
149  getCounter (Counter ptr) = fromIntegral $ c_counter_get ptr
150  
151  -- | Increment and return new value.
152  incrementCounter :: Counter -> IO Int
153  incrementCounter (Counter ptr) = fromIntegral <$> c_counter_increment ptr
154  
155  -- | Add n and return new value.
156  addCounter :: Counter -> Int -> IO Int
157  addCounter (Counter ptr) n = fromIntegral <$> c_counter_add ptr (fromIntegral n)