Skip to content

Commit

Permalink
Added dataset class which can load from nrrds or directly from the da…
Browse files Browse the repository at this point in the history
…taset and convert to pytorch dataset

Former-commit-id: 1cd8984
  • Loading branch information
Vishwesh4 committed Dec 9, 2021
1 parent f3ba678 commit 21a8546
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 20 deletions.
6 changes: 3 additions & 3 deletions imgtools/autopipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def process_one_subject(self, subject_id):
counter[modality] = counter[modality]+1
self.output(f"{subject_id}_{counter[modality]}", doses, output_stream)
metadata[f"size_{output_stream}"] = str(doses.GetSize())
metadata[f"metadata_{output_stream}"] = str(read_results[i].get_metadata())
metadata[f"metadata_{colname}"] = [read_results[i].get_metadata()]
print(subject_id, " SAVED DOSE")
elif modality == "RTSTRUCT":
#For RTSTRUCT, you need image or PT
Expand All @@ -152,7 +152,7 @@ def process_one_subject(self, subject_id):
else:
counter[modality] = counter[modality] + 1
self.output(f"{subject_id}_{counter[modality]}", mask, output_stream)
metadata[f"roi_names_{output_stream}"] = str(structure_set.roi_names)
metadata[f"metadata_{colname}"] = [structure_set.roi_names]

print(subject_id, "SAVED MASK ON", conn_to)
elif modality == "PT":
Expand All @@ -169,7 +169,7 @@ def process_one_subject(self, subject_id):
counter[modality] = counter[modality] + 1
self.output(f"{subject_id}_{counter[modality]}", pet, output_stream)
metadata[f"size_{output_stream}"] = str(pet.GetSize())
metadata[f"metadata_{output_stream}"] = str(read_results[i].get_metadata())
metadata[f"metadata_{colname}"] = [read_results[i].get_metadata()]
print(subject_id, " SAVED PET")
#Saving all the metadata in multiple text files
with open(os.path.join(self.output_directory,f'temp_{subject_id}.txt'),'w') as f:
Expand Down
1 change: 1 addition & 0 deletions imgtools/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .common import *
from .loaders import *
from .writers import *
from .dataset import *
15 changes: 15 additions & 0 deletions imgtools/io/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Dict

from pydicom.misc import is_dicom

Expand Down Expand Up @@ -34,3 +35,17 @@ def find_dicom_paths(root_path: str, yield_directories: bool = False) -> str:
fpath = os.path.join(root, f)
if is_dicom(fpath):
yield fpath

def file_name_convention() -> Dict:
"""
This function returns the file name taxonomy which is used by ImageAutoOutput and Dataset class
"""
file_name_convention = {"CT": "image",
"RTDOSE_CT": "dose",
"RTSTRUCT_CT": "mask_ct.seg",
"RTSTRUCT_PT": "mask_pt.seg",
"PT_CT": "pet",
"PT": "pet",
"RTDOSE": "dose",
"RTSTRUCT": "mask.seg"}
return file_name_convention
183 changes: 183 additions & 0 deletions imgtools/io/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from genericpath import exists
import os
import numpy as np
from typing import List, Sequence, Optional, Callable, Iterable, Dict,Tuple
import torchio as tio
import pandas as pd
from . import file_name_convention
from ..ops import StructureSetToSegmentation, ImageAutoInput, Resample, BaseOp
# from imgtools.io import file_name_convention
# from imgtools.ops import StructureSetToSegmentation, ImageAutoInput, Resample, BaseOp
from tqdm import tqdm
from joblib import Parallel, delayed
import SimpleITK as sitk
import warnings
from imgtools.pipeline import Pipeline

class Dataset(tio.SubjectsDataset):
"""
This class takes in medical dataset in the form of nrrds or directly from the dataset and converts the data into torchio.Subject object, which can be loaded into
torchio.SubjectDataset object.
This class inherits from torchio.SubjectDataset object, which can support transforms and torch.Dataloader.
Read more about torchio from https://torchio.readthedocs.io/quickstart.html and torchio.SubjectDataset from https://github.com/fepegar/torchio/blob/3e07b78da16d6db4da7193325b3f9cb31fc0911a/torchio/data/dataset.py#L101
"""
def __init__(
self,
subjects: Sequence[tio.Subject],
transform: Optional[Callable] = None,
load_getitem: bool = True
) -> tio.SubjectsDataset:
super().__init__(subjects,transform,load_getitem)

