Skip to content

Commit

Permalink
Dev user op tensor support stride (#7829)
Browse files Browse the repository at this point in the history
* align Parameter::contiguous()

* auto format by CI

* use IsViewApplicable

* auto format by CI

* refine

* auto format by CI

* revise AsStrided to support view

* auto format by CI

* fix view::unfold

* fix doctest

* fix AsStrided

* raw implemetation

* refine

* refine

* fix conflict

* refine

* refine

* fix clang check

* auto format by CI

* fix bug

* refine

* auto format by CI

* fix consistent test

* fix check warning

* refine

* refine

* refine

* refine

* remove useless codes

* auto format by CI

* refine

* support stride in attr value and op kernel infer cache

* refactor

* multi tensorview impl

* auto format by CI

* format

* auto format by CI

* refine

* auto format by CI

* refine

* remove checks

* use select when index size is 1

* auto format by CI

* make inputs contiguous in autograd interpreter

* auto format by CI

* fix bug

* auto format by CI

* remove functors' tensorcontiguous

* auto format by CI

* refine

* refine

* auto format by CI

* refine

* revert slice changes

* refine

* auto format by CI

* refine

* refine

* fix clang check

* auto format by CI

* auto format by CI

* remove useless fn

* refine view::Transpose

* refine

* add view ops tets cases

* refine

* rename

* refine

* opexpr support strdie param

* auto format by CI

* auto format by CI

* fix

* auto format by CI

* refine

* auto format by CI

* fix comments

* fix comments

* fix comments

* refine

* refine

* export oneflow.has_same_tensor_storage api

* auto format by CI

* test has same tensor_storage

* auto format by CI

* fix comments

* fix comment

* refine

* refine

* auto format by CI

* refine

* refine

* remove

* auto format by CI

* auto format by CI

* refine

* auto format by CI

* refine

* fix comments

* refine

* fix comments

* refine

* auto format by CI

* update backward_fn

* auto format by CI

* fix comments

* auto format by CI

* fix diagonal

* refine 0-shape copy check

* refine

* fix

* update

* auto format by CI

* refine

* refine chunk

* refine

* fix

* remove StrideView

* rm useless item

* refine

* remove StrideProto

* refine

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: Chengyu Ma <1802572599@qq.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored May 16, 2022
1 parent 2f68b46 commit cdf34db
Show file tree
Hide file tree
Showing 71 changed files with 485 additions and 198 deletions.
2 changes: 1 addition & 1 deletion oneflow/api/python/framework/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ limitations under the License.
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_rpc_util.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/stride.h"
#include "oneflow/core/common/stride.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/placement_utils.h"
#include "oneflow/core/functional/functional.h"
Expand Down
2 changes: 1 addition & 1 deletion oneflow/api/python/functional/tensor_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ class LocalTensorSharedNumpyDataFunctor {
stride /= element_size_in_bytes;
}
const auto strides = std::make_shared<Stride>(strides_vec);
auto tensor_meta = std::make_shared<MirroredTensorMeta>(shape, data_type, device, strides, 0);
auto tensor_meta = std::make_shared<MirroredTensorMeta>(shape, strides, data_type, device, 0);

// Build TensorBuffer
const auto& Free = [obj](char* dptr) {
Expand Down
2 changes: 1 addition & 1 deletion oneflow/api/python/utils/tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ limitations under the License.
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/stride.h"
#include "oneflow/core/common/stride.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/common/blocking_then_busy.h"
#include "oneflow/core/vm/virtual_machine.h"
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/boxing/generic_symmetric_nd_sbp_boxing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ limitations under the License.
#include "oneflow/core/framework/placement_sbp_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/framework/stride.h"
#include "oneflow/core/common/stride.h"

namespace oneflow {

Expand Down
5 changes: 5 additions & 0 deletions oneflow/core/common/shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ namespace oneflow {

class ShapeView;

namespace cfg {
// TODO: use Int64ListProto replace ShapeProto
class ShapeProto;
} // namespace cfg

class Shape final {
public:
// OF_DISALLOW_COPY_AND_MOVE(Shape);
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/common/shape.proto
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
syntax = "proto2";
package oneflow;

message Int64ListProto {
repeated int64 dim = 1;
}

message ShapeProto {
repeated int64 dim = 1;
}
Expand Down
2 changes: 0 additions & 2 deletions oneflow/core/common/shape_vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,11 @@ namespace oneflow {

typedef std::vector<int64_t> DimVector;
typedef std::vector<int64_t> AxisVector;
typedef std::vector<int64_t> StrideVector;

#else

typedef fixed_vector<int64_t, SHAPE_MAX_AXIS_SIZE> DimVector;
typedef fixed_vector<int64_t, SHAPE_MAX_AXIS_SIZE> AxisVector;
typedef fixed_vector<int64_t, SHAPE_MAX_AXIS_SIZE> StrideVector;

#endif
} // namespace oneflow
Expand Down
83 changes: 83 additions & 0 deletions oneflow/core/common/stride.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include "oneflow/core/common/stride.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/cplusplus_17.h"

namespace oneflow {

Stride::Stride(const Shape& shape) {
if (shape.is_initialized()) {
const int64_t ndim = shape.NumAxes();
stride_vec_.resize(shape.NumAxes());
if (ndim > 0 && shape.elem_cnt() > 0) {
std::exclusive_scan(shape.dim_vec().rbegin(), shape.dim_vec().rend(), stride_vec_.rbegin(), 1,
std::multiplies<>{});
} else if (ndim > 0 && shape.elem_cnt() == 0) {
// 0-size shape
std::vector<int64_t> tmp_shape(ndim);
for (int64_t i = 0; i < ndim; ++i) { tmp_shape[i] = shape.At(i) > 0 ? shape.At(i) : 1; }
std::exclusive_scan(tmp_shape.rbegin(), tmp_shape.rend(), stride_vec_.rbegin(), 1,
std::multiplies<>{});
}
}
}

Stride::Stride(const std::shared_ptr<Shape>& shape) : Stride(*shape) {}

Stride::Stride(const std::initializer_list<int64_t>& stride_vec) : stride_vec_(stride_vec) {}
Stride::Stride(const DimVector& stride_vec) : stride_vec_(stride_vec) {}
Stride::Stride(DimVector&& stride_vec) : stride_vec_(std::move(stride_vec)) {}
Stride::Stride(const Int64ListProto& stride_proto) {
stride_vec_.assign(stride_proto.dim().begin(), stride_proto.dim().end());
}

Stride& Stride::assign(const DimVector& stride_vec) {
stride_vec_ = stride_vec;
return *this;
}

Stride& Stride::CheckNumAxesIdenticalAndAssign(const Stride& stride) {
CHECK_EQ(NumAxes(), stride.NumAxes());
stride_vec_.assign(stride.StrideVec().begin(), stride.StrideVec().end());
return *this;
}

Stride& Stride::operator=(const Stride& stride) {
stride_vec_ = stride.stride_vec_;
return *this;
}

bool Stride::operator==(const Stride& rhs) const { return stride_vec_ == rhs.stride_vec_; }

std::string Stride::ToString() const {
std::stringstream ss;
int32_t idx = 0;
ss << "(";
for (int64_t dim : stride_vec_) {
ss << dim;
if (++idx != stride_vec_.size() || stride_vec_.size() == 1) { ss << ","; }
}
ss << ")";
return ss.str();
}

void Stride::ToProto(Int64ListProto* ret) const {
*(ret->mutable_dim()) = PbRf<int64_t>(stride_vec_.begin(), stride_vec_.end());
}

} // namespace oneflow
17 changes: 12 additions & 5 deletions oneflow/core/framework/stride.h → oneflow/core/common/stride.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,36 @@ limitations under the License.

namespace oneflow {

class StrideView;

class Stride final {
public:
Stride() = default;
explicit Stride(const Shape& shape);
explicit Stride(const StrideVector& stride_vec) : stride_vec_(stride_vec) {}
explicit Stride(StrideVector&& stride_vec) : stride_vec_(stride_vec) {}
Stride(const std::initializer_list<int64_t>& stride_vec) : stride_vec_(stride_vec) {}
explicit Stride(const std::shared_ptr<Shape>& shape);
explicit Stride(DimVector&& stride_vec);
explicit Stride(const DimVector& stride_vec);
explicit Stride(const Int64ListProto& stride_proto);
Stride(const std::initializer_list<int64_t>& stride_vec);
Stride& operator=(const Stride& stride);
Stride& assign(const DimVector& stride_vec);
Stride& CheckNumAxesIdenticalAndAssign(const Stride& stride);
~Stride() = default;

bool operator==(const Stride& rhs) const;
bool operator!=(const Stride& rhs) const { return !(*this == rhs); }

std::string ToString() const;
void ToProto(Int64ListProto*) const;

// Getters and Setters
const StrideVector& StrideVec() const { return stride_vec_; }
const DimVector& StrideVec() const { return stride_vec_; }
int64_t NumAxes() const { return stride_vec_.size(); }
int64_t At(int64_t index) const { return stride_vec_.at(index); }
void Set(int64_t index, int64_t val) { stride_vec_.at(index) = val; }

private:
StrideVector stride_vec_;
DimVector stride_vec_;
};

} // namespace oneflow
Expand Down
7 changes: 5 additions & 2 deletions oneflow/core/eager/eager_blob_object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,22 @@ namespace oneflow {
namespace vm {

EagerBlobObject::EagerBlobObject(const std::shared_ptr<MemoryCase>& mem_case,
const std::shared_ptr<Shape>& shape, DataType data_type,
const std::shared_ptr<Shape>& shape,
const std::shared_ptr<Stride>& stride, DataType data_type,
const std::shared_ptr<TensorStorage>& tensor_storage,
const intrusive::shared_ptr<LocalDepObject>& dep_object)
: is_dynamic_(false),
mem_case_(mem_case),
data_type_(data_type),
shape_(shape),
stride_(stride),
storage_offset_(0),
tensor_storage_(tensor_storage),
is_shape_synced_(true),
compute_local_dep_object_(dep_object),
blob_desc_(shape, data_type) {
blob_desc_(shape, stride, data_type) {
CHECK(static_cast<bool>(shape));
CHECK(static_cast<bool>(stride));
CHECK(static_cast<bool>(tensor_storage));
}

Expand Down
13 changes: 9 additions & 4 deletions oneflow/core/eager/eager_blob_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,13 @@ class EagerBlobObject final {
EagerBlobObject(const EagerBlobObject&) = delete;
EagerBlobObject(EagerBlobObject&&) = delete;
EagerBlobObject(const std::shared_ptr<MemoryCase>& mem_case, const std::shared_ptr<Shape>& shape,
DataType data_type, const std::shared_ptr<TensorStorage>& tensor_storage)
: EagerBlobObject(mem_case, shape, data_type, tensor_storage,
const std::shared_ptr<Stride>& stride, DataType data_type,
const std::shared_ptr<TensorStorage>& tensor_storage)
: EagerBlobObject(mem_case, shape, stride, data_type, tensor_storage,
intrusive::shared_ptr<LocalDepObject>()) {}
EagerBlobObject(const std::shared_ptr<MemoryCase>& mem_case, const std::shared_ptr<Shape>& shape,
DataType data_type, const std::shared_ptr<TensorStorage>& tensor_storage,
const std::shared_ptr<Stride>& stride, DataType data_type,
const std::shared_ptr<TensorStorage>& tensor_storage,
const intrusive::shared_ptr<LocalDepObject>& dep_object);

~EagerBlobObject() { tensor_storage_.reset(); }
Expand Down Expand Up @@ -144,6 +146,9 @@ class EagerBlobObject final {
std::shared_ptr<const Shape> shape_ptr() const { return shape_; }
const Shape& shape() const { return *shape_; }
Shape& mut_shape() { return *shape_; }
std::shared_ptr<const Stride> stride_ptr() const { return stride_; }
const Stride& stride() const { return *stride_; }
Stride& mut_stride() { return *stride_; }

size_t ByteSizeOfBlobBody() const { return shape_->elem_cnt() * GetSizeOfDataType(data_type_); }
size_t AlignedByteSizeOfBlobBody() const {
Expand Down Expand Up @@ -181,7 +186,7 @@ class EagerBlobObject final {
std::shared_ptr<MemoryCase> mem_case_;
DataType data_type_;
std::shared_ptr<Shape> shape_;

std::shared_ptr<Stride> stride_;
int64_t storage_offset_;
std::shared_ptr<TensorStorage> tensor_storage_;
std::atomic<bool> is_shape_synced_;
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/eager/opkernel_instruction_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ struct LocalCallOpKernelUtil final {
operand->consistent_tensor_infer_result().get());
size_t temp_size = InferTmpSizeFn(op_infer_ctx);
temp_eager_blob_object->mut_shape() = Shape({static_cast<int64_t>(temp_size)});
temp_eager_blob_object->mut_stride() = Stride(temp_eager_blob_object->mut_shape());
temp_eager_blob_object->set_pin_memory(false);
temp_eager_blob_object->set_is_dynamic(true);
op_infer_ctx->Update(nullptr, nullptr, nullptr);
Expand Down
10 changes: 7 additions & 3 deletions oneflow/core/framework/attr_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/core/framework/user_op_attr.pb.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/stride.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/protobuf.h"

Expand All @@ -40,7 +41,9 @@ namespace user_op {

#define ENUM_ATTR_SEQ OF_PP_MAKE_TUPLE_SEQ(at_data_type, DataType, AttrType::kAtDataType)

#define MESSAGE_ATTR_SEQ OF_PP_MAKE_TUPLE_SEQ(at_shape, Shape, AttrType::kAtShape)
#define MESSAGE_ATTR_SEQ \
OF_PP_MAKE_TUPLE_SEQ(at_shape, Shape, AttrType::kAtShape) \
OF_PP_MAKE_TUPLE_SEQ(at_stride, Stride, AttrType::kAtStride)

#define LIST_BASIC_ATTR_SEQ \
OF_PP_MAKE_TUPLE_SEQ(at_list_int32, std::vector<int32_t>, AttrType::kAtListInt32) \
Expand All @@ -50,8 +53,9 @@ namespace user_op {
#define LIST_ENUM_ATTR_SEQ \
OF_PP_MAKE_TUPLE_SEQ(at_list_data_type, std::vector<DataType>, AttrType::kAtListDataType)

#define LIST_MESSAGE_ATTR_SEQ \
OF_PP_MAKE_TUPLE_SEQ(at_list_shape, std::vector<Shape>, AttrType::kAtListShape)
#define LIST_MESSAGE_ATTR_SEQ \
OF_PP_MAKE_TUPLE_SEQ(at_list_shape, std::vector<Shape>, AttrType::kAtListShape) \
OF_PP_MAKE_TUPLE_SEQ(at_list_stride, std::vector<Stride>, AttrType::kAtListStride)

#define LIST_STRING_ATTR_SEQ \
OF_PP_MAKE_TUPLE_SEQ(at_list_string, std::vector<std::string>, AttrType::kAtListString)
Expand Down
27 changes: 26 additions & 1 deletion oneflow/core/framework/attr_value_accessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "oneflow/core/framework/attr_value_accessor.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/stride.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/framework/user_op_conf.h"

Expand Down Expand Up @@ -55,6 +56,16 @@ void AttrValueAccessor<Shape>::Attr(const Shape& cpp_val, AttrValue* attr_val) {
cpp_val.ToProto(attr_val->mutable_at_shape());
}

template<>
Stride AttrValueAccessor<Stride>::Attr(const AttrValue& val) {
return Stride(val.at_stride());
}

template<>
void AttrValueAccessor<Stride>::Attr(const Stride& cpp_val, AttrValue* attr_val) {
cpp_val.ToProto(attr_val->mutable_at_stride());
}

// List of Basic Attr
#define LIST_BASIC_ATTR_SEQ_ENTRY(field, cpp_type, attr_type) \
template<> \
Expand Down Expand Up @@ -110,7 +121,21 @@ void AttrValueAccessor<std::vector<Shape>>::Attr(const std::vector<Shape>& cpp_v
cpp_val.at(i).ToProto(attr_val->mutable_at_list_shape()->add_val());
}
}

template<>
std::vector<Stride> AttrValueAccessor<std::vector<Stride>>::Attr(const AttrValue& val) {
std::vector<Stride> ret;
ret.reserve(val.at_list_stride().val_size());
for (const auto& value : val.at_list_stride().val()) { ret.emplace_back(value); }
return ret;
}
template<>
void AttrValueAccessor<std::vector<Stride>>::Attr(const std::vector<Stride>& cpp_val,
AttrValue* attr_val) {
attr_val->mutable_at_list_stride()->clear_val();
FOR_RANGE(int32_t, i, 0, cpp_val.size()) {
cpp_val.at(i).ToProto(attr_val->mutable_at_list_stride()->add_val());
}
}
// List of String Attr
template<>
std::vector<std::string> AttrValueAccessor<std::vector<std::string>>::Attr(const AttrValue& val) {
Expand Down
7 changes: 4 additions & 3 deletions oneflow/core/framework/consistent_tensor_infer_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ Maybe<void> ConsistentTensorMetaInferArgs::MakeInputBlobDescs(
for (int i = 0; i < input_arg_tuple.size(); ++i) {
const auto& tensor_meta = *input_consistent_tensor_metas_.at(i).tensor_meta();
const auto& shape = std::const_pointer_cast<Shape>(tensor_meta.shape_ptr());
blob_descs->emplace_back(shape, tensor_meta.data_type());
const auto& stride = std::const_pointer_cast<Stride>(tensor_meta.stride_ptr());
blob_descs->emplace_back(shape, stride, tensor_meta.data_type());
}
return Maybe<void>::Ok();
}
Expand Down Expand Up @@ -265,7 +266,7 @@ class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStr
{
// Infer OpArgMutConsistentTensorMeta.
const auto& input_metas = infer_args.input_consistent_tensor_metas();
JUST(user_op_expr.InferLogicalShapeAndDType(
JUST(user_op_expr.InferLogicalTensorDesc(
infer_args.attrs(), parallel_desc,
[&](int32_t i) { return &*input_metas.at(i).tensor_meta(); },
[&](int32_t i) { return output_mut_metas.at(i).mut_tensor_meta(); }));
Expand Down Expand Up @@ -329,7 +330,7 @@ class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStr
UNIMPLEMENTED();
return nullptr;
};
JUST(user_op_expr.InferLogicalShapeAndDType(
JUST(user_op_expr.InferLogicalTensorDesc(
infer_args.attrs(), parallel_desc, GetInputTensorMeta,
[&](int32_t i) { return output_mut_metas.at(i).mut_tensor_meta(); }));
}
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/framework/eager_blob_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define ONEFLOW_CORE_FRAMEWORK_EAGER_BLOB_UTIL_H_

#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/stride.h"
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/framework/object.h"
#include "oneflow/core/framework/blob_register.h"
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/framework/infer_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class InferContext {
virtual const Shape& InputShape(const std::string&, int32_t) const = 0;
virtual Shape* OutputShape(const std::string&, int32_t) = 0;
virtual Shape* Shape4ArgNameAndIndex(const std::string&, int32_t) = 0;
virtual const Stride& InputStride(const std::string&, int32_t) const = 0;
virtual Stride* OutputStride(const std::string&, int32_t) = 0;
virtual Stride* Stride4ArgNameAndIndex(const std::string&, int32_t) = 0;
virtual const DataType& InputDType(const std::string&, int32_t) const = 0;
virtual DataType* OutputDType(const std::string&, int32_t) = 0;
virtual DataType* Dtype4ArgNameAndIndex(const std::string&, int32_t) = 0;
Expand Down
Loading

0 comments on commit cdf34db

Please sign in to comment.