diff --git a/src/segmantic/seg/monai_unet.py b/src/segmantic/seg/monai_unet.py index 3a1049d..0cdcf19 100644 --- a/src/segmantic/seg/monai_unet.py +++ b/src/segmantic/seg/monai_unet.py @@ -554,9 +554,6 @@ def predict( test_labels: Optional[list[Path]] = None, output_dir: Path = None, tissue_dict: dict[str, int] = None, - channels: tuple[int, ...] = (16, 32, 64, 128, 256), - strides: tuple[int, ...] = (2, 2, 2, 2), - dropout: float = 0.0, spacing: Sequence[float] = [], gpu_ids: list[int] = [], ) -> None: @@ -568,9 +565,7 @@ def predict( settings = json.load(json_file) net: Net = Net.load_from_checkpoint(f"{model_file}", **settings) else: - net = Net.load_from_checkpoint( - f"{model_file}", channels=channels, strides=strides, dropout=dropout - ) + net = Net.load_from_checkpoint(f"{model_file}") num_classes = net.num_classes net.freeze() @@ -765,7 +760,7 @@ def cross_validate( for config_file in Path(config_files_dir).iterdir(): assert config_file.suffix in [".json", ".yml"], f"suffix: {config_file}" - is_json = config_file and config_file.suffix.lower() == ".json" + is_json = config_file.suffix.lower() == ".json" dumps = partial(config.dumps, is_json=is_json) loads = partial(config.loads, is_json=is_json) @@ -823,9 +818,6 @@ def cross_validate( test_images=test_images, test_labels=test_labels, tissue_dict=tissue_dict, - # channels=current_layers, - # strides=current_strides, - dropout=0.0, spacing=[1, 1, 1], gpu_ids=gpu_ids, )