/ code / seperate_slices.py
seperate_slices.py
 1  import nibabel as nib
 2  import numpy as np
 3  import re
 4  import os
 5  from pathlib import Path
 6  
 7  def determine_middle_slice(input_path):
 8      """
 9      Determines the middle slice of a 3D image volume.
10  
11      Parameters:
12      input_path (str): The path to the input image volume.
13  
14      Returns:
15      int: The index of the middle slice.
16  
17      """
18      _, _, slice, _ = nib.load(input_path).get_fdata().shape
19      slice = int(np.round(slice/2))
20      return slice
21  
22  def get_specific_slice(image_path, mask_path, save_loc, slice_dict=False):
23      """
24      Extracts a specific slice from a 4D image volume and saves it as a 3D NIfTI file.
25  
26      Args:
27          image_path (str): The path to the 4D image volume.
28          mask_path (str): The path to the 4D mask volume.
29          save_loc (str): The directory where the extracted slices will be saved.
30          slice_dict (dict, optional): A dictionary mapping keywords to specific slice indices.
31              If provided, the function will search for a keyword in the image path and use the corresponding slice index.
32              If not provided, the function will determine the middle slice of the image volume.
33  
34      Raises:
35          AssertionError: If the slice is not an integer.
36  
37      """
38      slice = False
39      image_path = Path(image_path)
40      mask_path = Path(mask_path)
41      if type(slice_dict) != dict:
42          slice = determine_middle_slice(image_path)
43      else:
44          for key, value in slice_dict.items():
45              if re.search(rf"{key}[/_]", str(image_path)):
46                  slice = value
47                  break
48      if type(slice) != int:
49          assert("No slice in dict")
50      image_nim = nib.load(image_path)
51      mask_nim = nib.load(mask_path)
52      image_data = image_nim.get_fdata()
53      mask_data = mask_nim.get_fdata()
54      mask_data = (mask_data == 2).astype('<f8')
55      _, _, _, t1 = image_data.shape
56      _, _, _, t2 = mask_data.shape
57      for t in range(t1):
58          image_data_mod = image_data[:, :, slice, t]
59          image_data_mod = np.expand_dims(image_data_mod, axis=2)
60          image_data_mod = nib.Nifti1Image(image_data_mod, np.eye(4))
61          if not os.path.isdir(f"{save_loc}/{image_path.parent.name}/{t}"):
62              os.makedirs(f"{save_loc}/{image_path.parent.name}/{t}")
63          nib.save(image_data_mod, f"{save_loc}/{image_path.parent.name}/{t}/image.nii.gz")
64      for t in range(t2):
65          mask_data_mod = mask_data[:, :, slice, t]
66          mask_data_mod = np.expand_dims(mask_data_mod, axis=2)
67          mask_data_mod = nib.Nifti1Image(mask_data_mod, np.eye(4))
68          nib.save(mask_data_mod, f"{save_loc}/{image_path.parent.name}/{t}/segmentation.nii.gz")
69  
70  get_specific_slice(snakemake.input[0], snakemake.input[1], Path(snakemake.output[0]).parent, snakemake.params[0])