forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request PaddlePaddle#24 from Superjomn/fea/add-operation
add operation
- Loading branch information
Showing
22 changed files
with
241 additions
and
323 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#include "cinn/common/axis.h" | ||
|
||
#include <glog/logging.h> | ||
|
||
namespace cinn { | ||
namespace common { | ||
|
||
const std::vector<std::string> kAxises({ | ||
"i", // level 0 | ||
"j", // level 1 | ||
"k", // level 2 | ||
"a", // level 3 | ||
"b", // level 4 | ||
"c", // level 5 | ||
"d", // level 6 | ||
"e", // level 7 | ||
"f", // level 8 | ||
"g", // level 9 | ||
"h" // level 10 | ||
}); | ||
|
||
const std::string &axis_name(int level) { | ||
CHECK_LT(level, kAxises.size()); | ||
return kAxises[level]; | ||
} | ||
|
||
} // namespace common | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#pragma once | ||
#include <string> | ||
#include <vector> | ||
|
||
namespace cinn { | ||
namespace common { | ||
|
||
//! Get the predifined axis name. | ||
const std::string& axis_name(int level); | ||
|
||
} // namespace common | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,34 @@ | ||
#include "cinn/ir/operation.h" | ||
|
||
#include "cinn/common/common.h" | ||
|
||
namespace cinn { | ||
namespace ir {} // namespace ir | ||
namespace ir { | ||
|
||
Operation PlaceholderOp::Make(const std::string &name, const std::vector<Expr> &shape, Type dtype) { | ||
auto n = make_shared<PlaceholderOp>(); | ||
n->name = name; | ||
n->shape = shape; | ||
n->set_type(dtype); | ||
return Operation(n); | ||
} | ||
const char *PlaceholderOp::func_type() const { return __func_type__; } | ||
|
||
const char *ComputeOp::func_type() const { return __func_type__; } | ||
|
||
Operation ComputeOp::Make(std::string name, | ||
std::string tag, | ||
std::map<std::string, IrNodeRef> attrs, | ||
std::vector<Var> axis, | ||
std::vector<Expr> body) { | ||
auto n = make_shared<ComputeOp>(); | ||
n->name = name; | ||
n->tag = tag; | ||
n->attrs = attrs; | ||
n->axis = axis; | ||
n->body = body; | ||
return Operation(n); | ||
} | ||
|
||
} // namespace ir | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,76 +1,54 @@ | ||
#include "cinn/ir/tensor.h" | ||
|
||
#include "cinn/common/common.h" | ||
#include "cinn/ir/ir_visitor.h" | ||
#include "cinn/ir/operation.h" | ||
|
||
namespace cinn { | ||
namespace ir { | ||
|
||
Tensor::Tensor(const std::vector<Var> &shape, Type type) : IrNodeRef(common::make_shared<_Tensor_>()) { | ||
operator->()->shape.clear(); | ||
for (auto &v : shape) { | ||
operator->()->shape.push_back(Expr(v)); | ||
} | ||
const _Operation_ *Operation::operator->() const { return static_cast<_Operation_ *>(get()); } | ||
|
||
operator->()->set_type(type); | ||
} | ||
Tensor::Tensor(const std::vector<Expr> &shape, Type type) : IrNodeRef(common::make_shared<_Tensor_>()) { | ||
operator->()->shape = shape; | ||
operator->()->set_type(type); | ||
} | ||
|
||
const _Tensor_ *Tensor::operator->() const { | ||
auto *p = Object::As<_Tensor_>(); | ||
CHECK(p) << "type not match"; | ||
return p; | ||
} | ||
_Tensor_ *Tensor::operator->() { | ||
auto *p = Object::As<_Tensor_>(); | ||
CHECK(p) << "type not match"; | ||
return p; | ||
} | ||
size_t Tensor::ndims() const { return operator->()->shape.size(); } | ||
|
||
Expr Tensor::operator()(const std::vector<Expr> &indices) const { | ||
CHECK_EQ(indices.size(), ndims()) << "dimension not match"; | ||
auto n = Call::Make(operator->()->type().ElementOf(), // | ||
// operator->()->op->name, // | ||
"cinn_buffer_get_element", | ||
indices, // | ||
Call::Halide, // | ||
operator->()->op, // | ||
operator->()->value_index); | ||
n->set_type(operator->()->type()); | ||
return n; | ||
} | ||
Tensor _Tensor_::Make(const std::string &name, | ||
const std::string &tag, | ||
const std::vector<Expr> &shape, | ||
const std::vector<Var> &axis, | ||
Type dtype, | ||
const std::map<std::string, IrNodeRef> &attrs, | ||
const std::vector<Expr> &body) { | ||
auto op = ComputeOp::Make(name, tag, attrs, axis, body); | ||
auto *compute_op = const_cast<ComputeOp *>(op->As<ComputeOp>()); | ||
compute_op->axis = axis; | ||
|
||
Expr Tensor::operator()(const std::vector<Var> &indices) const { | ||
std::vector<Expr> _indices(indices.begin(), indices.end()); | ||
return operator()(_indices); | ||
auto n = make_shared<_Tensor_>(); | ||
n->name = name; | ||
n->operaion = op; | ||
n->shape = shape; | ||
n->set_type(dtype); | ||
return Tensor(n); | ||
} | ||
|
||
bool Tensor::operator==(const Tensor &other) const { | ||
if (get() == other.get()) return true; | ||
if (!get() || !get()) return false; | ||
if (operator->()->op.defined() && other->op.defined()) { | ||
return operator->()->op == other->op && operator->()->value_index == other->value_index; | ||
} | ||
Tensor _Tensor_::Make(const std::string &name, const std::vector<Expr> &shape, FunctionRef fn) { | ||
auto n = make_shared<_Tensor_>(); | ||
n->name = name; | ||
n->shape = shape; | ||
n->operaion = fn; | ||
return Tensor(n); | ||
} | ||
|
||
IrNodeTy Tensor::node_type() const { return ir::IrNodeTy::_Tensor_; } | ||
Tensor::Tensor( | ||
const std::vector<Expr> &shape, const std::vector<Var> &axis, Type dtype, Expr expr, const std::string &name) | ||
: IrNodeRef(_Tensor_::Make(name, "", shape, axis, dtype, {}, {expr})) {} | ||
|
||
void _Tensor_::Accept(IrVisitor *v) const { v->Visit(this); } | ||
size_t Tensor::ndims() const { return operator->()->shape.size(); } | ||
|
||
Tensor _Tensor_::Make(const std::vector<Expr> &shape, Type dtype, Operation op, int value_index) { | ||
auto *node = common::make_shared<_Tensor_>(); | ||
node->shape = shape; | ||
node->set_type(dtype); | ||
node->op = op; | ||
node->value_index = value_index; | ||
return Tensor(node); | ||
Expr Tensor::operator()(const std::vector<Expr> &indices) const { | ||
CHECK_EQ(indices.size(), ndims()) << "number of indices not match the dimension"; | ||
auto *node = operator->(); | ||
auto n = Call::Make(node->type().ElementOf(), node->name, indices, Call::Halide, node->operaion); | ||
n->set_type(node->type()); | ||
return n; | ||
} | ||
|
||
const _Operation_ *Operation::operator->() const { return ptr()->As<_Operation_>(); } | ||
Tensor Operation::output(size_t i) const { return Tensor(); } | ||
|
||
} // namespace ir | ||
} // namespace cinn |
Oops, something went wrong.