oci-gpu-shelly.hs
1 {-# LANGUAGE ExtendedDefaultRules #-} 2 {-# LANGUAGE OverloadedStrings #-} 3 {-# OPTIONS_GHC -fno-warn-type-defaults #-} 4 5 {- | 6 Module : OciGpu 7 Description : Run OCI container images with NVIDIA GPU access 8 9 A typed shell script that pulls OCI images, caches them, and runs them 10 in a bubblewrap namespace with GPU device passthrough. 11 12 This is the Shelly version. Compared to Turtle: 13 - Thread-safe (maintains its own environment state) 14 - Command tracing built-in (great for debugging) 15 - More verbose API but more control 16 - Uses monad transformer (ShIO) 17 -} 18 module Main where 19 20 import Data.Aeson (Value (..), decode, (.:?)) 21 import Data.Aeson.Types (parseMaybe) 22 import qualified Data.Aeson.Types as Aeson 23 import qualified Data.ByteString.Lazy as BL 24 import Data.Maybe (fromMaybe) 25 import Data.Text (Text) 26 import qualified Data.Text as T 27 import qualified Data.Text.Encoding as TE 28 import qualified Data.Vector as V 29 import Shelly 30 import qualified System.Environment 31 import System.FilePath (takeFileName) 32 import System.Posix.Process (executeFile) 33 34 default (Text) 35 36 -- ============================================================================ 37 -- Types 38 -- ============================================================================ 39 40 data Config = Config 41 { cfgImage :: Text 42 , cfgCommand :: [Text] 43 , cfgPlatform :: Text 44 , cfgCacheDir :: FilePath 45 , cfgCertFile :: FilePath 46 } 47 deriving (Show) 48 49 data ContainerEnv = ContainerEnv 50 { envPath :: Maybe Text 51 , envLdLibPath :: Maybe Text 52 } 53 deriving (Show) 54 55 emptyEnv :: ContainerEnv 56 emptyEnv = ContainerEnv Nothing Nothing 57 58 -- ============================================================================ 59 -- Main 60 -- ============================================================================ 61 62 main :: IO () 63 main = shelly $ verbosely $ do 64 args <- liftIO System.Environment.getArgs 65 case args of 66 [] -> errorExit "Usage: oci-gpu IMAGE [COMMAND...]" 67 (image : cmdArgs) -> do 68 let cmd' = if null cmdArgs then ["nvidia-smi"] else map T.pack cmdArgs 69 cfg <- buildConfig (T.pack image) cmd' 70 runWithGpu cfg 71 72 buildConfig :: Text -> [Text] -> Sh Config 73 buildConfig image cmd' = do 74 homeDir <- fromMaybe (error "HOME not set") <$> get_env "HOME" 75 xdgCache <- get_env "XDG_CACHE_HOME" 76 let cacheBase = case xdgCache of 77 Just c -> fromText c 78 Nothing -> fromText homeDir </> ".cache" 79 80 -- These would be injected by Nix in the real build 81 let certFile = "/etc/ssl/certs/ca-bundle.crt" 82 platform = "linux/amd64" 83 84 pure 85 Config 86 { cfgImage = image 87 , cfgCommand = cmd' 88 , cfgPlatform = platform 89 , cfgCacheDir = cacheBase </> "straylight-oci" 90 , cfgCertFile = fromText certFile 91 } 92 93 -- ============================================================================ 94 -- Core Logic 95 -- ============================================================================ 96 97 runWithGpu :: Config -> Sh () 98 runWithGpu cfg = do 99 -- Ensure cache directory exists 100 mkdir_p (cfgCacheDir cfg) 101 102 -- Compute cache key from image name 103 cacheKey <- getCacheKey (cfgImage cfg) 104 let cachedRootfs = cfgCacheDir cfg </> fromText cacheKey 105 106 -- Create temp working directory 107 -- Shelly has withTmpDir for automatic cleanup 108 withTmpDir $ \workDir -> do 109 let rootfsLink = workDir </> "rootfs" 110 111 -- Pull or use cached image 112 cached <- test_d cachedRootfs 113 if cached 114 then do 115 echo $ ":: Using cached " <> cfgImage cfg 116 -- Shelly doesn't have symlink, use command 117 run_ "ln" ["-s", toTextIgnore cachedRootfs, toTextIgnore rootfsLink] 118 else do 119 echo $ ":: Pulling " <> cfgImage cfg 120 pullImage cfg workDir 121 mv (workDir </> "rootfs") cachedRootfs 122 run_ "ln" ["-s", toTextIgnore cachedRootfs, toTextIgnore rootfsLink] 123 echo $ ":: Cached to " <> toTextIgnore cachedRootfs 124 125 -- Create nvidia mount points 126 mkdir_p (rootfsLink </> "usr/local/nvidia/bin") 127 mkdir_p (rootfsLink </> "usr/local/nvidia/lib64") 128 129 -- Discover GPU devices and drivers 130 nvBinds <- discoverNvidiaBinds 131 132 -- Extract container environment from image config 133 containerEnv <- getContainerEnv cfg 134 135 -- Build final environment 136 let combinedPath = buildPath containerEnv 137 combinedLdPath = buildLdPath containerEnv 138 139 -- Execute bwrap 140 echo ":: Entering namespace with GPU" 141 let bwrapArgs = buildBwrapArgs workDir nvBinds combinedPath combinedLdPath (cfgCommand cfg) 142 143 -- Use exec to replace process 144 liftIO $ executeFile "bwrap" True (map T.unpack bwrapArgs) Nothing 145 146 -- ============================================================================ 147 -- Image Operations 148 -- ============================================================================ 149 150 getCacheKey :: Text -> Sh Text 151 getCacheKey image = do 152 -- SHA256 hash of image name, first 16 chars 153 -- Shelly's escaping False equivalent is using bash directly 154 result <- run "sh" ["-c", "echo -n '" <> image <> "' | sha256sum | cut -c1-16"] 155 pure $ T.strip result 156 157 pullImage :: Config -> FilePath -> Sh () 158 pullImage cfg workDir = do 159 let rootfs = workDir </> "rootfs" 160 mkdir_p rootfs 161 162 -- Set SSL cert for crane 163 setenv "SSL_CERT_FILE" (toTextIgnore $ cfgCertFile cfg) 164 165 -- Pull and extract in a pipeline 166 run_ 167 "sh" 168 [ "-c" 169 , "crane export --platform " <> cfgPlatform cfg <> " '" <> cfgImage cfg <> "' - | tar -xf - -C " <> toTextIgnore rootfs 170 ] 171 172 getContainerEnv :: Config -> Sh ContainerEnv 173 getContainerEnv cfg = do 174 -- Run crane config and parse JSON 175 -- errExit False to not fail on non-zero exit 176 result <- errExit False $ run "crane" ["config", cfgImage cfg] 177 code <- lastExitCode 178 if code == 0 179 then case decode (BL.fromStrict $ TE.encodeUtf8 result) of 180 Just val -> pure $ parseEnvFromConfig val 181 Nothing -> pure emptyEnv 182 else pure emptyEnv 183 where 184 parseEnvFromConfig :: Value -> ContainerEnv 185 parseEnvFromConfig val = fromMaybe emptyEnv $ parseMaybe parseConfig val 186 187 parseConfig :: Value -> Aeson.Parser ContainerEnv 188 parseConfig (Object obj) = do 189 mConfig <- obj .:? "config" 190 case mConfig of 191 Just (Object cfgObj) -> do 192 mEnvList <- cfgObj .:? "Env" 193 case mEnvList of 194 Just (Array arr) -> do 195 let envPairs = map extractEnvVar (V.toList arr) 196 pure 197 ContainerEnv 198 { envPath = lookup "PATH" envPairs 199 , envLdLibPath = lookup "LD_LIBRARY_PATH" envPairs 200 } 201 _ -> pure emptyEnv 202 _ -> pure emptyEnv 203 parseConfig _ = pure emptyEnv 204 205 extractEnvVar :: Value -> (Text, Text) 206 extractEnvVar (String str) = 207 let (k, v) = T.breakOn "=" str 208 in (k, T.drop 1 v) 209 extractEnvVar _ = ("", "") 210 211 -- ============================================================================ 212 -- NVIDIA Discovery 213 -- ============================================================================ 214 215 discoverNvidiaBinds :: Sh [Text] 216 discoverNvidiaBinds = do 217 devBinds <- discoverDevices 218 driverBinds <- discoverDriver 219 glBinds <- discoverOpenGL 220 let nixBind = ["--ro-bind", "/nix/store", "/nix/store"] 221 pure $ devBinds <> driverBinds <> glBinds <> nixBind 222 223 discoverDevices :: Sh [Text] 224 discoverDevices = do 225 -- Find /dev/nvidia* devices 226 -- Shelly's findWhen takes a predicate on FilePath 227 allDevs <- ls "/dev" 228 let nvDevs = filter (T.isPrefixOf "nvidia" . T.pack . takeFileName . T.unpack . toTextIgnore) allDevs 229 230 -- Find /dev/dri/* devices 231 driExists <- test_d "/dev/dri" 232 driDevs <- 233 if driExists 234 then ls "/dev/dri" 235 else pure [] 236 237 let allBindDevs = nvDevs <> driDevs 238 binds = concatMap (\dev -> ["--dev-bind", toTextIgnore dev, toTextIgnore dev]) allBindDevs 239 pure binds 240 241 discoverDriver :: Sh [Text] 242 discoverDriver = do 243 -- Find nvidia driver path via nvidia-smi symlink 244 result <- errExit False $ run "readlink" ["-f", "/run/current-system/sw/bin/nvidia-smi"] 245 code <- lastExitCode 246 if code == 0 247 then do 248 let driverPath = T.replace "/bin/nvidia-smi" "" (T.strip result) 249 driverDir = fromText driverPath 250 251 exists <- test_d driverDir 252 if exists 253 then do 254 echo $ ":: Found nvidia driver at " <> driverPath 255 binBind <- ifDirExists (driverDir </> "bin") "/usr/local/nvidia/bin" 256 libBind <- ifDirExists (driverDir </> "lib") "/usr/local/nvidia/lib64" 257 pure $ binBind <> libBind 258 else pure [] 259 else pure [] 260 where 261 ifDirExists :: FilePath -> Text -> Sh [Text] 262 ifDirExists src dst = do 263 exists <- test_d src 264 if exists 265 then pure ["--ro-bind", toTextIgnore src, dst] 266 else pure [] 267 268 discoverOpenGL :: Sh [Text] 269 discoverOpenGL = do 270 -- Bind opengl drivers if present 271 gl1 <- glBind "/run/opengl-driver" 272 gl2 <- glBind "/run/opengl-driver-32" 273 pure $ gl1 <> gl2 274 where 275 glBind :: FilePath -> Sh [Text] 276 glBind glPath = do 277 exists <- test_d glPath 278 if exists 279 then pure ["--ro-bind", toTextIgnore glPath, toTextIgnore glPath] 280 else pure [] 281 282 -- ============================================================================ 283 -- Environment Building 284 -- ============================================================================ 285 286 buildPath :: ContainerEnv -> Text 287 buildPath cenv = 288 let base = "/usr/local/nvidia/bin" 289 defaultPath = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" 290 in case envPath cenv of 291 Just p -> base <> ":" <> p 292 Nothing -> base <> ":" <> defaultPath 293 294 buildLdPath :: ContainerEnv -> Text 295 buildLdPath cenv = 296 let base = "/usr/local/nvidia/lib64:/run/opengl-driver/lib" 297 in case envLdLibPath cenv of 298 Just p -> base <> ":" <> p 299 Nothing -> base 300 301 -- ============================================================================ 302 -- Bwrap Execution 303 -- ============================================================================ 304 305 buildBwrapArgs :: FilePath -> [Text] -> Text -> Text -> [Text] -> [Text] 306 buildBwrapArgs workDir nvBinds envPath ldPath cmd' = 307 [ "--bind" 308 , toTextIgnore (workDir </> "rootfs") 309 , "/" 310 , "--dev" 311 , "/dev" 312 , "--proc" 313 , "/proc" 314 , "--ro-bind" 315 , "/sys" 316 , "/sys" 317 , "--tmpfs" 318 , "/tmp" 319 , "--tmpfs" 320 , "/run" 321 ] 322 <> nvBinds 323 <> [ "--ro-bind" 324 , "/etc/resolv.conf" 325 , "/etc/resolv.conf" 326 , "--ro-bind" 327 , "/etc/ssl" 328 , "/etc/ssl" 329 , "--setenv" 330 , "PATH" 331 , envPath 332 , "--setenv" 333 , "HOME" 334 , "/root" 335 , "--setenv" 336 , "LD_LIBRARY_PATH" 337 , ldPath 338 , "--setenv" 339 , "OPAL_PREFIX" 340 , "/opt/hpcx/ompi" 341 , "--setenv" 342 , "OMPI_MCA_btl" 343 , "^openib" 344 , "--chdir" 345 , "/root" 346 , "--die-with-parent" 347 , "--unshare-pid" 348 , "--" 349 ] 350 <> cmd'