From a6fe28f8c626f5a1ece02742829255a37a1c3297 Mon Sep 17 00:00:00 2001 From: Bryn Lloyd Date: Thu, 7 Nov 2024 16:12:44 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20remove=20model=20parameters=20in?= =?UTF-8?q?=20predict=20(#65)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * remove model parameters in predict --------- Co-authored-by: Bryn Lloyd <12702862+dyollb@users.noreply.github.com> --- src/segmantic/seg/monai_unet.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/segmantic/seg/monai_unet.py b/src/segmantic/seg/monai_unet.py index 3a1049d..0cdcf19 100644 --- a/src/segmantic/seg/monai_unet.py +++ b/src/segmantic/seg/monai_unet.py @@ -554,9 +554,6 @@ def predict( test_labels: Optional[list[Path]] = None, output_dir: Path = None, tissue_dict: dict[str, int] = None, - channels: tuple[int, ...] = (16, 32, 64, 128, 256), - strides: tuple[int, ...] = (2, 2, 2, 2), - dropout: float = 0.0, spacing: Sequence[float] = [], gpu_ids: list[int] = [], ) -> None: @@ -568,9 +565,7 @@ def predict( settings = json.load(json_file) net: Net = Net.load_from_checkpoint(f"{model_file}", **settings) else: - net = Net.load_from_checkpoint( - f"{model_file}", channels=channels, strides=strides, dropout=dropout - ) + net = Net.load_from_checkpoint(f"{model_file}") num_classes = net.num_classes net.freeze() @@ -765,7 +760,7 @@ def cross_validate( for config_file in Path(config_files_dir).iterdir(): assert config_file.suffix in [".json", ".yml"], f"suffix: {config_file}" - is_json = config_file and config_file.suffix.lower() == ".json" + is_json = config_file.suffix.lower() == ".json" dumps = partial(config.dumps, is_json=is_json) loads = partial(config.loads, is_json=is_json) @@ -823,9 +818,6 @@ def cross_validate( test_images=test_images, test_labels=test_labels, tissue_dict=tissue_dict, - # channels=current_layers, - # strides=current_strides, - dropout=0.0, spacing=[1, 1, 1], gpu_ids=gpu_ids, )