Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#50 from Superjomn/refine/runtime-api
Browse files Browse the repository at this point in the history
refine/runtime api
  • Loading branch information
Superjomn authored Feb 28, 2020
2 parents 198ad34 + fff4f6a commit cfaaa3a
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 38 deletions.
6 changes: 6 additions & 0 deletions cinn/ir/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,11 @@ Expr &Expr::operator=(const Expr &other) {
return *this;
}

Expr::operator Var() {
auto *x = As<ir::_Var_>();
CHECK(x);
return ir::Var(x);
}

} // namespace ir
} // namespace cinn
2 changes: 2 additions & 0 deletions cinn/ir/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ struct Expr : public IrNodeRef {
double as_double() const;
// @}

operator Var();

const Type& type() const { return p_->type(); }
};

Expand Down
28 changes: 16 additions & 12 deletions cinn/lang/compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,63 +8,67 @@
namespace cinn {
namespace lang {

ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Var)> fn, const std::string &name) {
ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Expr)> fn, const std::string &name) {
return Compute(
dims,
[fn](const std::vector<Var> &axis) -> Expr {
[fn](const std::vector<Expr> &axis) -> Expr {
CHECK_EQ(axis.size(), 1);
return fn(axis[0]);
},
name);
}

ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Var, Var)> fn, const std::string &name) {
ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Expr, Expr)> fn, const std::string &name) {
return Compute(
dims,
[fn](const std::vector<Var> &axis) -> Expr {
[fn](const std::vector<Expr> &axis) -> Expr {
CHECK_EQ(axis.size(), 2);
return fn(axis[0], axis[1]);
},
name);
}

ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Var, Var, Var)> fn, const std::string &name) {
ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Expr, Expr, Expr)> fn, const std::string &name) {
return Compute(
dims,
[fn](const std::vector<Var> &axis) -> Expr {
[fn](const std::vector<Expr> &axis) -> Expr {
CHECK_EQ(axis.size(), 3);
return fn(axis[0], axis[1], axis[2]);
},
name);
}

ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Var, Var, Var, Var)> fn, const std::string &name) {
ir::Tensor Compute(const std::vector<int> &dims,
std::function<Expr(Expr, Expr, Expr, Expr)> fn,
const std::string &name) {
return Compute(
dims,
[fn](const std::vector<Var> &axis) -> Expr {
[fn](const std::vector<Expr> &axis) -> Expr {
CHECK_EQ(axis.size(), 4);
return fn(axis[0], axis[1], axis[2], axis[3]);
},
name);
}

ir::Tensor Compute(const std::vector<int> &dims,
std::function<Expr(Var, Var, Var, Var, Var)> fn,
std::function<Expr(Expr, Expr, Expr, Expr, Expr)> fn,
const std::string &name) {
return Compute(
dims,
[fn](const std::vector<Var> &axis) -> Expr {
[fn](const std::vector<Expr> &axis) -> Expr {
CHECK_EQ(axis.size(), 5);
return fn(axis[0], axis[1], axis[2], axis[3], axis[4]);
},
name);
}

ir::Tensor Compute(const std::vector<int> &dims,
std::function<Expr(const std::vector<Var> &)> fn,
std::function<Expr(const std::vector<Expr> &)> fn,
const std::string &name) {
auto axis = common::GenDefaultAxis(dims.size());
Expr expr = fn(axis);
std::vector<Expr> _axis;
for (auto &x : axis) _axis.push_back(x);
Expr expr = fn(_axis);

std::vector<Expr> shape;
for (int v : dims) shape.emplace_back(v);
Expand Down
14 changes: 8 additions & 6 deletions cinn/lang/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@ namespace lang {

//! Compute methods for one to five Vars as arguments.
// @{
ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Var)> fn, const std::string &name = "");
ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Var, Var)> fn, const std::string &name = "");
ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Var, Var, Var)> fn, const std::string &name = "");
ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Expr)> fn, const std::string &name = "");
ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Expr, Expr)> fn, const std::string &name = "");
ir::Tensor Compute(const std::vector<int> &dims,
std::function<Expr(Var, Var, Var, Var)> fn,
std::function<Expr(Expr, Expr, Expr)> fn,
const std::string &name = "");
ir::Tensor Compute(const std::vector<int> &dims,
std::function<Expr(Var, Var, Var, Var, Var)> fn,
std::function<Expr(Expr, Expr, Expr, Expr)> fn,
const std::string &name = "");
ir::Tensor Compute(const std::vector<int> &dims,
std::function<Expr(const std::vector<Var> &)> fn,
std::function<Expr(Expr, Expr, Expr, Expr, Expr)> fn,
const std::string &name = "");
ir::Tensor Compute(const std::vector<int> &dims,
std::function<Expr(const std::vector<Expr> &)> fn,
const std::string &name = "");
// @}

