FFI.hs
1 {-# LANGUAGE ForeignFunctionInterface #-} 2 3 {- | FFI bindings to C++ code. 4 5 This module demonstrates calling C++ from Haskell via the FFI. 6 The C++ code is compiled separately and linked via Buck2's extra_libraries. 7 -} 8 module FFI ( 9 -- * Simple arithmetic 10 add, 11 multiply, 12 13 -- * Vector operations 14 dotProduct, 15 norm, 16 scaleVector, 17 18 -- * String operations 19 greet, 20 21 -- * Counter (opaque handle) 22 Counter, 23 newCounter, 24 freeCounter, 25 withCounter, 26 getCounter, 27 incrementCounter, 28 addCounter, 29 ) where 30 31 import Control.Exception (bracket) 32 import Foreign.C.String 33 import Foreign.C.Types 34 import Foreign.Marshal.Array 35 import Foreign.Ptr 36 import System.IO.Unsafe (unsafePerformIO) 37 38 -- ============================================================================= 39 -- Simple arithmetic 40 -- ============================================================================= 41 42 foreign import ccall unsafe "ffi_add" 43 c_add :: CInt -> CInt -> CInt 44 45 foreign import ccall unsafe "ffi_multiply" 46 c_multiply :: CInt -> CInt -> CInt 47 48 -- | Add two integers (calls C++). 49 add :: Int -> Int -> Int 50 add a b = fromIntegral $ c_add (fromIntegral a) (fromIntegral b) 51 52 -- | Multiply two integers (calls C++). 53 multiply :: Int -> Int -> Int 54 multiply a b = fromIntegral $ c_multiply (fromIntegral a) (fromIntegral b) 55 56 -- ============================================================================= 57 -- Vector operations 58 -- ============================================================================= 59 60 foreign import ccall unsafe "ffi_dot_product" 61 c_dot_product :: Ptr CDouble -> Ptr CDouble -> CSize -> CDouble 62 63 foreign import ccall unsafe "ffi_norm" 64 c_norm :: Ptr CDouble -> CSize -> CDouble 65 66 foreign import ccall unsafe "ffi_scale" 67 c_scale :: Ptr CDouble -> CSize -> CDouble -> IO () 68 69 -- | Compute dot product of two vectors. 70 dotProduct :: [Double] -> [Double] -> Double 71 dotProduct xs ys 72 | length xs /= length ys = error "dotProduct: vectors must have same length" 73 | otherwise = realToFrac $ unsafePerformIOWithArrays xs ys $ \pxs pys len -> 74 return $ c_dot_product pxs pys (fromIntegral len) 75 where 76 unsafePerformIOWithArrays :: [Double] -> [Double] -> (Ptr CDouble -> Ptr CDouble -> Int -> IO a) -> a 77 unsafePerformIOWithArrays as bs f = unsafePerformIO $ 78 withArray (map realToFrac as) $ \pas -> 79 withArray (map realToFrac bs) $ \pbs -> 80 f pas pbs (length as) 81 82 -- | Compute L2 norm of a vector. 83 norm :: [Double] -> Double 84 norm xs = unsafePerformIO $ 85 withArray (map realToFrac xs) $ \pxs -> 86 return $ realToFrac $ c_norm pxs (fromIntegral $ length xs) 87 88 -- | Scale a vector by a scalar (returns new vector). 89 scaleVector :: Double -> [Double] -> [Double] 90 scaleVector scalar xs = unsafePerformIO $ 91 withArray (map realToFrac xs) $ \pxs -> do 92 c_scale pxs (fromIntegral $ length xs) (realToFrac scalar) 93 map realToFrac <$> peekArray (length xs) pxs 94 95 -- ============================================================================= 96 -- String operations 97 -- ============================================================================= 98 99 foreign import ccall unsafe "ffi_greet" 100 c_greet :: CString -> IO CString 101 102 foreign import ccall unsafe "ffi_free_string" 103 c_free_string :: CString -> IO () 104 105 -- | Generate a greeting (calls C++). 106 greet :: String -> IO String 107 greet name = withCString name $ \cname -> do 108 cresult <- c_greet cname 109 result <- peekCString cresult 110 c_free_string cresult 111 return result 112 113 -- ============================================================================= 114 -- Counter (opaque handle pattern) 115 -- ============================================================================= 116 117 -- | Opaque handle to a C++ Counter object. 118 newtype Counter = Counter (Ptr Counter) 119 120 foreign import ccall unsafe "ffi_counter_new" 121 c_counter_new :: CInt -> IO (Ptr Counter) 122 123 foreign import ccall unsafe "ffi_counter_free" 124 c_counter_free :: Ptr Counter -> IO () 125 126 foreign import ccall unsafe "ffi_counter_get" 127 c_counter_get :: Ptr Counter -> CInt 128 129 foreign import ccall unsafe "ffi_counter_increment" 130 c_counter_increment :: Ptr Counter -> IO CInt 131 132 foreign import ccall unsafe "ffi_counter_add" 133 c_counter_add :: Ptr Counter -> CInt -> IO CInt 134 135 -- | Create a new counter with initial value. 136 newCounter :: Int -> IO Counter 137 newCounter initial = Counter <$> c_counter_new (fromIntegral initial) 138 139 -- | Free a counter. 140 freeCounter :: Counter -> IO () 141 freeCounter (Counter ptr) = c_counter_free ptr 142 143 -- | Use a counter with automatic cleanup. 144 withCounter :: Int -> (Counter -> IO a) -> IO a 145 withCounter initial = bracket (newCounter initial) freeCounter 146 147 -- | Get the current value. 148 getCounter :: Counter -> Int 149 getCounter (Counter ptr) = fromIntegral $ c_counter_get ptr 150 151 -- | Increment and return new value. 152 incrementCounter :: Counter -> IO Int 153 incrementCounter (Counter ptr) = fromIntegral <$> c_counter_increment ptr 154 155 -- | Add n and return new value. 156 addCounter :: Counter -> Int -> IO Int 157 addCounter (Counter ptr) n = fromIntegral <$> c_counter_add ptr (fromIntegral n)