/ nix / script / exe / nvidia-extract.hs
nvidia-extract.hs
  1  {-# LANGUAGE OverloadedStrings #-}
  2  {-# LANGUAGE RecordWildCards #-}
  3  
  4  -- \|
  5  -- nvidia-extract - Extract NVIDIA SDK components from NGC container
  6  --
  7  -- Usage:
  8  --   nvidia-extract <image-ref> <output-dir>
  9  --
 10  -- Example:
 11  --   nvidia-extract nvcr.io/nvidia/tritonserver:25.11-py3 ./nvidia-sdk
 12  --
 13  -- Extracts:
 14  --   - CUDA toolkit (nvcc, libraries)
 15  --   - cuDNN
 16  --   - NCCL
 17  --   - TensorRT
 18  --   - cuTensor
 19  --
 20  -- Version info is parsed from container environment variables:
 21  --   CUDA_VERSION, CUDNN_VERSION, NCCL_VERSION, TENSORRT_VERSION
 22  --
 23  -- The NGC containers have blessed, tested configurations.
 24  -- No more fighting nvidia's download auth.
 25  
 26  import Aleph.Script
 27  import qualified Aleph.Script.Tools.Crane as Crane
 28  import Control.Monad (forM_, when)
 29  import qualified Control.Monad as M
 30  import Data.Aeson (Object, Value (..), decode)
 31  import qualified Data.Aeson.KeyMap as KM
 32  import qualified Data.ByteString.Lazy as BL
 33  import Data.Maybe (fromMaybe, mapMaybe)
 34  import qualified Data.Text as T
 35  import qualified Data.Text.Encoding as TE
 36  import qualified Data.Text.IO as TIO
 37  import qualified Data.Vector as V
 38  import System.Environment (getArgs)
 39  import Prelude hiding (FilePath)
 40  
 41  -- | Version info extracted from container
 42  data NvidiaVersions = NvidiaVersions
 43      { nvCuda :: Text
 44      , nvCudnn :: Text
 45      , nvNccl :: Text
 46      , nvTensorrt :: Text
 47      , nvCutensor :: Text
 48      }
 49      deriving (Show)
 50  
 51  -- | Libraries we care about
 52  targetLibs :: [Text]
 53  targetLibs =
 54      [ "libcudart"
 55      , "libcublas"
 56      , "libcufft"
 57      , "libcurand"
 58      , "libcusolver"
 59      , "libcusparse"
 60      , "libnvrtc"
 61      , "libcudnn"
 62      , "libnccl"
 63      , "libnvinfer"
 64      , "libcutensor"
 65      ]
 66  
 67  main :: IO ()
 68  main = script $ verbosely $ do
 69      args <- liftIO getArgs
 70      case args of
 71          [imageRef, outputDir] -> do
 72              extractNvidiaSdk (pack imageRef) (fromText $ pack outputDir)
 73          _ -> do
 74              echoErr "Usage: nvidia-extract <image-ref> <output-dir>"
 75              echoErr ""
 76              echoErr "Example:"
 77              echoErr "  nvidia-extract nvcr.io/nvidia/tritonserver:25.11-py3 ./nvidia-sdk"
 78              exit 1
 79  
 80  extractNvidiaSdk :: Text -> FilePath -> Sh ()
 81  extractNvidiaSdk imageRef outputDir = do
 82      echoErr $ ":: Extracting NVIDIA SDK from " <> imageRef
 83  
 84      -- Get version info from container config before extraction
 85      echoErr ":: Reading container config..."
 86      versions <- getContainerVersions imageRef
 87  
 88      -- Create temp dir for full container
 89      withTmpDir $ \tmpDir -> do
 90          let containerRoot = tmpDir </> "rootfs"
 91  
 92          -- Pull and extract container
 93          echoErr ":: Pulling container..."
 94          Crane.exportToDir Crane.defaults imageRef containerRoot
 95  
 96          -- Create output structure
 97          echoErr ":: Creating SDK layout..."
 98          mkdirP (outputDir </> "bin")
 99          mkdirP (outputDir </> "lib64")
100          mkdirP (outputDir </> "include")
101          mkdirP (outputDir </> "nvvm")
102  
103          -- Extract CUDA toolkit
104          -- Look for versioned cuda dir (e.g., cuda-13.0) since symlinks may not resolve
105          cudaDir <- findCudaDir containerRoot (nvCuda versions)
106          case cudaDir of
107              Nothing -> echoErr ":: Warning: CUDA directory not found"
108              Just dir -> do
109                  echoErr $ ":: Extracting CUDA toolkit " <> nvCuda versions <> " from " <> toTextIgnore dir <> "..."
110                  -- bin (nvcc, etc)
111                  copyDir (dir </> "bin") (outputDir </> "bin")
112                  -- lib64 - try both direct and targets structure
113                  let libDir = dir </> "lib64"
114                      targetsLibDir = dir </> "targets/x86_64-linux/lib"
115                  hasLib64 <- test_d libDir
116                  hasTargetsLib <- test_d targetsLibDir
117                  when hasLib64 $ copyDir libDir (outputDir </> "lib64")
118                  when hasTargetsLib $ copyDir targetsLibDir (outputDir </> "lib64")
119                  -- include - try both direct and targets structure
120                  let incDir = dir </> "include"
121                      targetsIncDir = dir </> "targets/x86_64-linux/include"
122                  hasInc <- test_d incDir
123                  hasTargetsInc <- test_d targetsIncDir
124                  when hasInc $ copyDir incDir (outputDir </> "include")
125                  when hasTargetsInc $ copyDir targetsIncDir (outputDir </> "include")
126                  -- nvvm (for nvcc)
127                  copyDir (dir </> "nvvm") (outputDir </> "nvvm")
128  
129          -- Extract system libraries (cuDNN, NCCL, TensorRT)
130          let sysLibDir = containerRoot </> "usr/lib/x86_64-linux-gnu"
131          hasSysLib <- test_d sysLibDir
132          when hasSysLib $ do
133              echoErr ":: Extracting cuDNN, NCCL, TensorRT..."
134              forM_ targetLibs $ \lib -> do
135                  -- Find and copy matching libraries
136                  libs <- findLibs sysLibDir lib
137                  forM_ libs $ \libPath -> do
138                      cp libPath (outputDir </> "lib64" </> filename libPath)
139  
140          -- Extract TensorRT-LLM backend if present
141          let trtLlmDir = containerRoot </> "opt/tritonserver/backends/tensorrtllm"
142          hasTrtLlm <- test_d trtLlmDir
143          when hasTrtLlm $ do
144              echoErr ":: Extracting TensorRT-LLM backend..."
145              mkdirP (outputDir </> "backends")
146              copyDir trtLlmDir (outputDir </> "backends/tensorrtllm")
147  
148          -- Write version info
149          echoErr ":: Writing version info..."
150          writeVersionInfo outputDir versions
151  
152          -- Fix ELF binaries (RPATHs and interpreter)
153          echoErr ":: Patching ELF binaries..."
154          patchElf outputDir
155  
156          echoErr $ ":: Done! SDK extracted to " <> toTextIgnore outputDir
157          echoErr $
158              ":: CUDA "
159                  <> nvCuda versions
160                  <> ", cuDNN "
161                  <> nvCudnn versions
162                  <> ", NCCL "
163                  <> nvNccl versions
164  
165  -- | Get version info from container environment variables
166  getContainerVersions :: Text -> Sh NvidiaVersions
167  getContainerVersions imageRef = do
168      -- Get container config JSON
169      configJson <- Crane.config imageRef
170  
171      -- Parse JSON and extract Env array
172      let envPairs = case decode (BL.fromStrict $ TE.encodeUtf8 configJson) of
173              Just (Object obj) -> extractEnv obj
174              _ -> []
175          lookup' k = fromMaybe "unknown" $ Prelude.lookup k envPairs
176  
177      pure
178          NvidiaVersions
179              { nvCuda = cleanVersion $ lookup' "CUDA_VERSION"
180              , nvCudnn = cleanVersion $ lookup' "CUDNN_VERSION"
181              , nvNccl = cleanVersion $ lookup' "NCCL_VERSION"
182              , nvTensorrt = cleanVersion $ lookup' "TENSORRT_VERSION"
183              , nvCutensor = cleanVersion $ lookup' "CUTENSOR_VERSION"
184              }
185    where
186      -- Extract .config.Env array from JSON
187      extractEnv :: Object -> [(Text, Text)]
188      extractEnv obj = case KM.lookup "config" obj of
189          Just (Object cfg) -> case KM.lookup "Env" cfg of
190              Just (Array arr) -> mapMaybe parseEnvVar (V.toList arr)
191              _ -> []
192          _ -> []
193  
194      parseEnvVar :: Value -> Maybe (Text, Text)
195      parseEnvVar (String t) =
196          let (k, rest) = T.breakOn "=" t
197           in if T.null rest
198                  then Nothing
199                  else Just (k, T.drop 1 rest)
200      parseEnvVar _ = Nothing
201  
202      -- Remove trailing -1, quotes, etc
203      cleanVersion :: Text -> Text
204      cleanVersion v =
205          let v' = T.replace "\"" "" v
206              v'' = if "-1" `T.isSuffixOf` v' then T.dropEnd 2 v' else v'
207           in if T.null v'' then "unknown" else v''
208  
209  {- | Find CUDA directory - looks for versioned dirs like cuda-13.0
210  Falls back to /usr/local/cuda symlink if available
211  -}
212  findCudaDir :: FilePath -> Text -> Sh (Maybe FilePath)
213  findCudaDir containerRoot cudaVersion = do
214      let localDir = containerRoot </> "usr/local"
215          -- Extract major.minor from version (e.g., "13.0.1" -> "13.0")
216          majorMinor = T.intercalate "." $ Prelude.take 2 $ T.splitOn "." cudaVersion
217          versionedDir = localDir </> fromText ("cuda-" <> majorMinor)
218          symlinkedDir = localDir </> "cuda"
219  
220      hasVersioned <- test_d versionedDir
221      hasSymlinked <- test_d symlinkedDir
222  
223      pure $
224          if hasVersioned
225              then Just versionedDir
226              else
227                  if hasSymlinked
228                      then Just symlinkedDir
229                      else Nothing
230  
231  -- | Find libraries matching a prefix
232  findLibs :: FilePath -> Text -> Sh [FilePath]
233  findLibs dir prefix = do
234      exists <- test_d dir
235      if exists
236          then do
237              files <- ls dir
238              pure $ filter (hasPrefix prefix . toTextIgnore . filename) files
239          else pure []
240    where
241      hasPrefix p t = T.isPrefixOf p t
242  
243  -- | Copy directory contents recursively
244  copyDir :: FilePath -> FilePath -> Sh ()
245  copyDir src dst = do
246      exists <- test_d src
247      when exists $ do
248          mkdirP dst
249          run_ "cp" ["-rL", toTextIgnore src <> "/.", toTextIgnore dst <> "/"]
250  
251  -- | Write version.json with component versions
252  writeVersionInfo :: FilePath -> NvidiaVersions -> Sh ()
253  writeVersionInfo outputDir NvidiaVersions{..} = do
254      let versionFile = outputDir </> "version.json"
255          json =
256              T.unlines
257                  [ "{"
258                  , "  \"cuda\": \"" <> nvCuda <> "\","
259                  , "  \"cudnn\": \"" <> nvCudnn <> "\","
260                  , "  \"nccl\": \"" <> nvNccl <> "\","
261                  , "  \"tensorrt\": \"" <> nvTensorrt <> "\","
262                  , "  \"cutensor\": \"" <> nvCutensor <> "\""
263                  , "}"
264                  ]
265  
266      liftIO $ TIO.writeFile (T.unpack $ toTextIgnore versionFile) json
267  
268  {- | Patch all ELF binaries in SDK directory
269  Uses find to recursively locate all ELF files, then:
270    - Sets interpreter for executables
271    - Sets RPATH relative to file location
272  -}
273  patchElf :: FilePath -> Sh ()
274  patchElf sdkDir = do
275      -- Get absolute SDK root for reliable path comparison
276      cwd <- pwd
277      let sdkRoot =
278              if "/" `T.isPrefixOf` toTextIgnore sdkDir
279                  then sdkDir
280                  else cwd </> sdkDir
281          sdkRootText = toTextIgnore sdkRoot
282  
283      echoErr $ "   Scanning " <> sdkRootText <> " for ELF files..."
284  
285      -- Find all regular files (not symlinks)
286      allFiles <- findAllFiles sdkRoot
287  
288      -- Filter to ELF files and patch each
289      patchCount <- M.foldM (patchIfElf sdkRootText) 0 allFiles
290  
291      echoErr $ "   Patched " <> pack (show patchCount) <> " ELF files"
292  
293  -- | Find all regular files recursively (not symlinks)
294  findAllFiles :: FilePath -> Sh [FilePath]
295  findAllFiles dir = do
296      exists <- test_d dir
297      if not exists
298          then pure []
299          else do
300              -- Use find command for recursive search
301              output <-
302                  errExit False $
303                      run
304                          "find"
305                          [ toTextIgnore dir
306                          , "-type"
307                          , "f"
308                          , "-not"
309                          , "-name"
310                          , "*.py"
311                          , "-not"
312                          , "-name"
313                          , "*.pyc"
314                          , "-not"
315                          , "-name"
316                          , "*.h"
317                          , "-not"
318                          , "-name"
319                          , "*.hpp"
320                          , "-not"
321                          , "-name"
322                          , "*.cuh"
323                          , "-not"
324                          , "-name"
325                          , "*.txt"
326                          , "-not"
327                          , "-name"
328                          , "*.md"
329                          , "-not"
330                          , "-name"
331                          , "*.json"
332                          ]
333              pure $ map fromText $ filter (not . T.null) $ T.lines output
334  
335  -- | Patch a file if it's ELF, return 1 if patched, 0 otherwise
336  patchIfElf :: Text -> Int -> FilePath -> Sh Int
337  patchIfElf sdkRoot count path = do
338      -- Skip symlinks
339      isLink <- test_s path
340      if isLink
341          then pure count
342          else do
343              -- Check if ELF
344              output <- errExit False $ run "file" ["-b", toTextIgnore path]
345              if not ("ELF" `T.isInfixOf` output)
346                  then pure count
347                  else do
348                      -- Determine if executable or shared library
349                      let isExe = "executable" `T.isInfixOf` output
350  
351                      -- Get the directory containing this file, relative to SDK root
352                      let pathText = toTextIgnore path
353                          relToRoot = calculateRelPath sdkRoot pathText
354  
355                      -- Build RPATH with multiple search paths
356                      let rpath =
357                              T.intercalate ":" $
358                                  [ "$ORIGIN"
359                                  , "$ORIGIN/" <> relToRoot <> "/lib64"
360                                  , "$ORIGIN/" <> relToRoot <> "/lib"
361                                  , "$ORIGIN/" <> relToRoot <> "/nvvm/lib64"
362                                  ]
363  
364                      -- Patch RPATH
365                      errExit False $ run_ "patchelf" ["--set-rpath", rpath, pathText]
366  
367                      -- Set interpreter for executables (not shared objects)
368                      when isExe $ do
369                          -- Get the dynamic linker path
370                          -- For now, use a standard path that works on most Linux
371                          errExit False $
372                              run_
373                                  "patchelf"
374                                  [ "--set-interpreter"
375                                  , "/lib64/ld-linux-x86-64.so.2"
376                                  , pathText
377                                  ]
378  
379                      pure (count + 1)
380  
381  {- | Calculate relative path from file's directory back to SDK root
382  Takes the SDK root and the absolute file path
383  e.g., sdkRoot="/tmp/sdk", path="/tmp/sdk/bin/nvcc" -> ".."
384        sdkRoot="/tmp/sdk", path="/tmp/sdk/lib64/libfoo.so" -> ".."
385        sdkRoot="/tmp/sdk", path="/tmp/sdk/nvvm/lib64/libnvvm.so" -> "../.."
386  -}
387  calculateRelPath :: Text -> Text -> Text
388  calculateRelPath sdkRoot filePath =
389      -- Get the file's directory by removing the filename
390      let fileDir = T.dropWhileEnd (/= '/') filePath
391          -- Remove SDK root prefix to get relative path within SDK
392          relPath = fromMaybe fileDir $ T.stripPrefix (sdkRoot <> "/") fileDir
393          -- Count directory depth (number of path separators)
394          depth = Prelude.length $ filter (== '/') $ T.unpack relPath
395       in if depth == 0
396              then "."
397              else T.intercalate "/" $ Prelude.replicate depth ".."