Skip to content

Commit

Permalink
fix attr dtype, headers
Browse files Browse the repository at this point in the history
  • Loading branch information
Siyuan Liu committed Nov 28, 2023
1 parent 6508708 commit 7ac10fc
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 39 deletions.
39 changes: 5 additions & 34 deletions torch_xla/csrc/runtime/stablehlo_composite_helper.cc
Original file line number Diff line number Diff line change
@@ -1,36 +1,14 @@
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include "torch_xla/csrc/runtime/stablehlo_composite_helper.h"

#include <limits>
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>

#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
#include "mlir/include/mlir/IR/BuiltinOps.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/Location.h"
#include "mlir/include/mlir/IR/MLIRContext.h"
#include "mlir/include/mlir/IR/Operation.h"
#include "mlir/include/mlir/IR/OperationSupport.h"
#include "mlir/include/mlir/IR/SymbolTable.h"
#include "mlir/include/mlir/IR/TypeUtilities.h"
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/Pass/Pass.h"
#include "mlir/include/mlir/Pass/PassRegistry.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "mlir/include/mlir/Support/TypeID.h"
#include "single_include/nlohmann/json.hpp"
#include "single_include/nlohmann/json_fwd.hpp"
#include "stablehlo/dialect/StablehloOps.h"

namespace torch_xla {
Expand All @@ -40,7 +18,7 @@ namespace {

using nlohmann::json;

bool IsXlaMarkTensorOp(mlir::Operation* op) {
static bool IsXlaMarkTensorOp(mlir::Operation* op) {
if (op == nullptr) {
return false;
}
Expand Down Expand Up @@ -175,12 +153,12 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
case json::value_t::number_float:
named_attrs.push_back(
{builder.getStringAttr(key),
builder.getI64IntegerAttr(j.template get<float>())});
builder.getF32FloatAttr(j.template get<float>())});
break;
case json::value_t::boolean:
named_attrs.push_back(
{builder.getStringAttr(key),
builder.getI64IntegerAttr(j.template get<bool>())});
builder.getBoolAttr(j.template get<bool>())});
break;
case json::value_t::string:
named_attrs.push_back(
Expand Down Expand Up @@ -253,7 +231,6 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
auto scope_ops = scope_ops_setvec.takeVector();
for (auto& op : scope_ops) {
if (!op_line_num.contains(op)) {
LOG(ERROR) << "!!!! Op line number not found";
return;
}
}
Expand All @@ -277,12 +254,6 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
args.push_back(arg);
}

LOG(ERROR) << "-- ARGS:: " << args.size();
LOG(ERROR) << "-- SCOPE_OPS:: " << scope_ops.size();
for (auto scope_op : scope_ops) {
LOG(ERROR) << "---- " << std::string(scope_op->getName().getStringRef());
}

// Creates composite impl function and duplicates all ops within the
// boundary in the function.
llvm::SmallVector<mlir::Location> arg_locs;
Expand Down
9 changes: 4 additions & 5 deletions torch_xla/csrc/runtime/stablehlo_composite_helper.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
#ifndef STABLEHLO_COMPOSITE_HELPER_H_
#define STABLEHLO_COMPOSITE_HELPER_H_
#include <utility>

#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/include/mlir/IR/Operation.h"
#include "mlir/include/mlir/Pass/Pass.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"

namespace torch_xla {
namespace runtime {
Expand All @@ -18,4 +17,4 @@ CreateRemoveXlaMarkTensorOpsPass();
} // namespace runtime
} // namespace torch_xla

#endif
#endif

0 comments on commit 7ac10fc

Please sign in to comment.