@classmethod
def load_from_nrrd(
cls,
path:str,
transform: Optional[Callable] = None,
load_getitem: bool = True
) -> List[tio.Subject]:
"""
Based on the given path, passess the processed nrrd files present in the directory and the metadata associated with it and creates a list of Subject instances
Parameters
path: Path to the output directory passed to the autopipeline script. The output directory should have all the user mentioned modalities processed and present in their folder. The directory
should additionally have dataset.csv which stores all the metadata
"""
path_metadata = os.path.join(path,"dataset.csv")
if not os.path.exists(path_metadata):
raise ValueError("The specified path has no file name {}".format(path_metadata))
df_metadata = pd.read_csv(path_metadata,index_col=0)
output_streams = [("_").join(cols.split("_")[1:]) for cols in df_metadata.columns if cols.split("_")[0]=="folder"]
imp_metadata = [cols for cols in df_metadata.columns if cols.split("_")[0] in ("metadata")]
#Based on the file naming taxonomy
file_names = file_name_convention()
subject_id_list = list(df_metadata.index)
subjects = []
for subject_id in tqdm(subject_id_list):
temp = {}
for col in output_streams:
extension = file_names[col]
mult_conn = col.split("_")[-1] == "1"
metadata_name = f"metadata_{col}"
if mult_conn:
extra = str(col.split("_").count("1"))+"_"
else:
extra = ""
path_mod = os.path.join(path,extension.split(".")[0],f"{subject_id}_{extra}{extension}.nrrd")
#All modalities except RTSTRUCT should be of type torchIO.ScalarImage
if col!="RTSTRUCT":
temp[f"mod_{col}"] = tio.ScalarImage(path_mod)
else:
temp[f"mod_{col}"] = tio.LabelImage(path_mod)
#For including metadata
if metadata_name in imp_metadata:
#convert string to proper datatype
temp[metadata_name] = df_metadata.loc[subject_id,metadata_name][0]
subjects.append(tio.Subject(temp))
return cls(subjects,transform,load_getitem)

@classmethod
def load_directly(
cls,
path:str,
modalities: str,
n_jobs: int = -1,
spacing: Tuple = (1., 1., 0.),
transform: Optional[Callable] = None,
load_getitem: bool = True
) -> List[tio.Subject]:
"""
Based on the given path, imgtools crawls through the directory, forms datagraph and picks the user defined modalities. These paths are processed into sitk.Image.
This image and the metadata associated with it, creates a list of Subject instances
Parameters
path: Path to the directory of the dataset
"""
input = ImageAutoInput(path, modalities, n_jobs)
df_metadata = input.df_combined
output_streams = input.output_streams
#Basic operations
subject_id_list = list(df_metadata.index)
# basic image processing ops
resample = Resample(spacing=spacing)
make_binary_mask = StructureSetToSegmentation(roi_names=[], continuous=False)
subjects = Parallel(n_jobs=n_jobs)(delayed(cls.process_one_subject)(input,subject_id,output_streams,resample,make_binary_mask) for subject_id in tqdm(subject_id_list))
return cls(subjects,transform,load_getitem)

@staticmethod
def process_one_subject(
input: Pipeline,
subject_id: str,
output_streams: List[str],
resample: BaseOp,
make_binary_mask: BaseOp,
) -> tio.Subject:
"""
Process all modalities for one subject
Parameters:
input: ImageAutoInput class which helps in loading the respective DICOMs
subject_id: subject id of the data
output_streams: the modalities that are being considered, Note that there can be multiple items of same modality based on their relations with different modalities
resample: transformation which resamples sitk.Image
make_binary_mask: transformation useful in making binary mask for rtstructs
Returns tio.Subject instance for a particular subject id
"""
temp = {}
read_results = input(subject_id)
for i,colname in enumerate(output_streams):
modality = colname.split("_")[0]
output_stream = ("_").join([item for item in colname.split("_") if item != "1"])

if read_results[i] is None:
pass
elif modality == "CT":
image = read_results[i]
if len(image.GetSize()) == 4:
assert image.GetSize()[-1] == 1, f"There is more than one volume in this CT file for {subject_id}."
extractor = sitk.ExtractImageFilter()
extractor.SetSize([*image.GetSize()[:3], 0])
extractor.SetIndex([0, 0, 0, 0])
image = extractor.Execute(image)
image = resample(image)
temp[f"mod_{colname}"] = tio.ScalarImage.from_sitk(image)
elif modality == "RTDOSE":
try: #For cases with no image present
doses = read_results[i].resample_dose(image)
except:
Warning("No CT image present. Returning dose image without resampling")
doses = read_results[i]
temp[f"mod_{colname}"] = tio.ScalarImage.from_sitk(doses)
temp[f"metadata_{colname}"] = str(read_results[i].get_metadata())
elif modality == "RTSTRUCT":
#For RTSTRUCT, you need image or PT
structure_set = read_results[i]
conn_to = output_stream.split("_")[-1]
# make_binary_mask relative to ct/pet
if conn_to == "CT":
mask = make_binary_mask(structure_set, image)
elif conn_to == "PT":
mask = make_binary_mask(structure_set, pet)
else:
raise ValueError("You need to pass a reference CT or PT/PET image to map contours to.")
temp[f"mod_{colname}"] = tio.LabelMap.from_sitk(mask)
temp[f"metadata_{colname}"] = structure_set.roi_names
elif modality == "PT":
try:
#For cases with no image present
pet = read_results[i].resample_pet(image)
except:
Warning("No CT image present. Returning PT/PET image without resampling.")
pet = read_results[i]
temp[f"mod_{colname}"] = tio.ScalarImage.from_sitk(pet)
temp[f"metadata_{colname}"] = read_results[i].get_metadata()
return tio.Subject(temp)

