Skip to content

Commit

Permalink
Changed dataset class returns
Browse files Browse the repository at this point in the history
Former-commit-id: 3c6d7f5
  • Loading branch information
Vishwesh4 committed Dec 13, 2021
1 parent 3b3ca59 commit 8a66def
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 88 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,5 @@ examples/process_one.py

tests/temp_folder*
examples/data_test
data
data/
demo.py
57 changes: 0 additions & 57 deletions demo.py

This file was deleted.

42 changes: 12 additions & 30 deletions imgtools/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from typing import List, Sequence, Optional, Callable, Iterable, Dict,Tuple
import torchio as tio
import pandas as pd
# from . import file_name_convention
# from ..ops import StructureSetToSegmentation, ImageAutoInput, Resample, BaseOp
from imgtools.io import file_name_convention
from imgtools.ops import StructureSetToSegmentation, ImageAutoInput, Resample, BaseOp
from tqdm import tqdm
Expand All @@ -15,28 +13,26 @@
from imgtools.pipeline import Pipeline
import datetime

class Dataset(tio.SubjectsDataset):
class Dataset(List):
"""
This class takes in medical dataset in the form of nrrds or directly from the dataset and converts the data into torchio.Subject object, which can be loaded into
torchio.SubjectDataset object.
This class inherits from torchio.SubjectDataset object, which can support transforms and torch.Dataloader.
This class takes in medical dataset in the form of nrrds or directly from the dataset and converts the data into list of torchio.Subject object, which can be later loaded into
torchio.SubjectDataset object separately
Read more about torchio from https://torchio.readthedocs.io/quickstart.html and torchio.SubjectDataset from https://github.com/fepegar/torchio/blob/3e07b78da16d6db4da7193325b3f9cb31fc0911a/torchio/data/dataset.py#L101
"""
def __init__(
self,
subjects: Sequence[tio.Subject],
transform: Optional[Callable] = None,
load_getitem: bool = True
) -> tio.SubjectsDataset:
super().__init__(subjects,transform,load_getitem)

path: str,
) -> List[tio.Subject]:
super().__init__(subjects)
self.subjects = subjects
self.path = path

@classmethod
def load_from_nrrd(
cls,
path:str,
transform: Optional[Callable] = None,
ignore_multi: bool = True,
load_getitem: bool = True
) -> List[tio.Subject]:
"""
Based on the given path, passess the processed nrrd files present in the directory and the metadata associated with it and creates a list of Subject instances
Expand Down Expand Up @@ -88,7 +84,7 @@ def load_from_nrrd(
#torch dataloader doesnt accept None type
temp[metadata_name] = {}
subjects.append(tio.Subject(temp))
return cls(subjects,transform,load_getitem)
return cls(subjects,path)

@classmethod
def load_directly(
Expand All @@ -97,9 +93,7 @@ def load_directly(
modalities: str,
n_jobs: int = -1,
spacing: Tuple = (1., 1., 0.),
transform: Optional[Callable] = None,
ignore_multi: bool = True,
load_getitem: bool = True
) -> List[tio.Subject]:
"""
Based on the given path, imgtools crawls through the directory, forms datagraph and picks the user defined modalities. These paths are processed into sitk.Image.
Expand All @@ -119,7 +113,7 @@ def load_directly(
resample = Resample(spacing=spacing)
make_binary_mask = StructureSetToSegmentation(roi_names=[], continuous=False)
subjects = Parallel(n_jobs=n_jobs)(delayed(cls.process_one_subject)(input,subject_id,output_streams,resample,make_binary_mask) for subject_id in tqdm(subject_id_list))
return cls(subjects,transform,load_getitem)
return cls(subjects,path)

@staticmethod
def process_one_subject(
Expand Down Expand Up @@ -187,16 +181,4 @@ def process_one_subject(
pet = read_results[i]
temp[f"mod_{colname}"] = tio.ScalarImage.from_sitk(pet)
temp[f"metadata_{colname}"] = read_results[i].get_metadata()
return tio.Subject(temp)

if __name__=="__main__":
from torch.utils.data import DataLoader
output_path = "/cluster/projects/radiomics/Temp/vishwesh/HN-ctptdose_test2"
# input_path = "/cluster/home/ramanav/imgtools/examples/data_test"
transform = tio.Compose([tio.Resize(256)])
subjects_dataset = Dataset.load_from_nrrd(output_path,transform=transform)
# subjects_dataset = Dataset.load_directly(input_path,modalities="CT,RTDOSE,PT",n_jobs=4,transform=transform)
print(len(subjects_dataset))
training_loader = DataLoader(subjects_dataset, batch_size=4)
items = next(iter(training_loader))
print(items["mod_RTDOSE_CT"])
return tio.Subject(temp)

0 comments on commit 8a66def

Please sign in to comment.