Skip to content

Commit

Permalink
test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
GabrielBG0 committed May 14, 2024
1 parent b0b3b2b commit 9424c0b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
16 changes: 12 additions & 4 deletions minerva/models/nets/setr.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,10 +457,18 @@ def __init__(
The interpolation mode for upsampling in the decoder. Defaults to "bilinear".
loss_fn : nn.Module, optional
The loss function to be used during training. Defaults to None.
log_metrics : bool
Whether to log metrics during training. Defaults to True.
metrics : list[MetricTypeSetR], optional
The metrics to be used for evaluation. Defaults to [MetricTypeSetR.mIoU, MetricTypeSetR.mIoU, MetricTypeSetR.mIoU].
train_metrics : Dict[str, Metric], optional
The metrics to be used for training evaluation. Defaults to None.
val_metrics : Dict[str, Metric], optional
The metrics to be used for validation evaluation. Defaults to None.
test_metrics : Dict[str, Metric], optional
The metrics to be used for testing evaluation. Defaults to None.
aux_output : bool
Whether to include auxiliary output heads in the model. Defaults to True.
aux_output_layers : list[int] | None
The indices of the layers to output auxiliary predictions. Defaults to [9, 14, 19].
aux_weights : list[float]
The weights for the auxiliary predictions. Defaults to [0.3, 0.3, 0.3].
"""
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion tests/models/nets/test_setr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_setr_predict():
preds = model.predict_step((x, mask), 0)
assert preds is not None
assert (
preds[0].shape == mask_shape
preds.shape == mask_shape
), f"Expected shape {mask_shape}, but got {preds[0].shape}"


Expand Down

0 comments on commit 9424c0b

Please sign in to comment.