Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in prediction for Segmentation configs #379

Open
KartikeyKansal1 opened this issue Apr 23, 2024 · 1 comment
Open

Error in prediction for Segmentation configs #379

KartikeyKansal1 opened this issue Apr 23, 2024 · 1 comment

Comments

@KartikeyKansal1
Copy link

Hi, I'm running python cyto_dl/eval.py experiment=im2im/segmentation.yaml ckpt_path='xyz' to do prediction on saved checkpoints from segmentation training but getting the following error. This command usually works for labelfree experiments where it runs prediction for the complete data set.

[2024-04-23 07:22:36,432][cyto_dl.utils.template_utils][INFO] - Closing loggers...
Error executing job with overrides: ['experiment=im2im/segmentation.yaml', 'trainer=cpu', 'experiment_name=240422_exp2_actinseg_batch_100', 'run_name=predict_run_1', 'data.batch_size=100', 'ckpt_path=/Users/kartikeykansal/Documents/tensionGAN/actin/Segmentation_240415/240422/Experiment_2/logs/train/runs/240422_exp2_actinseg_batch_100/train_run_1/2024-04-22_20-05-34/checkpoints/last.ckpt']
Traceback (most recent call last):
  File "/Users/kartikeykansal/Documents/tensionGAN/actin/Segmentation_240415/240422/Experiment_2/cyto_dl_actinseg_exp2_staging_240422/cyto_dl/eval.py", line 99, in main
    evaluate(cfg)
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/cyto_dl/utils/template_utils.py", line 56, in wrap
    raise ex
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/cyto_dl/utils/template_utils.py", line 53, in wrap
    out = task_func(cfg=cfg)
  File "/Users/kartikeykansal/Documents/tensionGAN/actin/Segmentation_240415/240422/Experiment_2/cyto_dl_actinseg_exp2_staging_240422/cyto_dl/eval.py", line 87, in evaluate
    output = method(model=model, dataloaders=data, ckpt_path=cfg.ckpt_path)
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 852, in predict
    return call._call_and_handle_interrupt(
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 894, in _predict_impl
    results = self._run(model, ckpt_path=ckpt_path)
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 946, in _run
    self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 400, in _restore_modules_and_callbacks
    self.restore_model()
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 280, in restore_model
    trainer.strategy.load_model_state_dict(self._loaded_checkpoint)
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 364, in load_model_state_dict
    self.lightning_module.load_state_dict(checkpoint["state_dict"])
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for MultiTaskIm2Im:
	Missing key(s) in state_dict: "task_heads.seg.loss.dice.class_weight". 
@benjijamorris
Copy link
Contributor

Good find! I'll work on a bug fix, in the meantime if you replace your loss function in the model config with something that doesn't require parameters (like torch.nn.MSELoss), it should work in the meantime.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants