Skip to content

Commit

Permalink
fix codegen bug
Browse files Browse the repository at this point in the history
  • Loading branch information
cyber-pioneer committed May 24, 2024
1 parent a84c31e commit a066d35
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 17 deletions.
24 changes: 23 additions & 1 deletion paddle/cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ PD_DECLARE_string(cinn_dump_group_lowered_func);
PD_DECLARE_string(cinn_dump_group_source_code);
PD_DECLARE_string(cinn_dump_group_ptx);
PD_DECLARE_string(cinn_dump_group_instruction);
PD_DECLARE_string(cinn_custom_code_path);

namespace cinn {
namespace backends {
Expand Down Expand Up @@ -267,6 +268,22 @@ void Compiler::BuildDefault(const Module& module) {
});
}

std::string getFileContent(const std::string& filePath) {
std::ifstream file(filePath);

if (!file.is_open()) {
std::cerr << "Unable to open file: " << filePath << std::endl;
return "";
}

std::ostringstream contentStream;
contentStream << file.rdbuf();
std::string content = contentStream.str();

file.close();
return content;
}

void Compiler::CompileCudaModule(const Module& module,
const std::string& code) {
#ifdef CINN_WITH_CUDA
Expand All @@ -278,12 +295,17 @@ void Compiler::CompileCudaModule(const Module& module,

VLOG(3) << "[CUDA] device module:\n" << device_module;
std::string source_code;
if (code.empty()) {

if (!FLAGS_cinn_custom_code_path.empty()) {
std::string filePath = FLAGS_cinn_custom_code_path;
source_code = getFileContent(filePath);
} else if (code.empty()) {
CodeGenCUDA_Dev codegen(target_);
source_code = codegen.Compile(device_module);
} else {
source_code = code;
}

CHECK(!source_code.empty())
<< "Compile CUDA C code failed from device module:\n"
<< device_module;
Expand Down
26 changes: 21 additions & 5 deletions paddle/cinn/ir/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,25 @@ void IrPrinter::Visit(const UIntImm *x) {
namespace {
template <typename T>
bool isCloseEqualMaxValue(T value) {
T max_value = std::numeric_limits<T>::max();
T maxValue = std::numeric_limits<T>::max();
T minValue = std::numeric_limits<T>::lowest();
T tol = std::numeric_limits<T>::denorm_min();
return (max_value - value) < tol;
return (maxValue - value) < tol || (value - minValue) < tol;
}

template <typename T>
T truncateInfinity(T value) {
T maxValue = std::numeric_limits<T>::max();
T minValue = std::numeric_limits<T>::lowest();
if (value > maxValue) {
return maxValue;
}
if (value < minValue) {
return minValue;
}
return value;
}

} // namespace

void IrPrinter::Visit(const FloatImm *x) {
Expand All @@ -123,11 +138,12 @@ void IrPrinter::Visit(const FloatImm *x) {
ss << static_cast<bfloat16>(x->value) << "f";
}
} else if (x->type().is_float(32)) {
if (isCloseEqualMaxValue<float>(x->value)) std::fesetround(FE_TOWARDZERO);
float v = truncateInfinity<float>(x->value);
if (isCloseEqualMaxValue<float>(v)) std::fesetround(FE_TOWARDZERO);
ss << std::setprecision(std::numeric_limits<float>::max_digits10);
ss << std::showpoint;
ss << x->value;
if (std::isfinite(x->value)) {
ss << v;
if (std::isfinite(v)) {
ss << "f";
}
} else if (x->type().is_float(64)) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/cinn/runtime/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ PD_DEFINE_string(
StringFromEnv("FLAGS_cinn_dump_group_instruction", ""),
"Specify the path for dump instruction by group, which is used for debug.");

PD_DEFINE_string(cinn_custom_code_path,
StringFromEnv("FLAGS_cinn_custom_code_path", ""),
"Specify custom code path for cinn.");

PD_DEFINE_string(cinn_pass_visualize_dir,
StringFromEnv("FLAGS_cinn_pass_visualize_dir", ""),
"Specify the directory path of pass visualize file of graph, "
Expand Down
29 changes: 18 additions & 11 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -1039,20 +1039,27 @@ std::tuple<Tensor, Tensor> flatten_decomp(const Tensor& x,
}

template <typename T>
Tensor clip_decomp(const Tensor& x, const Tensor& min_, const Tensor& max_) {
Tensor min_t = min_;
Tensor max_t = max_;
if (min_.dtype() == x.dtype()) {
min_t = cast<T>(min_, x.dtype());
Tensor clip_decomp(const Tensor& x, const Tensor& min, const Tensor& max) {
auto min_reshape = min;
auto max_reshape = max;

if (has_dynamic_shape(x.shape())) {
min_reshape = backend::expand_with_tensor<T>(min, shape<T>(x));
max_reshape = backend::expand_with_tensor<T>(max, shape<T>(x));
} else {
min_reshape = expand<T>(min, x.shape());
max_reshape = expand<T>(max, x.shape());
}
if (max_.dtype() == x.dtype()) {
max_t = cast<T>(max_, x.dtype());
if (min_reshape.dtype() != x.dtype()) {
min_reshape = cast<T>(min_reshape, x.dtype());
}
if (x.size() == 0) {
min_t = reshape<T>(min_t, empty_shape);
max_t = reshape<T>(max_t, empty_shape);

if (max_reshape.dtype() != x.dtype()) {
max_reshape = cast<T>(max_reshape, x.dtype());
}
return maximum<T>(minimum<T>(x, max_t), min_t);

auto ans = maximum<T>(minimum<T>(x, max_reshape), min_reshape);
return ans;
}

template <typename T>
Expand Down

0 comments on commit a066d35

Please sign in to comment.