if __name__=="__main__":
from torch.utils.data import DataLoader
# output_path = "/cluster/projects/radiomics/Temp/vishwesh/HN-CT_RTdose_test2"
input_path = "/cluster/home/ramanav/imgtools/examples/data_test"
transform = tio.Compose([tio.Resize(256)])
# subjects_dataset = Dataset.load_from_nrrd(output_path,transform=transform)
subjects_dataset = Dataset.load_directly(input_path,modalities="CT,RTDOSE,PT",n_jobs=4,transform=transform)
print(len(subjects_dataset))
training_loader = DataLoader(subjects_dataset, batch_size=4)
items = next(iter(training_loader))
print(items["mod_RTDOSE_CT"])
26 changes: 17 additions & 9 deletions imgtools/modules/datagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,10 @@ def _get_df(self,
temp[temp_dfconn.iloc[k,0]]["modality"] = temp_dfconn.iloc[k,1]
temp[temp_dfconn.iloc[k,0]]["folder"] = temp_dfconn.iloc[k,2]
temp[temp_dfconn.iloc[k,0]]["conn_to"] = "CT"
folder_save[f"series_{temp_dfconn.iloc[k,1]}_CT"] = temp_dfconn.iloc[k,0]
folder_save[f"folder_{temp_dfconn.iloc[k,1]}_CT"] = temp_dfconn.iloc[k,2]
#Checks if there is already existing connection
key,key_series = self._check_save(folder_save,temp_dfconn.iloc[k,1],"CT")
folder_save[key_series] = temp_dfconn.iloc[k,0]
folder_save[key] = temp_dfconn.iloc[k,2]
A.append(temp)
save_folder_comp.append(folder_save)
#For rest of the edges left out, the connections are formed by going through the dictionary. For cases such as RTstruct-RTDose and PET-RTstruct
Expand All @@ -377,15 +379,12 @@ def _get_df(self,
A[k][rest_locs.iloc[j,3]]["conn_to"] = rest_locs.iloc[j,1]
if rest_locs.iloc[j,4]=="RTDOSE":
#RTDOSE is connected via either RTstruct or/and CT, but we usually don't care, so naming it commonly
save_folder_comp[k][f"series_{rest_locs.iloc[j,4]}_CT"] = rest_locs.iloc[j,3]
save_folder_comp[k][f"folder_{rest_locs.iloc[j,4]}_CT"] = rest_locs.iloc[j,5]
key,key_series = self._check_save(save_folder_comp[k],rest_locs.iloc[j,4],"CT")
save_folder_comp[k][key_series] = rest_locs.iloc[j,3]
save_folder_comp[k][key] = rest_locs.iloc[j,5]
else: #Cases such as RTSTRUCT-PT
key = "folder_{}_{}".format(rest_locs.iloc[j,4], rest_locs.iloc[j,1])
key_series = "series_{}_{}".format(rest_locs.iloc[j,4], rest_locs.iloc[j,1])
#if there is already a connection and one more same category modality wants to connect
if key in save_folder_comp[k].keys():
key = key + "_1"
key_series = key_series + "_1"
key,key_series = self._check_save(save_folder_comp[k],rest_locs.iloc[j,4],rest_locs.iloc[j,1])
save_folder_comp[k][key_series] = rest_locs.iloc[j,3]
save_folder_comp[k][key] = rest_locs.iloc[j,5]
flag = 0
Expand All @@ -410,6 +409,15 @@ def _get_df(self,
final_df = pd.DataFrame(final_df)
return final_df

@staticmethod
def _check_save(save_dict,node,dest):
key = f"folder_{node}_{dest}"
key_series = f"series_{node}_{dest}"
if key in save_dict.keys():
key = key + "_1"
key_series = key_series + "_1"
return key,key_series

@staticmethod
def list_edges(series):
return reduce(lambda x, y:str(x) + str(y), series)
Expand Down
9 changes: 1 addition & 8 deletions imgtools/ops/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,7 @@ def __init__(self,
output_streams: List[str]):

# File types
self.file_name = {"CT": "image",
"RTDOSE_CT": "dose",
"RTSTRUCT_CT": "mask_ct.seg",
"RTSTRUCT_PT": "mask_pt.seg",
"PT_CT": "pet",
"PT": "pet",
"RTDOSE": "dose",
"RTSTRUCT": "mask.seg"}
self.file_name = file_name_convention()

self.output = {}
for colname in output_streams:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ pytest
scikit-image
SimpleITK
tqdm
torchio

0 comments on commit 21a8546

Please sign in to comment.