Skip to content

Commit

Permalink
fixed bugs regarding multiple connections, saving of metadata and loa…
Browse files Browse the repository at this point in the history
…ding of metadata
  • Loading branch information
Vishwesh4 committed Dec 10, 2021
1 parent 735b26c commit 1528efc
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 24 deletions.
20 changes: 10 additions & 10 deletions imgtools/autopipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import ast
import datetime
import numpy as np
import json
import pickle
import shutil
###############################################################
# Example usage:
Expand Down Expand Up @@ -84,7 +84,7 @@ def process_one_subject(self, subject_id):
The ID of subject to process
"""
#Check if the subject_id has already been processed
if os.path.exists(os.path.join(self.output_directory,".temp",f'temp_{subject_id}.json')):
if os.path.exists(os.path.join(self.output_directory,".temp",f'temp_{subject_id}.pkl')):
print(f"{subject_id} already processed")
return

Expand All @@ -99,8 +99,8 @@ def process_one_subject(self, subject_id):
for i, colname in enumerate(self.output_streams):
modality = colname.split("_")[0]

#Taking modality pairs if it exists till _1
output_stream = ("_").join([item for item in colname.split("_") if item != "1"])
#Taking modality pairs if it exists till _{num}
output_stream = ("_").join([item for item in colname.split("_") if item.isnumeric()==False])

#If there are multiple connections existing, multiple connections means two modalities connected to one modality. They end with _1
mult_conn = colname.split("_")[-1].isnumeric()
Expand All @@ -109,7 +109,7 @@ def process_one_subject(self, subject_id):
print(output_stream)

if read_results[i] is None:
print("The subject id: {} has no {}".format(subject_id, ("_").join(colname.split("_")[1:])))
print("The subject id: {} has no {}".format(subject_id,colname))
pass
elif modality == "CT":
image = read_results[i]
Expand Down Expand Up @@ -178,16 +178,16 @@ def process_one_subject(self, subject_id):
metadata[f"metadata_{colname}"] = [read_results[i].get_metadata()]
print(subject_id, " SAVED PET")
#Saving all the metadata in multiple text files
with open(os.path.join(self.output_directory,".temp",f'temp_{subject_id}.json'),'w') as f:
json.dump(metadata,f)
with open(os.path.join(self.output_directory,".temp",f'temp_{subject_id}.pkl'),'wb') as f:
pickle.dump(metadata,f)
return

def save_data(self):
files = glob.glob(os.path.join(self.output_directory,".temp","*.json"))
files = glob.glob(os.path.join(self.output_directory,".temp","*.pkl"))
for file in files:
subject_id = ("_").join(file.replace("/","_").replace(".","_").split("_")[-3:-1])
with open(file) as f:
metadata = json.load(f)
with open(file,"rb") as f:
metadata = pickle.load(f)
self.output_df.loc[subject_id, list(metadata.keys())] = list(metadata.values())
self.output_df.to_csv(self.output_df_path)
shutil.rmtree(os.path.join(self.output_directory,".temp"))
Expand Down
39 changes: 29 additions & 10 deletions imgtools/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import SimpleITK as sitk
import warnings
from imgtools.pipeline import Pipeline
import datetime

class Dataset(tio.SubjectsDataset):
"""
Expand All @@ -34,6 +35,7 @@ def load_from_nrrd(
cls,
path:str,
transform: Optional[Callable] = None,
ignore_multi: bool = True,
load_getitem: bool = True
) -> List[tio.Subject]:
"""
Expand All @@ -48,30 +50,43 @@ def load_from_nrrd(
df_metadata = pd.read_csv(path_metadata,index_col=0)
output_streams = [("_").join(cols.split("_")[1:]) for cols in df_metadata.columns if cols.split("_")[0]=="folder"]
imp_metadata = [cols for cols in df_metadata.columns if cols.split("_")[0] in ("metadata")]
#Based on the file naming taxonomy
#Ignores multiple connection to single modality
if ignore_multi:
output_streams = [items for items in output_streams if items.split("_")[-1].isnumeric()==False]
imp_metadata = [items for items in imp_metadata if items.split("_")[-1].isnumeric()==False]
#Based on the file naming convention
file_names = file_name_convention()
subject_id_list = list(df_metadata.index)
subjects = []
for subject_id in tqdm(subject_id_list):
temp = {}
for col in output_streams:
extension = file_names[col]
mult_conn = col.split("_")[-1].isnumeric()
metadata_name = f"metadata_{col}"
if mult_conn:
extra = col.split("_")[-1]+"_"
extension = file_names[("_").join(col.split("_")[:-1])]
else:
extension = file_names[col]
extra = ""
path_mod = os.path.join(path,extension.split(".")[0],f"{subject_id}_{extra}{extension}.nrrd")
#All modalities except RTSTRUCT should be of type torchIO.ScalarImage
if col!="RTSTRUCT":
temp[f"mod_{col}"] = tio.ScalarImage(path_mod)
if os.path.exists(path_mod):
if col!="RTSTRUCT":
temp[f"mod_{col}"] = tio.ScalarImage(path_mod)
else:
temp[f"mod_{col}"] = tio.LabelImage(path_mod)
else:
temp[f"mod_{col}"] = tio.LabelImage(path_mod)
temp[f"mod_{col}"] = None
#For including metadata
if metadata_name in imp_metadata:
#convert string to proper datatype
temp[metadata_name] = df_metadata.loc[subject_id,metadata_name][0]
meta = df_metadata.loc[subject_id,metadata_name]
if pd.notna(meta):
temp[metadata_name] = eval(meta)[0]
else:
#torch dataloader doesnt accept None type
temp[metadata_name] = {}
subjects.append(tio.Subject(temp))
return cls(subjects,transform,load_getitem)

Expand All @@ -83,6 +98,7 @@ def load_directly(
n_jobs: int = -1,
spacing: Tuple = (1., 1., 0.),
transform: Optional[Callable] = None,
ignore_multi: bool = True,
load_getitem: bool = True
) -> List[tio.Subject]:
"""
Expand All @@ -94,6 +110,9 @@ def load_directly(
input = ImageAutoInput(path, modalities, n_jobs)
df_metadata = input.df_combined
output_streams = input.output_streams
#Ignores multiple connection to single modality
if ignore_multi:
output_streams = [items for items in output_streams if items.split("_")[-1].isnumeric()==False]
#Basic operations
subject_id_list = list(df_metadata.index)
# basic image processing ops
Expand Down Expand Up @@ -172,11 +191,11 @@ def process_one_subject(

if __name__=="__main__":
from torch.utils.data import DataLoader
# output_path = "/cluster/projects/radiomics/Temp/vishwesh/HN-CT_RTdose_test2"
input_path = "/cluster/home/ramanav/imgtools/examples/data_test"
output_path = "/cluster/projects/radiomics/Temp/vishwesh/HN-ctptdose_test2"
# input_path = "/cluster/home/ramanav/imgtools/examples/data_test"
transform = tio.Compose([tio.Resize(256)])
# subjects_dataset = Dataset.load_from_nrrd(output_path,transform=transform)
subjects_dataset = Dataset.load_directly(input_path,modalities="CT,RTDOSE,PT",n_jobs=4,transform=transform)
subjects_dataset = Dataset.load_from_nrrd(output_path,transform=transform)
# subjects_dataset = Dataset.load_directly(input_path,modalities="CT,RTDOSE,PT",n_jobs=4,transform=transform)
print(len(subjects_dataset))
training_loader = DataLoader(subjects_dataset, batch_size=4)
items = next(iter(training_loader))
Expand Down
4 changes: 2 additions & 2 deletions imgtools/modules/pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def get_metadata(self):
except:
pass
try:
self.metadata["scan_time"] = datetime.datetime.strptime(self.df.AcquisitionTime, '%H%M%S.%f').__str__()
self.metadata["injection_time"] = datetime.datetime.strptime(self.df.RadiopharmaceuticalInformationSequence[0].RadiopharmaceuticalStartTime, '%H%M%S.%f').__str__()
self.metadata["scan_time"] = datetime.datetime.strptime(self.df.AcquisitionTime, '%H%M%S.%f')
self.metadata["injection_time"] = datetime.datetime.strptime(self.df.RadiopharmaceuticalInformationSequence[0].RadiopharmaceuticalStartTime, '%H%M%S.%f')
self.metadata["half_life"] = float(self.df.RadiopharmaceuticalInformationSequence[0].RadionuclideHalfLife)
self.metadata["injected_dose"] = float(self.df.RadiopharmaceuticalInformationSequence[0].RadionuclideTotalDose)
except:
Expand Down
4 changes: 2 additions & 2 deletions imgtools/ops/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ def __init__(self,

self.output = {}
for colname in output_streams:
# Not considering colnames ending with _1
colname_process = ("_").join([items for items in colname.split("_") if items!="1"])
# Not considering colnames ending with alphanumeric
colname_process = ("_").join([item for item in colname.split("_") if item.isnumeric()==False])
extension = self.file_name[colname_process]
self.output[colname_process] = ImageFileOutput(os.path.join(root_directory,extension.split(".")[0]),
filename_format="{subject_id}_"+"{}.nrrd".format(extension))
Expand Down

0 comments on commit 1528efc

Please sign in to comment.