diff --git a/cinn/ir/node.cc b/cinn/ir/node.cc index f79aea2b37559..1c352bf390833 100644 --- a/cinn/ir/node.cc +++ b/cinn/ir/node.cc @@ -49,5 +49,11 @@ Expr &Expr::operator=(const Expr &other) { return *this; } +Expr::operator Var() { + auto *x = As(); + CHECK(x); + return ir::Var(x); +} + } // namespace ir } // namespace cinn diff --git a/cinn/ir/node.h b/cinn/ir/node.h index ee3e104e5e8b6..b3f7e9efab6f6 100644 --- a/cinn/ir/node.h +++ b/cinn/ir/node.h @@ -235,6 +235,8 @@ struct Expr : public IrNodeRef { double as_double() const; // @} + operator Var(); + const Type& type() const { return p_->type(); } }; diff --git a/cinn/lang/compute.cc b/cinn/lang/compute.cc index e8491b20d5886..22f65d7118cce 100644 --- a/cinn/lang/compute.cc +++ b/cinn/lang/compute.cc @@ -8,40 +8,42 @@ namespace cinn { namespace lang { -ir::Tensor Compute(const std::vector &dims, std::function fn, const std::string &name) { +ir::Tensor Compute(const std::vector &dims, std::function fn, const std::string &name) { return Compute( dims, - [fn](const std::vector &axis) -> Expr { + [fn](const std::vector &axis) -> Expr { CHECK_EQ(axis.size(), 1); return fn(axis[0]); }, name); } -ir::Tensor Compute(const std::vector &dims, std::function fn, const std::string &name) { +ir::Tensor Compute(const std::vector &dims, std::function fn, const std::string &name) { return Compute( dims, - [fn](const std::vector &axis) -> Expr { + [fn](const std::vector &axis) -> Expr { CHECK_EQ(axis.size(), 2); return fn(axis[0], axis[1]); }, name); } -ir::Tensor Compute(const std::vector &dims, std::function fn, const std::string &name) { +ir::Tensor Compute(const std::vector &dims, std::function fn, const std::string &name) { return Compute( dims, - [fn](const std::vector &axis) -> Expr { + [fn](const std::vector &axis) -> Expr { CHECK_EQ(axis.size(), 3); return fn(axis[0], axis[1], axis[2]); }, name); } -ir::Tensor Compute(const std::vector &dims, std::function fn, const std::string &name) { +ir::Tensor Compute(const std::vector &dims, + std::function fn, + const std::string &name) { return Compute( dims, - [fn](const std::vector &axis) -> Expr { + [fn](const std::vector &axis) -> Expr { CHECK_EQ(axis.size(), 4); return fn(axis[0], axis[1], axis[2], axis[3]); }, @@ -49,11 +51,11 @@ ir::Tensor Compute(const std::vector &dims, std::function &dims, - std::function fn, + std::function fn, const std::string &name) { return Compute( dims, - [fn](const std::vector &axis) -> Expr { + [fn](const std::vector &axis) -> Expr { CHECK_EQ(axis.size(), 5); return fn(axis[0], axis[1], axis[2], axis[3], axis[4]); }, @@ -61,10 +63,12 @@ ir::Tensor Compute(const std::vector &dims, } ir::Tensor Compute(const std::vector &dims, - std::function &)> fn, + std::function &)> fn, const std::string &name) { auto axis = common::GenDefaultAxis(dims.size()); - Expr expr = fn(axis); + std::vector _axis; + for (auto &x : axis) _axis.push_back(x); + Expr expr = fn(_axis); std::vector shape; for (int v : dims) shape.emplace_back(v); diff --git a/cinn/lang/compute.h b/cinn/lang/compute.h index 1e6cbafd85f96..f137872bf8924 100644 --- a/cinn/lang/compute.h +++ b/cinn/lang/compute.h @@ -14,17 +14,19 @@ namespace lang { //! Compute methods for one to five Vars as arguments. // @{ -ir::Tensor Compute(const std::vector &dims, std::function fn, const std::string &name = ""); -ir::Tensor Compute(const std::vector &dims, std::function fn, const std::string &name = ""); -ir::Tensor Compute(const std::vector &dims, std::function fn, const std::string &name = ""); +ir::Tensor Compute(const std::vector &dims, std::function fn, const std::string &name = ""); +ir::Tensor Compute(const std::vector &dims, std::function fn, const std::string &name = ""); ir::Tensor Compute(const std::vector &dims, - std::function fn, + std::function fn, const std::string &name = ""); ir::Tensor Compute(const std::vector &dims, - std::function fn, + std::function fn, const std::string &name = ""); ir::Tensor Compute(const std::vector &dims, - std::function &)> fn, + std::function fn, + const std::string &name = ""); +ir::Tensor Compute(const std::vector &dims, + std::function &)> fn, const std::string &name = ""); // @} diff --git a/cinn/runtime/cinn_runtime.h b/cinn/runtime/cinn_runtime.h index 969df083b27b5..d294652f9650c 100644 --- a/cinn/runtime/cinn_runtime.h +++ b/cinn/runtime/cinn_runtime.h @@ -13,8 +13,8 @@ extern "C" { typedef enum cinn_type_code_t { cinn_type_int = 0, //! signed int cinn_type_uint = 1, //! unsigned int - cinn_type_float = 1, //! floating point - cinn_type_handle = 1 //! void* + cinn_type_float = 2, //! floating point + cinn_type_handle = 3 //! void* } cinn_type_code_t; #ifndef CINN_ATTRIBUTE_ALIGN @@ -60,11 +60,68 @@ typedef enum cinn_buffer_kind_t { cinn_buffer_on_device = 1 << 1 // ! buffer on device e.g. GPU. } cinn_buffer_kind_t; +struct cinn_buffer_t; + +/** + * All CINN backends implementation should provide an interface to be used. + */ +struct cinn_device_interface_impl_t; + +struct cinn_device_interface_t { + int (*malloc)(void* context, struct cinn_buffer_t* buf, const struct cinn_device_interface_t* device_interface); + int (*free)(void* context, struct cinn_device_interface_t* buf); + int (*sync)(void* context, struct cinn_buffer_t* buf); + void (*release)(void* context, const struct cinn_device_interface_t* device_interface); + int (*copy_to_host)(void* context, struct cinn_buffer_t* buf); + int (*copy_to_device)(void* context, + struct cinn_buffer_t* buf, + const struct cinn_device_interface_t* device_interface); + int (*buffer_copy)(void* context, + struct cinn_buffer_t* src, + struct cinn_buffer_t* dst, + const struct cinn_device_interface_t* dest_device_interface); +}; + +/** + * Release all data associated with the given interface. + */ +extern void cinn_device_release(void* context, const struct cinn_device_interface_t* device_interface); + +/* + * Copy image data from device to host memory. + */ +extern int cinn_copy_to_host(void* context, struct cinn_buffer_t* buf); + +//! Copy data from host to device memory. +extern int cinn_copy_to_device(void* context, + struct cinn_buffer_t* buf, + const struct cinn_device_interface_t* device_interface); + +//! Copy data from one buffer to another. +extern int cinn_buffer_copy(void* context, + struct cinn_buffer_t* src, + struct cinn_buffer_t* dst, + const struct cinn_device_interface_t* dest_device_interface); + +//! Wait for current device operations to complete. +extern int cinn_device_sync(void* context, struct cinn_buffer_t* buf); + +//! Allocate device memory. +extern int cinn_device_malloc(void* context, + struct cinn_buffer_t* buf, + const struct cinn_device_interface_t* device_interface); + +//! Free device memory. +extern int cinn_device_free(void* context, struct cinn_buffer_t* buf); + //! The raw representation of a buffer,used in the generated code/lib. typedef struct cinn_buffer_t { //! A device handle. uint64_t device; + //! The interface used to operate on device. + const struct cinn_device_interface_t* device_interface; + //! A pointer to the memory in host. uint8_t* host_memory; @@ -75,13 +132,13 @@ typedef struct cinn_buffer_t { cinn_type_t type; //! Number of dimensions. - int32_t ndims; - cinn_buffer_t* dims; + int32_t dimensions; + cinn_dimension_t* dims; #ifdef __cplusplus - int num_elements() const { + CINN_ALWAYS_INLINE int num_elements() const { int res = 1; - for (int i = 0; i < ndims; i++) { + for (int i = 0; i < dimensions; i++) { res *= dims[i]; } return res; @@ -89,24 +146,26 @@ typedef struct cinn_buffer_t { CINN_ALWAYS_INLINE bool on_host() const { return get_flag(cinn_buffer_on_host); } CINN_ALWAYS_INLINE bool on_device() const { return get_flag(cinn_buffer_on_device); } - CINN_ALWAYS_INLINE void set_on_host(bool x = true) { - if (x) { - set_flag(cinn_buffer_on_host); - } else { - flag &= ~cinn_buffer_on_host; - } - } - CINN_ALWAYS_INLINE void set_on_device(bool x = true) { - if (x) { - set_flag(cinn_buffer_on_device); - } else { - flag &= ~cinn_buffer_on_device; + CINN_ALWAYS_INLINE void set_on_host(bool x = true) { set_flag(cinn_buffer_on_host, x); } + CINN_ALWAYS_INLINE void set_on_device(bool x = true) { set_flag(cinn_buffer_on_device, x); } + + CINN_ALWAYS_INLINE int device_sync(void* ctx = NULL) { + if (device_interface && device_interface->sync) { + return device_interface->sync(ctx, this); } + return 0; } - CINN_ALWAYS_INLINE uint8_t* begin() const {} + + CINN_ALWAYS_INLINE uint8_t* begin() const { return 0; } + CINN_ALWAYS_INLINE uint8_t* end() const { return host_memory + num_elements() * type.bytes(); } CINN_ALWAYS_INLINE bool get_flag(cinn_buffer_kind_t flag) const { return (this->flag & flag) != 0; } - CINN_ALWAYS_INLINE void set_flag(cinn_buffer_kind_t flag) { this->flag |= flag; } + CINN_ALWAYS_INLINE void set_flag(cinn_buffer_kind_t flag, bool value) { + if (value) + this->flag |= flag; + else + this->flag &= ~flag; + } #endif // __cplusplus