seperate_slices.py
1 import nibabel as nib 2 import numpy as np 3 import re 4 import os 5 from pathlib import Path 6 7 def determine_middle_slice(input_path): 8 """ 9 Determines the middle slice of a 3D image volume. 10 11 Parameters: 12 input_path (str): The path to the input image volume. 13 14 Returns: 15 int: The index of the middle slice. 16 17 """ 18 _, _, slice, _ = nib.load(input_path).get_fdata().shape 19 slice = int(np.round(slice/2)) 20 return slice 21 22 def get_specific_slice(image_path, mask_path, save_loc, slice_dict=False): 23 """ 24 Extracts a specific slice from a 4D image volume and saves it as a 3D NIfTI file. 25 26 Args: 27 image_path (str): The path to the 4D image volume. 28 mask_path (str): The path to the 4D mask volume. 29 save_loc (str): The directory where the extracted slices will be saved. 30 slice_dict (dict, optional): A dictionary mapping keywords to specific slice indices. 31 If provided, the function will search for a keyword in the image path and use the corresponding slice index. 32 If not provided, the function will determine the middle slice of the image volume. 33 34 Raises: 35 AssertionError: If the slice is not an integer. 36 37 """ 38 slice = False 39 image_path = Path(image_path) 40 mask_path = Path(mask_path) 41 if type(slice_dict) != dict: 42 slice = determine_middle_slice(image_path) 43 else: 44 for key, value in slice_dict.items(): 45 if re.search(rf"{key}[/_]", str(image_path)): 46 slice = value 47 break 48 if type(slice) != int: 49 assert("No slice in dict") 50 image_nim = nib.load(image_path) 51 mask_nim = nib.load(mask_path) 52 image_data = image_nim.get_fdata() 53 mask_data = mask_nim.get_fdata() 54 mask_data = (mask_data == 2).astype('<f8') 55 _, _, _, t1 = image_data.shape 56 _, _, _, t2 = mask_data.shape 57 for t in range(t1): 58 image_data_mod = image_data[:, :, slice, t] 59 image_data_mod = np.expand_dims(image_data_mod, axis=2) 60 image_data_mod = nib.Nifti1Image(image_data_mod, np.eye(4)) 61 if not os.path.isdir(f"{save_loc}/{image_path.parent.name}/{t}"): 62 os.makedirs(f"{save_loc}/{image_path.parent.name}/{t}") 63 nib.save(image_data_mod, f"{save_loc}/{image_path.parent.name}/{t}/image.nii.gz") 64 for t in range(t2): 65 mask_data_mod = mask_data[:, :, slice, t] 66 mask_data_mod = np.expand_dims(mask_data_mod, axis=2) 67 mask_data_mod = nib.Nifti1Image(mask_data_mod, np.eye(4)) 68 nib.save(mask_data_mod, f"{save_loc}/{image_path.parent.name}/{t}/segmentation.nii.gz") 69 70 get_specific_slice(snakemake.input[0], snakemake.input[1], Path(snakemake.output[0]).parent, snakemake.params[0])