diff --git a/minerva/models/nets/setr.py b/minerva/models/nets/setr.py index 6e22aa8..7a91827 100644 --- a/minerva/models/nets/setr.py +++ b/minerva/models/nets/setr.py @@ -457,10 +457,18 @@ def __init__( The interpolation mode for upsampling in the decoder. Defaults to "bilinear". loss_fn : nn.Module, optional The loss function to be used during training. Defaults to None. - log_metrics : bool - Whether to log metrics during training. Defaults to True. - metrics : list[MetricTypeSetR], optional - The metrics to be used for evaluation. Defaults to [MetricTypeSetR.mIoU, MetricTypeSetR.mIoU, MetricTypeSetR.mIoU]. + train_metrics : Dict[str, Metric], optional + The metrics to be used for training evaluation. Defaults to None. + val_metrics : Dict[str, Metric], optional + The metrics to be used for validation evaluation. Defaults to None. + test_metrics : Dict[str, Metric], optional + The metrics to be used for testing evaluation. Defaults to None. + aux_output : bool + Whether to include auxiliary output heads in the model. Defaults to True. + aux_output_layers : list[int] | None + The indices of the layers to output auxiliary predictions. Defaults to [9, 14, 19]. + aux_weights : list[float] + The weights for the auxiliary predictions. Defaults to [0.3, 0.3, 0.3]. """ super().__init__() diff --git a/tests/models/nets/test_setr.py b/tests/models/nets/test_setr.py index af39eba..0fbab6e 100644 --- a/tests/models/nets/test_setr.py +++ b/tests/models/nets/test_setr.py @@ -27,7 +27,7 @@ def test_setr_predict(): preds = model.predict_step((x, mask), 0) assert preds is not None assert ( - preds[0].shape == mask_shape + preds.shape == mask_shape ), f"Expected shape {mask_shape}, but got {preds[0].shape}"