Skip to content

Commit

Permalink
yaml thing
Browse files Browse the repository at this point in the history
  • Loading branch information
fishingguy456 committed Jun 16, 2022
1 parent effac09 commit 23b76e4
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
20 changes: 14 additions & 6 deletions examples/autotest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings

from argparse import ArgumentParser
import yaml
import SimpleITK as sitk

from imgtools.ops import StructureSetToSegmentation, ImageAutoInput, ImageAutoOutput, Resample
Expand Down Expand Up @@ -45,7 +46,7 @@ def __init__(self,
is_nnunet=False,
train_size=1.0,
random_state=42,
label_names=None,
read_yaml_label_names=False,
ignore_missing_regex=False):
"""Initialize the pipeline.
Expand Down Expand Up @@ -77,8 +78,8 @@ def __init__(self,
Proportion of the dataset to use for training, as a decimal
random_state: int, default=42
Random state for train_test_split
label_names: dict of str:str, default=None
Dictionary representing the label that regexes are mapped to. For example, "GTV": "GTV.*" will combine all regexes that match "GTV.*" into "GTV"
read_yaml_label_names: bool, default=False
Whether to read dictionary representing the label that regexes are mapped to from YAML. For example, "GTV": "GTV.*" will combine all regexes that match "GTV.*" into "GTV"
ignore_missing_regex: bool, default=False
Whether to ignore missing regexes. Will raise an error if none of the regexes in label_names are found for a patient.
"""
Expand All @@ -100,9 +101,15 @@ def __init__(self,
self.nnunet_info = None
self.train_size = train_size
self.random_state = random_state
self.label_names = label_names
self.label_names = {}
self.ignore_missing_regex = ignore_missing_regex

with open(pathlib.Path(self.input_directory, "roi_names.yaml").as_posix(), "r") as f:
try:
self.label_names = yaml.safe_load(f)
except yaml.YAMLError as exc:
print(exc)

if self.train_size == 1.0:
warnings.warn("Train size is 1, all data will be used for training")

Expand All @@ -115,8 +122,8 @@ def __init__(self,
if self.train_size > 1 or self.train_size < 0 and self.is_nnunet:
raise ValueError("train_size must be between 0 and 1")

if is_nnunet and (label_names is None or label_names == {}):
raise ValueError("label_names must be provided for nnunet")
if is_nnunet and (not read_yaml_label_names or self.label_names == {}):
raise ValueError("YAML label names must be provided for nnunet")


if self.is_nnunet:
Expand Down Expand Up @@ -402,6 +409,7 @@ def run(self):
# is_nnunet=True,
train_size=0.5,
# label_names={"GTV":"GTV.*", "Brainstem": "Brainstem.*"},
read_yaml_label_names=True, # "GTV.*",
ignore_missing_regex=True)

# pipeline = AutoPipeline(input_directory="C:/Users/qukev/BHKLAB/dataset/manifest-1598890146597/NSCLC-Radiomics-Interobserver1",
Expand Down
2 changes: 1 addition & 1 deletion imgtools/modules/structureset.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def to_segmentation(self, reference_image: sitk.Image,
labels = self._assign_labels(roi_names, force_missing)
print("labels:", labels)
if not labels:
if ignore_missing_regex:
if not ignore_missing_regex:
raise ValueError(f"No ROIs matching {roi_names} found in {self.roi_names}.")
else:
return None
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ tqdm
torch
torchio
scikit-learn
pyyaml

0 comments on commit 23b76e4

Please sign in to comment.