/ nix / script / exe / oci-gpu-shelly.hs
oci-gpu-shelly.hs
  1  {-# LANGUAGE ExtendedDefaultRules #-}
  2  {-# LANGUAGE OverloadedStrings #-}
  3  {-# OPTIONS_GHC -fno-warn-type-defaults #-}
  4  
  5  {- |
  6  Module      : OciGpu
  7  Description : Run OCI container images with NVIDIA GPU access
  8  
  9  A typed shell script that pulls OCI images, caches them, and runs them
 10  in a bubblewrap namespace with GPU device passthrough.
 11  
 12  This is the Shelly version. Compared to Turtle:
 13    - Thread-safe (maintains its own environment state)
 14    - Command tracing built-in (great for debugging)
 15    - More verbose API but more control
 16    - Uses monad transformer (ShIO)
 17  -}
 18  module Main where
 19  
 20  import Data.Aeson (Value (..), decode, (.:?))
 21  import Data.Aeson.Types (parseMaybe)
 22  import qualified Data.Aeson.Types as Aeson
 23  import qualified Data.ByteString.Lazy as BL
 24  import Data.Maybe (fromMaybe)
 25  import Data.Text (Text)
 26  import qualified Data.Text as T
 27  import qualified Data.Text.Encoding as TE
 28  import qualified Data.Vector as V
 29  import Shelly
 30  import qualified System.Environment
 31  import System.FilePath (takeFileName)
 32  import System.Posix.Process (executeFile)
 33  
 34  default (Text)
 35  
 36  -- ============================================================================
 37  -- Types
 38  -- ============================================================================
 39  
 40  data Config = Config
 41      { cfgImage :: Text
 42      , cfgCommand :: [Text]
 43      , cfgPlatform :: Text
 44      , cfgCacheDir :: FilePath
 45      , cfgCertFile :: FilePath
 46      }
 47      deriving (Show)
 48  
 49  data ContainerEnv = ContainerEnv
 50      { envPath :: Maybe Text
 51      , envLdLibPath :: Maybe Text
 52      }
 53      deriving (Show)
 54  
 55  emptyEnv :: ContainerEnv
 56  emptyEnv = ContainerEnv Nothing Nothing
 57  
 58  -- ============================================================================
 59  -- Main
 60  -- ============================================================================
 61  
 62  main :: IO ()
 63  main = shelly $ verbosely $ do
 64      args <- liftIO System.Environment.getArgs
 65      case args of
 66          [] -> errorExit "Usage: oci-gpu IMAGE [COMMAND...]"
 67          (image : cmdArgs) -> do
 68              let cmd' = if null cmdArgs then ["nvidia-smi"] else map T.pack cmdArgs
 69              cfg <- buildConfig (T.pack image) cmd'
 70              runWithGpu cfg
 71  
 72  buildConfig :: Text -> [Text] -> Sh Config
 73  buildConfig image cmd' = do
 74      homeDir <- fromMaybe (error "HOME not set") <$> get_env "HOME"
 75      xdgCache <- get_env "XDG_CACHE_HOME"
 76      let cacheBase = case xdgCache of
 77              Just c -> fromText c
 78              Nothing -> fromText homeDir </> ".cache"
 79  
 80      -- These would be injected by Nix in the real build
 81      let certFile = "/etc/ssl/certs/ca-bundle.crt"
 82          platform = "linux/amd64"
 83  
 84      pure
 85          Config
 86              { cfgImage = image
 87              , cfgCommand = cmd'
 88              , cfgPlatform = platform
 89              , cfgCacheDir = cacheBase </> "straylight-oci"
 90              , cfgCertFile = fromText certFile
 91              }
 92  
 93  -- ============================================================================
 94  -- Core Logic
 95  -- ============================================================================
 96  
 97  runWithGpu :: Config -> Sh ()
 98  runWithGpu cfg = do
 99      -- Ensure cache directory exists
100      mkdir_p (cfgCacheDir cfg)
101  
102      -- Compute cache key from image name
103      cacheKey <- getCacheKey (cfgImage cfg)
104      let cachedRootfs = cfgCacheDir cfg </> fromText cacheKey
105  
106      -- Create temp working directory
107      -- Shelly has withTmpDir for automatic cleanup
108      withTmpDir $ \workDir -> do
109          let rootfsLink = workDir </> "rootfs"
110  
111          -- Pull or use cached image
112          cached <- test_d cachedRootfs
113          if cached
114              then do
115                  echo $ ":: Using cached " <> cfgImage cfg
116                  -- Shelly doesn't have symlink, use command
117                  run_ "ln" ["-s", toTextIgnore cachedRootfs, toTextIgnore rootfsLink]
118              else do
119                  echo $ ":: Pulling " <> cfgImage cfg
120                  pullImage cfg workDir
121                  mv (workDir </> "rootfs") cachedRootfs
122                  run_ "ln" ["-s", toTextIgnore cachedRootfs, toTextIgnore rootfsLink]
123                  echo $ ":: Cached to " <> toTextIgnore cachedRootfs
124  
125          -- Create nvidia mount points
126          mkdir_p (rootfsLink </> "usr/local/nvidia/bin")
127          mkdir_p (rootfsLink </> "usr/local/nvidia/lib64")
128  
129          -- Discover GPU devices and drivers
130          nvBinds <- discoverNvidiaBinds
131  
132          -- Extract container environment from image config
133          containerEnv <- getContainerEnv cfg
134  
135          -- Build final environment
136          let combinedPath = buildPath containerEnv
137              combinedLdPath = buildLdPath containerEnv
138  
139          -- Execute bwrap
140          echo ":: Entering namespace with GPU"
141          let bwrapArgs = buildBwrapArgs workDir nvBinds combinedPath combinedLdPath (cfgCommand cfg)
142  
143          -- Use exec to replace process
144          liftIO $ executeFile "bwrap" True (map T.unpack bwrapArgs) Nothing
145  
146  -- ============================================================================
147  -- Image Operations
148  -- ============================================================================
149  
150  getCacheKey :: Text -> Sh Text
151  getCacheKey image = do
152      -- SHA256 hash of image name, first 16 chars
153      -- Shelly's escaping False equivalent is using bash directly
154      result <- run "sh" ["-c", "echo -n '" <> image <> "' | sha256sum | cut -c1-16"]
155      pure $ T.strip result
156  
157  pullImage :: Config -> FilePath -> Sh ()
158  pullImage cfg workDir = do
159      let rootfs = workDir </> "rootfs"
160      mkdir_p rootfs
161  
162      -- Set SSL cert for crane
163      setenv "SSL_CERT_FILE" (toTextIgnore $ cfgCertFile cfg)
164  
165      -- Pull and extract in a pipeline
166      run_
167          "sh"
168          [ "-c"
169          , "crane export --platform " <> cfgPlatform cfg <> " '" <> cfgImage cfg <> "' - | tar -xf - -C " <> toTextIgnore rootfs
170          ]
171  
172  getContainerEnv :: Config -> Sh ContainerEnv
173  getContainerEnv cfg = do
174      -- Run crane config and parse JSON
175      -- errExit False to not fail on non-zero exit
176      result <- errExit False $ run "crane" ["config", cfgImage cfg]
177      code <- lastExitCode
178      if code == 0
179          then case decode (BL.fromStrict $ TE.encodeUtf8 result) of
180              Just val -> pure $ parseEnvFromConfig val
181              Nothing -> pure emptyEnv
182          else pure emptyEnv
183    where
184      parseEnvFromConfig :: Value -> ContainerEnv
185      parseEnvFromConfig val = fromMaybe emptyEnv $ parseMaybe parseConfig val
186  
187      parseConfig :: Value -> Aeson.Parser ContainerEnv
188      parseConfig (Object obj) = do
189          mConfig <- obj .:? "config"
190          case mConfig of
191              Just (Object cfgObj) -> do
192                  mEnvList <- cfgObj .:? "Env"
193                  case mEnvList of
194                      Just (Array arr) -> do
195                          let envPairs = map extractEnvVar (V.toList arr)
196                          pure
197                              ContainerEnv
198                                  { envPath = lookup "PATH" envPairs
199                                  , envLdLibPath = lookup "LD_LIBRARY_PATH" envPairs
200                                  }
201                      _ -> pure emptyEnv
202              _ -> pure emptyEnv
203      parseConfig _ = pure emptyEnv
204  
205      extractEnvVar :: Value -> (Text, Text)
206      extractEnvVar (String str) =
207          let (k, v) = T.breakOn "=" str
208           in (k, T.drop 1 v)
209      extractEnvVar _ = ("", "")
210  
211  -- ============================================================================
212  -- NVIDIA Discovery
213  -- ============================================================================
214  
215  discoverNvidiaBinds :: Sh [Text]
216  discoverNvidiaBinds = do
217      devBinds <- discoverDevices
218      driverBinds <- discoverDriver
219      glBinds <- discoverOpenGL
220      let nixBind = ["--ro-bind", "/nix/store", "/nix/store"]
221      pure $ devBinds <> driverBinds <> glBinds <> nixBind
222  
223  discoverDevices :: Sh [Text]
224  discoverDevices = do
225      -- Find /dev/nvidia* devices
226      -- Shelly's findWhen takes a predicate on FilePath
227      allDevs <- ls "/dev"
228      let nvDevs = filter (T.isPrefixOf "nvidia" . T.pack . takeFileName . T.unpack . toTextIgnore) allDevs
229  
230      -- Find /dev/dri/* devices
231      driExists <- test_d "/dev/dri"
232      driDevs <-
233          if driExists
234              then ls "/dev/dri"
235              else pure []
236  
237      let allBindDevs = nvDevs <> driDevs
238          binds = concatMap (\dev -> ["--dev-bind", toTextIgnore dev, toTextIgnore dev]) allBindDevs
239      pure binds
240  
241  discoverDriver :: Sh [Text]
242  discoverDriver = do
243      -- Find nvidia driver path via nvidia-smi symlink
244      result <- errExit False $ run "readlink" ["-f", "/run/current-system/sw/bin/nvidia-smi"]
245      code <- lastExitCode
246      if code == 0
247          then do
248              let driverPath = T.replace "/bin/nvidia-smi" "" (T.strip result)
249                  driverDir = fromText driverPath
250  
251              exists <- test_d driverDir
252              if exists
253                  then do
254                      echo $ ":: Found nvidia driver at " <> driverPath
255                      binBind <- ifDirExists (driverDir </> "bin") "/usr/local/nvidia/bin"
256                      libBind <- ifDirExists (driverDir </> "lib") "/usr/local/nvidia/lib64"
257                      pure $ binBind <> libBind
258                  else pure []
259          else pure []
260    where
261      ifDirExists :: FilePath -> Text -> Sh [Text]
262      ifDirExists src dst = do
263          exists <- test_d src
264          if exists
265              then pure ["--ro-bind", toTextIgnore src, dst]
266              else pure []
267  
268  discoverOpenGL :: Sh [Text]
269  discoverOpenGL = do
270      -- Bind opengl drivers if present
271      gl1 <- glBind "/run/opengl-driver"
272      gl2 <- glBind "/run/opengl-driver-32"
273      pure $ gl1 <> gl2
274    where
275      glBind :: FilePath -> Sh [Text]
276      glBind glPath = do
277          exists <- test_d glPath
278          if exists
279              then pure ["--ro-bind", toTextIgnore glPath, toTextIgnore glPath]
280              else pure []
281  
282  -- ============================================================================
283  -- Environment Building
284  -- ============================================================================
285  
286  buildPath :: ContainerEnv -> Text
287  buildPath cenv =
288      let base = "/usr/local/nvidia/bin"
289          defaultPath = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
290       in case envPath cenv of
291              Just p -> base <> ":" <> p
292              Nothing -> base <> ":" <> defaultPath
293  
294  buildLdPath :: ContainerEnv -> Text
295  buildLdPath cenv =
296      let base = "/usr/local/nvidia/lib64:/run/opengl-driver/lib"
297       in case envLdLibPath cenv of
298              Just p -> base <> ":" <> p
299              Nothing -> base
300  
301  -- ============================================================================
302  -- Bwrap Execution
303  -- ============================================================================
304  
305  buildBwrapArgs :: FilePath -> [Text] -> Text -> Text -> [Text] -> [Text]
306  buildBwrapArgs workDir nvBinds envPath ldPath cmd' =
307      [ "--bind"
308      , toTextIgnore (workDir </> "rootfs")
309      , "/"
310      , "--dev"
311      , "/dev"
312      , "--proc"
313      , "/proc"
314      , "--ro-bind"
315      , "/sys"
316      , "/sys"
317      , "--tmpfs"
318      , "/tmp"
319      , "--tmpfs"
320      , "/run"
321      ]
322          <> nvBinds
323          <> [ "--ro-bind"
324             , "/etc/resolv.conf"
325             , "/etc/resolv.conf"
326             , "--ro-bind"
327             , "/etc/ssl"
328             , "/etc/ssl"
329             , "--setenv"
330             , "PATH"
331             , envPath
332             , "--setenv"
333             , "HOME"
334             , "/root"
335             , "--setenv"
336             , "LD_LIBRARY_PATH"
337             , ldPath
338             , "--setenv"
339             , "OPAL_PREFIX"
340             , "/opt/hpcx/ompi"
341             , "--setenv"
342             , "OMPI_MCA_btl"
343             , "^openib"
344             , "--chdir"
345             , "/root"
346             , "--die-with-parent"
347             , "--unshare-pid"
348             , "--"
349             ]
350          <> cmd'