Skip to content

Commit

Permalink
fix global inplace (#7903)
Browse files Browse the repository at this point in the history
* fix global inplace

* Update test_consistent_linear.py

* Update test_consistent_linear.py
  • Loading branch information
hjchen2 authored Apr 6, 2022
1 parent d94bc0d commit c688b54
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ Maybe<void> Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs,
JUST(SrcOpConsistentTensorMetaInferArgs::New(ctx.attrs, parallel_desc, JUST(ctx.nd_sbp)));
result = JUST(user_op_expr.mut_consistent_tensor_infer_cache()->GetOrInfer(*infer_args));
} else {
for (int i = 0; i < outputs->size(); ++i) {
if ((*outputs)[i]) {
const auto& nd_sbp = JUST((*outputs)[i]->nd_sbp());
JUST((*outputs)[i]->set_consumer_nd_sbp_constraint(nd_sbp));
}
}
const auto& infer_args = JUST(ConsistentTensorMetaInferArgs::New(ctx.attrs, inputs));
result = JUST(user_op_expr.mut_consistent_tensor_infer_cache()->GetOrInfer(*infer_args));
}
Expand All @@ -133,7 +139,9 @@ Maybe<void> Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs,
if (!outputs->at(i)) {
const auto& tensor_impl = JUST(EagerConsistentTensorImpl::New(
output_tensor_metas.at(i), tensor_device, parallel_id, false, false));
outputs->at(i).reset(new ConsistentTensor(tensor_impl));
(*outputs)[i].reset(new ConsistentTensor(tensor_impl));
} else {
JUST((*outputs)[i]->set_consumer_nd_sbp_constraint(NullOpt));
}
}
// Do nothing if output_tensors has 0-size shape. Since the input of some ops is 0-size but the
Expand Down
10 changes: 6 additions & 4 deletions oneflow/core/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
OF_UNIMPLEMENTED();
}
virtual Maybe<MirroredTensor> cur_rank_phy_tensor() const { OF_UNIMPLEMENTED(); }
virtual Maybe<void> set_consumer_nd_sbp_constraint(Symbol<NdSbp> val) { OF_UNIMPLEMENTED(); }
virtual Maybe<void> set_consumer_nd_sbp_constraint(const Optional<Symbol<NdSbp>>& val) {
OF_UNIMPLEMENTED();
}

// Getters for autograd
virtual bool requires_grad() const = 0;
Expand Down Expand Up @@ -168,7 +170,7 @@ class StaticZerosTensor final : public Tensor {
RETURN_ERROR_WITH_BUG_PROMPT();
}
Maybe<MirroredTensor> cur_rank_phy_tensor() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<void> set_consumer_nd_sbp_constraint(Symbol<NdSbp> val) override {
Maybe<void> set_consumer_nd_sbp_constraint(const Optional<Symbol<NdSbp>>& val) override {
RETURN_ERROR_WITH_BUG_PROMPT();
}

Expand Down Expand Up @@ -336,7 +338,7 @@ class ProxyTensor : public TensorIf<DerivedT> {
virtual Maybe<MirroredTensor> cur_rank_phy_tensor() const override {
return tensor_->cur_rank_phy_tensor();
}
virtual Maybe<void> set_consumer_nd_sbp_constraint(Symbol<NdSbp> val) override {
virtual Maybe<void> set_consumer_nd_sbp_constraint(const Optional<Symbol<NdSbp>>& val) override {
return tensor_->set_consumer_nd_sbp_constraint(val);
}

Expand Down Expand Up @@ -573,7 +575,7 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> {
Maybe<bool> has_eager_blob_object() const override { return impl_->has_eager_blob_object(); }

// Setters
Maybe<void> set_consumer_nd_sbp_constraint(Symbol<NdSbp> val) override {
Maybe<void> set_consumer_nd_sbp_constraint(const Optional<Symbol<NdSbp>>& val) override {
impl_->set_consumer_nd_sbp_constraint(val);
return Maybe<void>::Ok();
}
Expand Down
4 changes: 3 additions & 1 deletion oneflow/core/framework/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ class ConsistentTensorImpl : public TensorImpl {
Maybe<bool> has_eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }

// Setters
void set_consumer_nd_sbp_constraint(Symbol<NdSbp> val) { consumer_nd_sbp_constraint_ = val; }
void set_consumer_nd_sbp_constraint(const Optional<Symbol<NdSbp>>& val) {
consumer_nd_sbp_constraint_ = val;
}

ConsistentTensorMeta* mut_tensor_meta() {
PRINT_BUG_PROMPT_AND_ABORT();
Expand Down
19 changes: 7 additions & 12 deletions python/oneflow/test/modules/test_consistent_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@

@autotest(n=1, check_graph=False)
def _test_linear_with_random_data(test_case, placement, weight_sbp, input_sbp):
print(placement)
print(weight_sbp)
input_size = 8
m = torch.nn.Linear(in_features=input_size, out_features=8, bias=random())
m.train(random())
Expand All @@ -40,22 +38,19 @@ def _test_linear_with_random_data(test_case, placement, weight_sbp, input_sbp):
# bias is 1-d tensor
bias_sbp = random_sbp(placement, max_dim=1)
m.bias = torch.nn.Parameter(m.bias.to_global(placement=placement, sbp=bias_sbp))
x = random_tensor(ndim=2, dim1=input_size, dim2=8).to_global(
x = random_tensor(ndim=2, dim0=input_size, dim1=8).to_global(
placement=placement, sbp=input_sbp
)
y = m(x)
return y


# class TestLinearModule(flow.unittest.TestCase):
# @globaltest
# def test_linear_with_random_data(test_case):
# for placement in all_placement():
# # TODO(): Fix 2d sbp
# if len(placement.ranks.shape) != 1:
# continue
# for sbp in all_sbp(placement, max_dim=2):
# _test_linear_with_random_data(test_case, placement, sbp, sbp)
class TestLinearModule(flow.unittest.TestCase):
@globaltest
def test_linear_with_random_data(test_case):
for placement in all_placement():
for sbp in all_sbp(placement, max_dim=2):
_test_linear_with_random_data(test_case, placement, sbp, sbp)


if __name__ == "__main__":
Expand Down

0 comments on commit c688b54

Please sign in to comment.