/ src / examples / haskell-cxx / Main.hs
Main.hs
 1  -- | Test Haskell calling C++ via FFI.
 2  module Main where
 3  
 4  import Control.Monad (unless)
 5  import FFI
 6  
 7  main :: IO ()
 8  main = do
 9      putStrLn "Haskell calling C++ via FFI:"
10  
11      -- Test arithmetic
12      testArithmetic
13  
14      -- Test vector operations
15      testVectors
16  
17      -- Test strings
18      testStrings
19  
20      -- Test Counter
21      testCounter
22  
23      putStrLn "all tests passed"
24  
25  testArithmetic :: IO ()
26  testArithmetic = do
27      let r1 = add 2 3
28      check "add 2 3 == 5" (r1 == 5)
29  
30      let r2 = multiply 4 5
31      check "multiply 4 5 == 20" (r2 == 20)
32  
33      putStrLn "  arithmetic: pass"
34  
35  testVectors :: IO ()
36  testVectors = do
37      let a = [1.0, 2.0, 3.0]
38          b = [4.0, 5.0, 6.0]
39  
40      -- dot product: 1*4 + 2*5 + 3*6 = 32
41      let dp = dotProduct a b
42      check "dotProduct [1,2,3] [4,5,6] == 32" (abs (dp - 32.0) < 1e-10)
43  
44      -- norm: sqrt(1 + 4 + 9) = sqrt(14)
45      let n = norm a
46      check "norm [1,2,3] == sqrt(14)" (abs (n - sqrt 14) < 1e-10)
47  
48      -- scale
49      let scaled = scaleVector 2.0 a
50      check
51          "scaleVector 2 [1,2,3] == [2,4,6]"
52          (all (< 1e-10) $ zipWith (\x y -> abs (x - y)) scaled [2.0, 4.0, 6.0])
53  
54      putStrLn "  vectors: pass"
55  
56  testStrings :: IO ()
57  testStrings = do
58      greeting <- greet "Haskell"
59      check "greet contains 'Haskell'" ("Haskell" `isInfixOf` greeting)
60      check "greet contains 'C++'" ("C++" `isInfixOf` greeting)
61  
62      putStrLn "  strings: pass"
63    where
64      isInfixOf needle haystack = any (needle `isPrefixOf`) (tails haystack)
65      isPrefixOf [] _ = True
66      isPrefixOf _ [] = False
67      isPrefixOf (x : xs) (y : ys) = x == y && isPrefixOf xs ys
68      tails [] = [[]]
69      tails xs@(_ : xs') = xs : tails xs'
70  
71  testCounter :: IO ()
72  testCounter = do
73      withCounter 10 $ \c -> do
74          v1 <- pure $ getCounter c
75          check "initial value == 10" (v1 == 10)
76  
77          v2 <- incrementCounter c
78          check "after increment == 11" (v2 == 11)
79  
80          v3 <- addCounter c 5
81          check "after add 5 == 16" (v3 == 16)
82  
83      putStrLn "  Counter: pass"
84  
85  check :: String -> Bool -> IO ()
86  check msg ok = unless ok $ error $ "FAILED: " ++ msg