Skip to content

Commit

Permalink
🎨🐛🚧Re-designed the reduce part
Browse files Browse the repository at this point in the history
And fixed bugs when `head` returns a `dict`
  • Loading branch information
carefree0910 committed Mar 22, 2021
1 parent 00bd2c4 commit b99d4a2
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 16 deletions.
37 changes: 36 additions & 1 deletion cflearn/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,42 @@ def forward(
x_batch = batch["x_batch"]
split = self._split_features(x_batch, batch_indices, loader_name)
outputs = self.execute(split)
return self.aggregator.reduce(outputs, **kwargs)
# check whether outputs from each pipe are of identical type
return_type = None
for pipe_outputs in outputs.values():
pipe_outputs_type = type(pipe_outputs)
if return_type is None:
return_type = pipe_outputs_type
elif return_type is not pipe_outputs_type:
raise ValueError(
f"some pipe(s) return `{return_type}` but "
f"other(s) return `{pipe_outputs_type}`"
)
# if return_type is Tensor, simply reduce them
if return_type is torch.Tensor:
return {"predictions": self.aggregator.reduce(outputs, **kwargs)}
# otherwise, return_type should be dict, and all pipes should hold the same keys
assert return_type is dict
key_set = None
for pipe_outputs in outputs.values():
pipe_outputs_key_set = set(pipe_outputs)
if key_set is None:
key_set = pipe_outputs_key_set
elif key_set != pipe_outputs_key_set:
raise ValueError(
f"some pipe(s) return `{key_set}` but "
f"other(s) return `{pipe_outputs_key_set}`"
)
return {
k: self.aggregator.reduce(
{
pipe_key: pipe_outputs[k]
for pipe_key, pipe_outputs in outputs.items()
},
**kwargs,
)
for k in key_set
}

def loss_function(
self,
Expand Down
5 changes: 3 additions & 2 deletions cflearn/modules/aggregators/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import abstractmethod
from abc import ABCMeta
from torch import Tensor
from typing import Any
from typing import Dict
from typing import Type
Expand All @@ -17,8 +18,8 @@ def __init__(self, **kwargs: Any):
self.config = kwargs

@abstractmethod
def reduce(self, outputs: tensor_dict_type, **kwargs: Any) -> tensor_dict_type:
""" requires returning the `predictions` key """
def reduce(self, outputs: tensor_dict_type, **kwargs: Any) -> Tensor:
pass

@classmethod
def register(cls, name: str) -> Callable[[Type], Type]:
Expand Down
14 changes: 3 additions & 11 deletions cflearn/modules/aggregators/sum.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from torch import Tensor
from typing import Any

from .base import AggregatorBase
Expand All @@ -6,17 +7,8 @@

@AggregatorBase.register("sum")
class Sum(AggregatorBase):
def reduce(self, outputs: tensor_dict_type, **kwargs: Any) -> tensor_dict_type:
values = list(outputs.values())
output = None
for value in values:
if value is None:
continue
if output is None:
output = value
else:
output = output + value
return {"predictions": output}
def reduce(self, outputs: tensor_dict_type, **kwargs: Any) -> Tensor:
return sum(outputs.values())


__all__ = ["Sum"]
4 changes: 2 additions & 2 deletions tests/unittests/test_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,8 @@ def reduce(
self,
outputs: tensor_dict_type,
**kwargs: Any,
) -> tensor_dict_type:
return {"predictions": outputs["linear"] * outputs["linear2"]}
) -> torch.Tensor:
return outputs["linear"] * outputs["linear2"]

cflearn.register_model(
"prod",
Expand Down

0 comments on commit b99d4a2

Please sign in to comment.