Expand Down
99 changes: 79 additions & 20 deletions cinn/runtime/cinn_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ extern "C" {
typedef enum cinn_type_code_t {
cinn_type_int = 0, //! signed int
cinn_type_uint = 1, //! unsigned int
cinn_type_float = 1, //! floating point
cinn_type_handle = 1 //! void*
cinn_type_float = 2, //! floating point
cinn_type_handle = 3 //! void*
} cinn_type_code_t;

#ifndef CINN_ATTRIBUTE_ALIGN
Expand Down Expand Up @@ -60,11 +60,68 @@ typedef enum cinn_buffer_kind_t {
cinn_buffer_on_device = 1 << 1 // ! buffer on device e.g. GPU.
} cinn_buffer_kind_t;

struct cinn_buffer_t;

/**
* All CINN backends implementation should provide an interface to be used.
*/
struct cinn_device_interface_impl_t;

struct cinn_device_interface_t {
int (*malloc)(void* context, struct cinn_buffer_t* buf, const struct cinn_device_interface_t* device_interface);
int (*free)(void* context, struct cinn_device_interface_t* buf);
int (*sync)(void* context, struct cinn_buffer_t* buf);
void (*release)(void* context, const struct cinn_device_interface_t* device_interface);
int (*copy_to_host)(void* context, struct cinn_buffer_t* buf);
int (*copy_to_device)(void* context,
struct cinn_buffer_t* buf,
const struct cinn_device_interface_t* device_interface);
int (*buffer_copy)(void* context,
struct cinn_buffer_t* src,
struct cinn_buffer_t* dst,
const struct cinn_device_interface_t* dest_device_interface);
};

/**
* Release all data associated with the given interface.
*/
extern void cinn_device_release(void* context, const struct cinn_device_interface_t* device_interface);

/*
* Copy image data from device to host memory.
*/
extern int cinn_copy_to_host(void* context, struct cinn_buffer_t* buf);

//! Copy data from host to device memory.
extern int cinn_copy_to_device(void* context,
struct cinn_buffer_t* buf,
const struct cinn_device_interface_t* device_interface);

//! Copy data from one buffer to another.
extern int cinn_buffer_copy(void* context,
struct cinn_buffer_t* src,
struct cinn_buffer_t* dst,
const struct cinn_device_interface_t* dest_device_interface);

//! Wait for current device operations to complete.
extern int cinn_device_sync(void* context, struct cinn_buffer_t* buf);

//! Allocate device memory.
extern int cinn_device_malloc(void* context,
struct cinn_buffer_t* buf,
const struct cinn_device_interface_t* device_interface);

//! Free device memory.
extern int cinn_device_free(void* context, struct cinn_buffer_t* buf);

//! The raw representation of a buffer,used in the generated code/lib.
typedef struct cinn_buffer_t {
//! A device handle.
uint64_t device;

//! The interface used to operate on device.
const struct cinn_device_interface_t* device_interface;

//! A pointer to the memory in host.
uint8_t* host_memory;

Expand All @@ -75,38 +132,40 @@ typedef struct cinn_buffer_t {
cinn_type_t type;

//! Number of dimensions.
int32_t ndims;
cinn_buffer_t* dims;
int32_t dimensions;
cinn_dimension_t* dims;

#ifdef __cplusplus
int num_elements() const {
CINN_ALWAYS_INLINE int num_elements() const {
int res = 1;
for (int i = 0; i < ndims; i++) {
for (int i = 0; i < dimensions; i++) {
res *= dims[i];
}
return res;
}

CINN_ALWAYS_INLINE bool on_host() const { return get_flag(cinn_buffer_on_host); }
CINN_ALWAYS_INLINE bool on_device() const { return get_flag(cinn_buffer_on_device); }
CINN_ALWAYS_INLINE void set_on_host(bool x = true) {
if (x) {
set_flag(cinn_buffer_on_host);
} else {
flag &= ~cinn_buffer_on_host;
}
}
CINN_ALWAYS_INLINE void set_on_device(bool x = true) {
if (x) {
set_flag(cinn_buffer_on_device);
} else {
flag &= ~cinn_buffer_on_device;
CINN_ALWAYS_INLINE void set_on_host(bool x = true) { set_flag(cinn_buffer_on_host, x); }
CINN_ALWAYS_INLINE void set_on_device(bool x = true) { set_flag(cinn_buffer_on_device, x); }

CINN_ALWAYS_INLINE int device_sync(void* ctx = NULL) {
if (device_interface && device_interface->sync) {
return device_interface->sync(ctx, this);
}
return 0;
}
CINN_ALWAYS_INLINE uint8_t* begin() const {}

CINN_ALWAYS_INLINE uint8_t* begin() const { return 0; }
CINN_ALWAYS_INLINE uint8_t* end() const { return host_memory + num_elements() * type.bytes(); }

CINN_ALWAYS_INLINE bool get_flag(cinn_buffer_kind_t flag) const { return (this->flag & flag) != 0; }
CINN_ALWAYS_INLINE void set_flag(cinn_buffer_kind_t flag) { this->flag |= flag; }
CINN_ALWAYS_INLINE void set_flag(cinn_buffer_kind_t flag, bool value) {
if (value)
this->flag |= flag;
else
this->flag &= ~flag;
}

#endif // __cplusplus

Expand Down

0 comments on commit cfaaa3a

Please sign in to comment.