Skip to content

Commit

Permalink
Merge branch 'master' into feat-logical_slice_assign_support_full_slice
Browse files Browse the repository at this point in the history
  • Loading branch information
wyg1997 authored Jun 9, 2022
2 parents 4570728 + c10a30c commit 1188cc0
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 19 deletions.
25 changes: 17 additions & 8 deletions oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
Expand Down Expand Up @@ -269,19 +270,27 @@ class BroadcastMinMax : public BroadcastBinaryGrad {
const auto& x_shape = *(x->shape());
const Shape& left_extended_x_shape =
CreateLeftExtendedShape(ShapeView(x_shape), out_shape.NumAxes());
const AxisVector& broadcast_axis_vec = left_extended_x_shape.Axes4BroadcastTo(out_shape);
const std::vector<int32_t> x_axis =
std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};
broad_x_ = JUST(functional::BroadcastLike(x, out_grads.at(0), x_axis));
if (left_extended_x_shape == out_shape) {
broad_x_ = JUST(functional::ReshapeLike(x, JUST(VectorAt(out_grads, 0))));
} else {
const AxisVector& broadcast_axis_vec = left_extended_x_shape.Axes4BroadcastTo(out_shape);
const std::vector<int32_t> x_axis =
std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};
broad_x_ = JUST(functional::BroadcastLike(x, JUST(VectorAt(out_grads, 0)), x_axis));
}
}
if (ctx->broadcast_y) {
const auto& y_shape = *(y->shape());
const Shape& left_extended_y_shape =
CreateLeftExtendedShape(ShapeView(y_shape), out_shape.NumAxes());
const AxisVector& broadcast_axis_vec = left_extended_y_shape.Axes4BroadcastTo(out_shape);
const std::vector<int32_t> y_axis =
std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};
broad_y_ = JUST(functional::BroadcastLike(y, out_grads.at(0), y_axis));
if (left_extended_y_shape == out_shape) {
broad_y_ = JUST(functional::ReshapeLike(y, JUST(VectorAt(out_grads, 0))));
} else {
const AxisVector& broadcast_axis_vec = left_extended_y_shape.Axes4BroadcastTo(out_shape);
const std::vector<int32_t> y_axis =
std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};
broad_y_ = JUST(functional::BroadcastLike(y, JUST(VectorAt(out_grads, 0)), y_axis));
}
}
const auto& broad_grads =
JUST(elementwise_grad_functor_(out_grads.at(0), broad_x_, broad_y_));
Expand Down
31 changes: 21 additions & 10 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,25 +427,36 @@ class BroadcastLikeFunctor {
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& like,
const std::vector<int32_t>& broadcast_axes) const {
const Shape& x_shape = *x->shape();
const Shape& like_shape = *like->shape();
if (x_shape == like_shape) { return x; }
MutableAttrMap attrs;
if (broadcast_axes.empty()) {
int64_t like_ndim = like->shape()->NumAxes();
int64_t x_ndim = x->shape()->NumAxes();
int64_t like_ndim = like_shape.NumAxes();
int64_t x_ndim = x_shape.NumAxes();
int64_t num_prepend = like_ndim - x_ndim;
std::vector<int64_t> prepend_shape(num_prepend, 1);
std::vector<int64_t> broadcast_axes;
for (int i = 0; i < x_ndim; ++i) { prepend_shape.emplace_back(x->shape()->At(i)); }
std::vector<int32_t> broadcast_axes;
for (int i = 0; i < x_ndim; ++i) { prepend_shape.emplace_back(x_shape.At(i)); }
for (int i = 0; i < num_prepend; ++i) { broadcast_axes.emplace_back(i); }
for (int i = num_prepend; i < prepend_shape.size(); ++i) {
if (prepend_shape[i] != like->shape()->At(i)) {
if (prepend_shape[i] == 1) { broadcast_axes.emplace_back(i); }
CHECK_GE_OR_RETURN(prepend_shape[i], 1)
<< Error::RuntimeError() << "output with shape " << x->shape()->ToString()
<< " doesn't match the broadcast shape " << like->shape()->ToString();
if (prepend_shape[i] != like_shape.At(i)) {
if (prepend_shape[i] == 1) {
broadcast_axes.emplace_back(i);
} else {
return Error::RuntimeError() << "The expanded size of the tensor "
<< "(" << like_shape.At(i) << ")"
<< " must match the existing size (" << prepend_shape[i]
<< ") at non-singleton dimension " << i
<< ". Target sizes: " << like_shape.ToString()
<< ". Tensor sizes: " << x_shape.ToString();
}
}
}
JUST(attrs.SetAttr<std::vector<int32_t>>("broadcast_axes", broadcast_axes));
} else {
JUST(attrs.SetAttr<std::vector<int32_t>>("broadcast_axes", broadcast_axes));
}
JUST(attrs.SetAttr<std::vector<int32_t>>("broadcast_axes", broadcast_axes));
return OpInterpUtil::Dispatch<Tensor>(*op_, {x, JUST(like->detach())}, attrs);
}

Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/test/exceptions/test_array_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_broadcast_like_runtime_error(test_case):
like = flow.ones((2, 2, 2), dtype=flow.float32, requires_grad=True)
y = flow.broadcast_like(x, like)
test_case.assertTrue(
"doesn't match the broadcast shape" in str(context.exception)
"The expanded size of the tensor" in str(context.exception)
)

def test_concat_index_error(test_case):
Expand Down
8 changes: 8 additions & 0 deletions python/oneflow/test/modules/test_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ def test_max_broadcast_dtype_promotion(test_case):
y = random_tensor(ndim, *b_dims, dtype=int).to(device)
return torch.max(x, y)

@autotest(n=3, auto_backward=True, check_graph=True)
def test_max_with_diff_size(test_case):
x = flow.rand(1, 1, 4, requires_grad=True)
y = flow.rand(1, 4, requires_grad=True)
x = random_tensor(3, 1, 1, 4)
y = random_tensor(2, 1, 4)
return torch.max(x, y)


if __name__ == "__main__":
unittest.main()

0 comments on commit 1188cc0

Please sign in to comment.