Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#24 from Superjomn/fea/add-operation
Browse files Browse the repository at this point in the history
add operation
  • Loading branch information
Superjomn authored Feb 11, 2020
2 parents 811c351 + a2708c2 commit faf65e6
Show file tree
Hide file tree
Showing 22 changed files with 241 additions and 323 deletions.
1 change: 1 addition & 0 deletions cinn/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ cc_library(common
object.cc
graph_utils.cc
context.cc
axis.cc
DEPS boost utils)

cc_test(test_pod_value SRCS pod_value_test.cc DEPS common ir)
Expand Down
28 changes: 28 additions & 0 deletions cinn/common/axis.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "cinn/common/axis.h"

#include <glog/logging.h>

namespace cinn {
namespace common {

const std::vector<std::string> kAxises({
"i", // level 0
"j", // level 1
"k", // level 2
"a", // level 3
"b", // level 4
"c", // level 5
"d", // level 6
"e", // level 7
"f", // level 8
"g", // level 9
"h" // level 10
});

const std::string &axis_name(int level) {
CHECK_LT(level, kAxises.size());
return kAxises[level];
}

} // namespace common
} // namespace cinn
12 changes: 12 additions & 0 deletions cinn/common/axis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once
#include <string>
#include <vector>

namespace cinn {
namespace common {

//! Get the predifined axis name.
const std::string& axis_name(int level);

} // namespace common
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/common/common.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "cinn/common/axis.h"
#include "cinn/common/context.h"
#include "cinn/common/domain.h"
#include "cinn/common/graph_utils.h"
Expand Down
4 changes: 2 additions & 2 deletions cinn/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ cc_library(ir SRCS
function_definition.cc
ir_operators.cc
buffer.cc
#tensor.cc
tensor.cc
function_base.cc
#operation.cc
operation.cc
DEPS common boost
)

Expand Down
14 changes: 0 additions & 14 deletions cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include "cinn/common/pod_value.h"
#include "cinn/ir/ir_visitor.h"
#include "cinn/lang/tensor.h"

namespace cinn {

Expand Down Expand Up @@ -222,19 +221,6 @@ Expr Call::Make(Type type,
return Expr(node);
}

lang::Tensor _Tensor_::Make(const std::vector<Expr> &shape,
const std::vector<Var> &iterators,
Type dtype,
ir::Expr expr) {
CHECK_EQ(shape.size(), iterators.size()) << "dimension of the shape and the iterators should match";
auto n = common::make_shared<_Tensor_>();
n->shape = shape;
n->expr = expr;
n->iterators = iterators;
n->set_type(dtype);
return lang::Tensor(n);
}

} // namespace ir

namespace common {
Expand Down
23 changes: 1 addition & 22 deletions cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,28 +513,7 @@ class _IterVar_ : public IrNode {
void Accept(IrVisitor* v) const override;
IrNodeTy node_type() const override { return _node_type_; }

static const IrNodeTy _node_type_ = IrNodeTy::_Range_;
};

class _Tensor_ : public ExprNode<_Tensor_> {
public:
//! Shape of this tensor.
std::vector<Expr> shape;
//! The expression that generate this tensor.
ir::Expr expr;
//! The iterators, we store the iterators to name the dimensions for better readability.
std::vector<Var> iterators;
//! Polyhedral element for analysis and schedule.
poly::Element* poly_element{};

static lang::Tensor Make(const std::vector<Expr>& shape,
const std::vector<Var>& iterators,
Type dtype,
ir::Expr expr);

_Tensor_() : ExprNode<_Tensor_>(Float(32)) {}

static const IrNodeTy _node_type_ = IrNodeTy::_Tensor_;
static const IrNodeTy _node_type_ = IrNodeTy::_IterVar_;
};

static IterVar thread_axis(Range dom, const std::string& tag) {
Expand Down
1 change: 0 additions & 1 deletion cinn/ir/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ class IrNodeRef : public common::Shared<IrNode> {
}
template <typename T>
T* As() {
LOG(INFO) << "node_type is " << node_type();
if (node_type() == T::_node_type_) return static_cast<T*>(get());
return nullptr;
}
Expand Down
31 changes: 30 additions & 1 deletion cinn/ir/operation.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,34 @@
#include "cinn/ir/operation.h"

#include "cinn/common/common.h"

namespace cinn {
namespace ir {} // namespace ir
namespace ir {

Operation PlaceholderOp::Make(const std::string &name, const std::vector<Expr> &shape, Type dtype) {
auto n = make_shared<PlaceholderOp>();
n->name = name;
n->shape = shape;
n->set_type(dtype);
return Operation(n);
}
const char *PlaceholderOp::func_type() const { return __func_type__; }

const char *ComputeOp::func_type() const { return __func_type__; }

Operation ComputeOp::Make(std::string name,
std::string tag,
std::map<std::string, IrNodeRef> attrs,
std::vector<Var> axis,
std::vector<Expr> body) {
auto n = make_shared<ComputeOp>();
n->name = name;
n->tag = tag;
n->attrs = attrs;
n->axis = axis;
n->body = body;
return Operation(n);
}

} // namespace ir
} // namespace cinn
12 changes: 10 additions & 2 deletions cinn/ir/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct ExternOp : public _Operation_ {
std::vector<Buffer> output_placeholders,
Stmt body);

static constexpr char* buffer_get_element = "cinn_buffer_get_element";
static constexpr char *buffer_get_element = "cinn_buffer_get_element";
};

/**
Expand All @@ -44,7 +44,11 @@ struct PlaceholderOp : public _Operation_ {
//! The data type of the input.
Type dtype;

static Operation Make(std::string name, std::vector<Expr> shape, Type dtype);
static Operation Make(const std::string &name, const std::vector<Expr> &shape, Type dtype);

const char *func_type() const override;

static constexpr char *__func_type__ = "placeholder";
};

/**
Expand All @@ -65,6 +69,10 @@ struct ComputeOp : public _Operation_ {
std::map<std::string, IrNodeRef> attrs,
std::vector<Var> axis,
std::vector<Expr> body);

const char *func_type() const override;

static constexpr char *__func_type__ = "compute";
};

} // namespace ir
Expand Down
92 changes: 35 additions & 57 deletions cinn/ir/tensor.cc
Original file line number Diff line number Diff line change
@@ -1,76 +1,54 @@
#include "cinn/ir/tensor.h"

#include "cinn/common/common.h"
#include "cinn/ir/ir_visitor.h"
#include "cinn/ir/operation.h"

namespace cinn {
namespace ir {

Tensor::Tensor(const std::vector<Var> &shape, Type type) : IrNodeRef(common::make_shared<_Tensor_>()) {
operator->()->shape.clear();
for (auto &v : shape) {
operator->()->shape.push_back(Expr(v));
}
const _Operation_ *Operation::operator->() const { return static_cast<_Operation_ *>(get()); }

operator->()->set_type(type);
}
Tensor::Tensor(const std::vector<Expr> &shape, Type type) : IrNodeRef(common::make_shared<_Tensor_>()) {
operator->()->shape = shape;
operator->()->set_type(type);
}

const _Tensor_ *Tensor::operator->() const {
auto *p = Object::As<_Tensor_>();
CHECK(p) << "type not match";
return p;
}
_Tensor_ *Tensor::operator->() {
auto *p = Object::As<_Tensor_>();
CHECK(p) << "type not match";
return p;
}
size_t Tensor::ndims() const { return operator->()->shape.size(); }

Expr Tensor::operator()(const std::vector<Expr> &indices) const {
CHECK_EQ(indices.size(), ndims()) << "dimension not match";
auto n = Call::Make(operator->()->type().ElementOf(), //
// operator->()->op->name, //
"cinn_buffer_get_element",
indices, //
Call::Halide, //
operator->()->op, //
operator->()->value_index);
n->set_type(operator->()->type());
return n;
}
Tensor _Tensor_::Make(const std::string &name,
const std::string &tag,
const std::vector<Expr> &shape,
const std::vector<Var> &axis,
Type dtype,
const std::map<std::string, IrNodeRef> &attrs,
const std::vector<Expr> &body) {
auto op = ComputeOp::Make(name, tag, attrs, axis, body);
auto *compute_op = const_cast<ComputeOp *>(op->As<ComputeOp>());
compute_op->axis = axis;

Expr Tensor::operator()(const std::vector<Var> &indices) const {
std::vector<Expr> _indices(indices.begin(), indices.end());
return operator()(_indices);
auto n = make_shared<_Tensor_>();
n->name = name;
n->operaion = op;
n->shape = shape;
n->set_type(dtype);
return Tensor(n);
}

bool Tensor::operator==(const Tensor &other) const {
if (get() == other.get()) return true;
if (!get() || !get()) return false;
if (operator->()->op.defined() && other->op.defined()) {
return operator->()->op == other->op && operator->()->value_index == other->value_index;
}
Tensor _Tensor_::Make(const std::string &name, const std::vector<Expr> &shape, FunctionRef fn) {
auto n = make_shared<_Tensor_>();
n->name = name;
n->shape = shape;
n->operaion = fn;
return Tensor(n);
}

IrNodeTy Tensor::node_type() const { return ir::IrNodeTy::_Tensor_; }
Tensor::Tensor(
const std::vector<Expr> &shape, const std::vector<Var> &axis, Type dtype, Expr expr, const std::string &name)
: IrNodeRef(_Tensor_::Make(name, "", shape, axis, dtype, {}, {expr})) {}

void _Tensor_::Accept(IrVisitor *v) const { v->Visit(this); }
size_t Tensor::ndims() const { return operator->()->shape.size(); }

Tensor _Tensor_::Make(const std::vector<Expr> &shape, Type dtype, Operation op, int value_index) {
auto *node = common::make_shared<_Tensor_>();
node->shape = shape;
node->set_type(dtype);
node->op = op;
node->value_index = value_index;
return Tensor(node);
Expr Tensor::operator()(const std::vector<Expr> &indices) const {
CHECK_EQ(indices.size(), ndims()) << "number of indices not match the dimension";
auto *node = operator->();
auto n = Call::Make(node->type().ElementOf(), node->name, indices, Call::Halide, node->operaion);
n->set_type(node->type());
return n;
}

const _Operation_ *Operation::operator->() const { return ptr()->As<_Operation_>(); }
Tensor Operation::output(size_t i) const { return Tensor(); }

} // namespace ir
} // namespace cinn
Loading

0 comments on commit faf65e6

Please sign in to comment.