diff --git a/cflearn/models/base.py b/cflearn/models/base.py index 08fbb42cb..13d699a09 100644 --- a/cflearn/models/base.py +++ b/cflearn/models/base.py @@ -415,7 +415,42 @@ def forward( x_batch = batch["x_batch"] split = self._split_features(x_batch, batch_indices, loader_name) outputs = self.execute(split) - return self.aggregator.reduce(outputs, **kwargs) + # check whether outputs from each pipe are of identical type + return_type = None + for pipe_outputs in outputs.values(): + pipe_outputs_type = type(pipe_outputs) + if return_type is None: + return_type = pipe_outputs_type + elif return_type is not pipe_outputs_type: + raise ValueError( + f"some pipe(s) return `{return_type}` but " + f"other(s) return `{pipe_outputs_type}`" + ) + # if return_type is Tensor, simply reduce them + if return_type is torch.Tensor: + return {"predictions": self.aggregator.reduce(outputs, **kwargs)} + # otherwise, return_type should be dict, and all pipes should hold the same keys + assert return_type is dict + key_set = None + for pipe_outputs in outputs.values(): + pipe_outputs_key_set = set(pipe_outputs) + if key_set is None: + key_set = pipe_outputs_key_set + elif key_set != pipe_outputs_key_set: + raise ValueError( + f"some pipe(s) return `{key_set}` but " + f"other(s) return `{pipe_outputs_key_set}`" + ) + return { + k: self.aggregator.reduce( + { + pipe_key: pipe_outputs[k] + for pipe_key, pipe_outputs in outputs.items() + }, + **kwargs, + ) + for k in key_set + } def loss_function( self, diff --git a/cflearn/modules/aggregators/base.py b/cflearn/modules/aggregators/base.py index 321ad5663..b23fc1593 100644 --- a/cflearn/modules/aggregators/base.py +++ b/cflearn/modules/aggregators/base.py @@ -1,5 +1,6 @@ from abc import abstractmethod from abc import ABCMeta +from torch import Tensor from typing import Any from typing import Dict from typing import Type @@ -17,8 +18,8 @@ def __init__(self, **kwargs: Any): self.config = kwargs @abstractmethod - def reduce(self, outputs: tensor_dict_type, **kwargs: Any) -> tensor_dict_type: - """ requires returning the `predictions` key """ + def reduce(self, outputs: tensor_dict_type, **kwargs: Any) -> Tensor: + pass @classmethod def register(cls, name: str) -> Callable[[Type], Type]: diff --git a/cflearn/modules/aggregators/sum.py b/cflearn/modules/aggregators/sum.py index 46af393df..18b888ecf 100644 --- a/cflearn/modules/aggregators/sum.py +++ b/cflearn/modules/aggregators/sum.py @@ -1,3 +1,4 @@ +from torch import Tensor from typing import Any from .base import AggregatorBase @@ -6,17 +7,8 @@ @AggregatorBase.register("sum") class Sum(AggregatorBase): - def reduce(self, outputs: tensor_dict_type, **kwargs: Any) -> tensor_dict_type: - values = list(outputs.values()) - output = None - for value in values: - if value is None: - continue - if output is None: - output = value - else: - output = output + value - return {"predictions": output} + def reduce(self, outputs: tensor_dict_type, **kwargs: Any) -> Tensor: + return sum(outputs.values()) __all__ = ["Sum"] diff --git a/tests/unittests/test_doc.py b/tests/unittests/test_doc.py index 05701e2da..bc343fe34 100644 --- a/tests/unittests/test_doc.py +++ b/tests/unittests/test_doc.py @@ -420,8 +420,8 @@ def reduce( self, outputs: tensor_dict_type, **kwargs: Any, - ) -> tensor_dict_type: - return {"predictions": outputs["linear"] * outputs["linear2"]} + ) -> torch.Tensor: + return outputs["linear"] * outputs["linear2"] cflearn.register_model( "prod",