Skip to content

Commit

Permalink
changes for roi names as a dict
Browse files Browse the repository at this point in the history
Former-commit-id: 0479cb8
  • Loading branch information
fishingguy456 committed Jun 14, 2022
1 parent 318cc12 commit 3fe34ed
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 49 deletions.
15 changes: 11 additions & 4 deletions examples/autotest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def __init__(self,
overwrite=False,
is_nnunet=False,
train_size=1.0,
random_state=42):
random_state=42,
label_names=None):
"""Initialize the pipeline.
Parameters
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(self,
self.nnunet_info = None
self.train_size = train_size
self.random_state = random_state
self.label_names = label_names

if self.train_size == 1.0:
warnings.warn("Train size is 1, all data will be used for training")
Expand All @@ -106,6 +108,10 @@ def __init__(self,

if self.train_size > 1 or self.train_size < 0 and self.is_nnunet:
raise ValueError("train_size must be between 0 and 1")

if is_nnunet and (label_names is None or label_names == {}):
raise ValueError("label_names must be provided for nnunet")


if self.is_nnunet:
self.nnunet_info["modalities"] = {"CT": "0000"} #modality to 4-digit code
Expand All @@ -121,7 +127,7 @@ def __init__(self,

# image processing ops
self.resample = Resample(spacing=self.spacing)
self.make_binary_mask = StructureSetToSegmentation(roi_names=[], continuous=False) # "GTV-.*"
self.make_binary_mask = StructureSetToSegmentation(roi_names=self.label_names, continuous=False) # "GTV-.*"

# output ops
self.output = ImageAutoOutput(self.output_directory, self.output_streams, self.nnunet_info)
Expand Down Expand Up @@ -249,7 +255,7 @@ def process_one_subject(self, subject_id):
mask_arr = np.transpose(sitk.GetArrayFromImage(mask))

if self.is_nnunet:
sparse_mask = mask.generate_sparse_mask().mask_array
sparse_mask = mask.generate_sparse_mask(self.label_names).mask_array
sparse_mask = sitk.GetImageFromArray(sparse_mask) #convert the nparray to sitk image
if subject_id in self.train:
self.output(subject_id, sparse_mask, output_stream, nnunet_info=self.nnunet_info, label_or_image="labels") #rtstruct is label for nnunet
Expand Down Expand Up @@ -373,7 +379,8 @@ def run(self):
visualize=False,
overwrite=True,
is_nnunet=True,
train_size=0.5)
train_size=0.5,
label_names={})

# pipeline = AutoPipeline(input_directory="C:/Users/qukev/BHKLAB/dataset/manifest-1598890146597/NSCLC-Radiomics-Interobserver1",
# output_directory="C:/Users/qukev/BHKLAB/autopipelineoutput",
Expand Down
2 changes: 1 addition & 1 deletion imgtools/modules/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __getitem__(self, idx):
def __repr__(self):
return f"<Segmentation with ROIs: {self.roi_names!r}>"

def generate_sparse_mask(self, verbose=False) -> SparseMask:
def generate_sparse_mask(self, label_names, verbose=False) -> SparseMask:
"""
Generate a sparse mask from the contours, taking the argmax of all overlaps
Expand Down
49 changes: 6 additions & 43 deletions imgtools/modules/structureset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,45 +38,6 @@ def from_dicom_rtstruct(cls, rtstruct_path: str) -> 'StructureSet':
warn(f"Could not get points for ROI {name} (in {rtstruct_path}).")

metadata = {}
# if hasattr(rtstruct, 'StructureSetROISequence'):
# metadata["numROIs"] = str(len(rtstruct.StructureSetROISequence))

# if hasattr(rtstruct, 'BodyPartExamined'):
# metadata["BodyPartExamined"] = str(rtstruct.BodyPartExamined)
# if hasattr(rtstruct, 'DataCollectionDiameter'):
# metadata["DataCollectionDiameter"] = str(rtstruct.DataCollectionDiameter)
# # Number of Slices is avg. number slice?
# if hasattr(rtstruct, 'NumberofSlices'):
# metadata["NumberofSlices"] = str(rtstruct.NumberofSlices)
# # Slice Thickness is avg. slice thickness?
# if hasattr(rtstruct, 'SliceThickness'):
# metadata["SliceThickness"] = str(rtstruct.SliceThickness)
# if hasattr(rtstruct, 'ScanType'):
# metadata["ScanType"] = str(rtstruct.ScanType)
# # Scan Progression Direction is Scan Direction?
# if hasattr(rtstruct, 'ScanProgressionDirection'):
# metadata["ScanProgressionDirection"] = str(rtstruct.ScanProgressionDirection)
# if hasattr(rtstruct, 'PatientPosition'):
# metadata["PatientPosition"] = str(rtstruct.PatientPosition)
# # is this contrast type?
# if hasattr(rtstruct, 'ContrastBolusAgent'):
# metadata["ContrastType"] = str(rtstruct.ContrastBolusAgent)
# if hasattr(rtstruct, 'Manufacturer'):
# metadata["Manufacturer"] = str(rtstruct.Manufacturer)
# # Which field of view?
# # if hasattr(rtstruct, 'FieldOfViewDescription'):
# # metadata["FieldOfViewDescription"] = str(rtstruct.FieldOfViewDescription)
# # Scan Plane?
# if hasattr(rtstruct, 'ScanOptions'):
# metadata["ScanOptions"] = str(rtstruct.ScanOptions)
# if hasattr(rtstruct, 'RescaleType'):
# metadata["RescaleType"] = str(rtstruct.RescaleType)
# if hasattr(rtstruct, 'RescaleSlope'):
# metadata["RescaleSlope"] = str(rtstruct.RescaleSlope)
# if hasattr(rtstruct, 'PixelSpacing') and hasattr(rtstruct, 'SliceThickness'):
# pixel_size = copy.copy(rtstruct.PixelSpacing)
# pixel_size.append(rtstruct.SliceThickness)
# metadata["PixelSize"] = str(tuple(pixel_size))

return cls(roi_points, metadata)
# return cls(roi_points)
Expand Down Expand Up @@ -125,7 +86,7 @@ def _assign_labels(self, names, force_missing=False):
return labels

def to_segmentation(self, reference_image: sitk.Image,
roi_names: Optional[List[Union[str, List[str]]]] = None,
roi_names: Dict[str: str] = None,
force_missing: bool = False,
continuous: bool = True) -> Segmentation:
"""Convert the structure set to a Segmentation object.
Expand Down Expand Up @@ -165,12 +126,14 @@ def to_segmentation(self, reference_image: sitk.Image,
guaranteed (unless all patterns in `roi_names` can only match
a single name or are lists of strings).
"""
if not roi_names:
if not roi_names or roi_names == {}:
roi_names = self.roi_names
if isinstance(roi_names, str):
roi_names = [roi_names]

labels = self._assign_labels(roi_names, force_missing)
if isinstance(roi_names, list):
labels = self._assign_labels(roi_names, force_missing)
else:
labels = self._assign_labels(list(roi_names.values()), force_missing)
print("labels:", labels)
if not labels:
raise ValueError(f"No ROIs matching {roi_names} found in {self.roi_names}.")
Expand Down
2 changes: 1 addition & 1 deletion imgtools/ops/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ class StructureSetToSegmentation(BaseOp):
"""

def __init__(self,
roi_names: Union[str,List[str]],
roi_names: Dict[str: str],
force_missing: bool = False,
continuous: bool = True):
"""Initialize the op.
Expand Down
10 changes: 10 additions & 0 deletions imgtools/utils/dicomutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def all_modalities_metadata(dicom_data: Union[pydicom.dataset.FileDataset, pydic
pixel_size = copy.copy(dicom_data.PixelSpacing)
pixel_size.append(dicom_data.SliceThickness)
metadata["PixelSize"] = str(tuple(pixel_size))
if hasattr(dicom_data, 'ManufacturerModelName'):
metadata["ManufacturerModelName"] = str(dicom_data.ManufacturerModelName)
return metadata


Expand All @@ -59,6 +61,12 @@ def ct_metadata(dicom_data: Union[pydicom.dataset.FileDataset, pydicom.dicomdir.
# is this contrast type?
if hasattr(dicom_data, 'ContrastBolusAgent'):
metadata["ContrastType"] = str(dicom_data.ContrastBolusAgent)
if hasattr(dicom_data, 'ReconstructionMethod'):
metadata["ReconstructionMethod"] = str(dicom_data.ReconstructionMethod)
if hasattr(dicom_data, 'ReconstructionDiameter'):
metadata["ReconstructionDiameter"] = str(dicom_data.ReconstructionDiameter)
if hasattr(dicom_data, 'ConvolutionKernel'):
metadata["ConvolutionKernel"] = str(dicom_data.ConvolutionKernel)
return metadata


Expand All @@ -79,6 +87,8 @@ def mr_metadata(dicom_data: Union[pydicom.dataset.FileDataset, pydicom.dicomdir.
metadata["ImagingFrequency"] = str(dicom_data.ImagingFrequency)
if hasattr(dicom_data, 'MagneticFieldStrength'):
metadata["MagneticFieldStrength"] = str(dicom_data.MagneticFieldStrength)
if hasattr(dicom_data, 'SequenceName'):
metadata["SequenceName"] = str(dicom_data.SequenceName)
return metadata


Expand Down

0 comments on commit 3fe34ed

Please sign in to comment.