diff --git a/paddle/cinn/backends/compiler.cc b/paddle/cinn/backends/compiler.cc index eebcea6aeaa84..82bae21b759d3 100644 --- a/paddle/cinn/backends/compiler.cc +++ b/paddle/cinn/backends/compiler.cc @@ -37,6 +37,7 @@ PD_DECLARE_string(cinn_dump_group_lowered_func); PD_DECLARE_string(cinn_dump_group_source_code); PD_DECLARE_string(cinn_dump_group_ptx); PD_DECLARE_string(cinn_dump_group_instruction); +PD_DECLARE_string(cinn_custom_code_path); namespace cinn { namespace backends { @@ -267,6 +268,22 @@ void Compiler::BuildDefault(const Module& module) { }); } +std::string getFileContent(const std::string& filePath) { + std::ifstream file(filePath); + + if (!file.is_open()) { + std::cerr << "Unable to open file: " << filePath << std::endl; + return ""; + } + + std::ostringstream contentStream; + contentStream << file.rdbuf(); + std::string content = contentStream.str(); + + file.close(); + return content; +} + void Compiler::CompileCudaModule(const Module& module, const std::string& code) { #ifdef CINN_WITH_CUDA @@ -278,12 +295,17 @@ void Compiler::CompileCudaModule(const Module& module, VLOG(3) << "[CUDA] device module:\n" << device_module; std::string source_code; - if (code.empty()) { + + if (!FLAGS_cinn_custom_code_path.empty()) { + std::string filePath = FLAGS_cinn_custom_code_path; + source_code = getFileContent(filePath); + } else if (code.empty()) { CodeGenCUDA_Dev codegen(target_); source_code = codegen.Compile(device_module); } else { source_code = code; } + CHECK(!source_code.empty()) << "Compile CUDA C code failed from device module:\n" << device_module; diff --git a/paddle/cinn/ir/ir_printer.cc b/paddle/cinn/ir/ir_printer.cc index 0959cc265f48e..c6c6f4110de7f 100644 --- a/paddle/cinn/ir/ir_printer.cc +++ b/paddle/cinn/ir/ir_printer.cc @@ -94,10 +94,25 @@ void IrPrinter::Visit(const UIntImm *x) { namespace { template bool isCloseEqualMaxValue(T value) { - T max_value = std::numeric_limits::max(); + T maxValue = std::numeric_limits::max(); + T minValue = std::numeric_limits::lowest(); T tol = std::numeric_limits::denorm_min(); - return (max_value - value) < tol; + return (maxValue - value) < tol || (value - minValue) < tol; } + +template +T truncateInfinity(T value) { + T maxValue = std::numeric_limits::max(); + T minValue = std::numeric_limits::lowest(); + if (value > maxValue) { + return maxValue; + } + if (value < minValue) { + return minValue; + } + return value; +} + } // namespace void IrPrinter::Visit(const FloatImm *x) { @@ -123,11 +138,12 @@ void IrPrinter::Visit(const FloatImm *x) { ss << static_cast(x->value) << "f"; } } else if (x->type().is_float(32)) { - if (isCloseEqualMaxValue(x->value)) std::fesetround(FE_TOWARDZERO); + float v = truncateInfinity(x->value); + if (isCloseEqualMaxValue(v)) std::fesetround(FE_TOWARDZERO); ss << std::setprecision(std::numeric_limits::max_digits10); ss << std::showpoint; - ss << x->value; - if (std::isfinite(x->value)) { + ss << v; + if (std::isfinite(v)) { ss << "f"; } } else if (x->type().is_float(64)) { diff --git a/paddle/cinn/runtime/flags.cc b/paddle/cinn/runtime/flags.cc index b8846da4acc26..8811b9b482a12 100644 --- a/paddle/cinn/runtime/flags.cc +++ b/paddle/cinn/runtime/flags.cc @@ -207,6 +207,10 @@ PD_DEFINE_string( StringFromEnv("FLAGS_cinn_dump_group_instruction", ""), "Specify the path for dump instruction by group, which is used for debug."); +PD_DEFINE_string(cinn_custom_code_path, + StringFromEnv("FLAGS_cinn_custom_code_path", ""), + "Specify custom code path for cinn."); + PD_DEFINE_string(cinn_pass_visualize_dir, StringFromEnv("FLAGS_cinn_pass_visualize_dir", ""), "Specify the directory path of pass visualize file of graph, " diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 2c7c1e80b3c6f..ceaed2ebda5ab 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -1039,20 +1039,27 @@ std::tuple flatten_decomp(const Tensor& x, } template -Tensor clip_decomp(const Tensor& x, const Tensor& min_, const Tensor& max_) { - Tensor min_t = min_; - Tensor max_t = max_; - if (min_.dtype() == x.dtype()) { - min_t = cast(min_, x.dtype()); +Tensor clip_decomp(const Tensor& x, const Tensor& min, const Tensor& max) { + auto min_reshape = min; + auto max_reshape = max; + + if (has_dynamic_shape(x.shape())) { + min_reshape = backend::expand_with_tensor(min, shape(x)); + max_reshape = backend::expand_with_tensor(max, shape(x)); + } else { + min_reshape = expand(min, x.shape()); + max_reshape = expand(max, x.shape()); } - if (max_.dtype() == x.dtype()) { - max_t = cast(max_, x.dtype()); + if (min_reshape.dtype() != x.dtype()) { + min_reshape = cast(min_reshape, x.dtype()); } - if (x.size() == 0) { - min_t = reshape(min_t, empty_shape); - max_t = reshape(max_t, empty_shape); + + if (max_reshape.dtype() != x.dtype()) { + max_reshape = cast(max_reshape, x.dtype()); } - return maximum(minimum(x, max_t), min_t); + + auto ans = maximum(minimum(x, max_reshape), min_reshape); + return ans; } template