diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp index 2ebb963bd13..8d57b6149ea 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -36,8 +36,9 @@ TorchMlirLoweringContext::TorchMlirLoweringContext( const std::string& name, BackendDevice device) : LoweringContext(name, std::forward(device)), graph_(std::make_shared()), + function_( + std::make_shared(name, graph_, nullptr)), mlir_context_(mlirContextCreate()) { - lowering_ = TorchMlirNodeLoweringInterface::Create(this); RegisterMlirDialects(); } @@ -49,16 +50,31 @@ TorchMlirLoweringContext::TorchMlirLoweringContext( std::forward>(post_order), std::forward(emit_status)), graph_(std::make_shared()), + function_( + std::make_shared(name, graph_, nullptr)), mlir_context_(mlirContextCreate()) { - lowering_ = TorchMlirNodeLoweringInterface::Create(this); for (auto node : post_order) { - bool ok = lowering_->Lower(node); - CHECK(ok) << "Failed to lower: " << *node; + Lower(node); } RegisterMlirDialects(); } +void TorchMlirLoweringContext::Lower(const Node* node) { + if (auto* torch_mlir_node = + dynamic_cast(node)) { + TorchMlirOpVector ops = torch_mlir_node->Lower(function_, this); + CHECK(!ops.empty()) << "Failed to lower: " << *node; + CHECK_EQ(node->num_outputs(), ops.size()); + for (size_t i = 0; i < ops.size(); ++i) { + AssignOutputOp(torch::lazy::Output(node, i), ops[i]); + } + } else { + throw std::runtime_error( + "Expected torch::lazy::TorchMlirNode but could not dynamic cast"); + } +} + void TorchMlirLoweringContext::SetUpAlias( const std::vector& output_index, int64_t param_number, const std::vector& param_index, bool must_alias) { @@ -136,8 +152,7 @@ torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) { if (it == emitted_outputs_.end()) { auto post_order = Util::ComputePostOrder(output.node, &emit_status_); for (auto node : post_order) { - bool ok = lowering_->Lower(node); - TORCH_CHECK(ok, "Failed to lower: ", node->ToString()); + Lower(node); } // At this point the output better be present, otherwise there is an issue // with the lowering code. diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h index 4a025b5bb9e..7c4c36a91fc 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h @@ -23,22 +23,6 @@ namespace torch { namespace lazy { -class TORCH_API TorchMlirNodeLoweringInterface { - /** - * This interface is only needed for legacy ops, and can be removed once all - * ops implement LtcMlirNode->lower(). - * */ -public: - TorchMlirNodeLoweringInterface() = default; - - virtual ~TorchMlirNodeLoweringInterface() = default; - - virtual bool Lower(const Node* node) = 0; - - static std::unique_ptr - Create(LoweringContext* loctx); -}; - class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext { public: // Describes an input/output alias as inserted by the SetUpAlias() API. @@ -61,6 +45,8 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext { c10::ArrayRef post_order, torch::lazy::Util::EmissionMap emit_status); + void Lower(const Node* node); + // Adds a new input/output alias. void SetUpAlias( const std::vector& output_index, int64_t param_number, @@ -120,11 +106,11 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext { // Holds the input/output alias information populated by the SetUpAlias() API. InputOutputAliases input_output_aliases_; std::shared_ptr graph_; + std::shared_ptr function_; MlirContext mlir_context_; std::unordered_map parameters_map_; std::vector root_tuple_; OutputMap emitted_outputs_; - std::unique_ptr lowering_; }; class TORCH_API TorchMlirComputation : public torch::lazy::Computation { diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp index 8e4ad40d261..2a56dd0fdd3 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp @@ -71,12 +71,6 @@ hash_t TorchMlirNode::hash() const { return dag_hash_; } hash_t TorchMlirNode::shapeHash() const { return shape_hash_; } -TorchMlirOpVector TorchMlirNode::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { - return {}; -} - - OpKind TorchMlirTensorList::ClassOpKind() { // Note: this OpKind is separate from ltc_ops.h since it would be a circular // import otherwise diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp index 174e6808b07..5012a8d69fc 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp @@ -32,7 +32,7 @@ namespace torch { namespace lazy { TorchMlirOpVector LowerTorchMlirBuiltin( - std::shared_ptr function, c10::Symbol sym, + TorchMlirFunction function, c10::Symbol sym, const std::vector tensor_types, const std::vector& arguments, const std::vector& kwarguments) { @@ -81,7 +81,7 @@ TorchMlirOpVector LowerTorchMlirBuiltin( } TorchMlirOpVector LowerTorchMlirBuiltin( - std::shared_ptr function, c10::Symbol sym, + TorchMlirFunction function, c10::Symbol sym, const c10::ArrayRef result_shapes, const std::vector& arguments, const std::vector& kwarguments) { @@ -101,6 +101,29 @@ TorchMlirOpVector LowerTorchMlirBuiltin( function, sym, tensor_types, arguments, kwarguments); } +TorchMlirOpVector LowerBuiltin( + const torch::lazy::Node* node, TorchMlirFunction function, + const std::vector& arguments, + const std::vector& kwarguments = {}) { + return LowerTorchMlirBuiltin( + function, node->op().op, node->shapes(), arguments, kwarguments); +} +TorchMlirOpVector LowerBuiltin( + c10::Symbol sym, const c10::ArrayRef result_shapes, + TorchMlirFunction function, + const std::vector& arguments, + const std::vector& kwarguments = {}) { + return LowerTorchMlirBuiltin( + function, sym, result_shapes, arguments, kwarguments); +} +TorchMlirOpVector LowerBuiltin( + c10::Symbol sym, const std::vector types, + TorchMlirFunction function, + const std::vector& arguments, + const std::vector& kwarguments = {}) { + return LowerTorchMlirBuiltin(function, sym, types, arguments, kwarguments); +} + c10::TensorType& cast_tensor_type(c10::TypePtr value_type) { auto tensor_type = value_type->cast(); TORCH_CHECK(tensor_type, "Unable to cast Value type to TensorType!"); @@ -174,358 +197,284 @@ std::vector compute_shape_slice( return {Shape(scalar_type.value(), dims)}; } -class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface { -public: - TorchMlirNodeLowering( - const std::string& name, torch::lazy::TorchMlirLoweringContext* loctx) - : loctx_(loctx), function_( - loctx ? std::make_shared( - name, loctx->graph(), nullptr) - : nullptr) {} - - torch::lazy::TorchMlirLoweringContext* loctx() { return loctx_; } - - bool Lower(const torch::lazy::Node* node) override { - if (auto* torch_mlir_node = - dynamic_cast(node)) { - // First, we call the node lowering function, which exists for newly - // codegenned or refactored nodes - TorchMlirOpVector ops = torch_mlir_node->Lower(function_, loctx()); - if (ops.empty()) { - // Then fall back to legacy lowering code, which should be gradually - // removed - ops = LowerNonCodegenOps(node); - } - if (ops.empty()) { - return false; - } - CHECK_EQ(node->num_outputs(), ops.size()); - for (size_t i = 0; i < ops.size(); ++i) { - loctx()->AssignOutputOp(torch::lazy::Output(node, i), ops[i]); - } - return true; - } else { - TorchMlirOpVector ops = LowerNonCodegenOps(node); - if (!ops.empty()) { - CHECK_EQ(node->num_outputs(), ops.size()); - for (size_t i = 0; i < ops.size(); ++i) { - loctx()->AssignOutputOp(torch::lazy::Output(node, i), ops[i]); - } - return true; - } - } - throw std::runtime_error( - "Expected torch::lazy::TorchMlirNode but could not dynamic cast"); - } +torch::jit::Value* +GenerateClone(torch::jit::Value* val, TorchMlirFunction function) { + std::vector clone_arguments; + clone_arguments.emplace_back(val); - // TODO(whc) this is for legacy/non-codegen Ops, and after moving most ops - // to codegen we should delete this and put all the lowering logic into Node - // classes - TorchMlirOpVector LowerNonCodegenOps(const torch::lazy::Node* node) { + // Type of cloned value should be identical to the original one. + TorchMlirOpVector cloned = + LowerBuiltin(at::aten::clone, {val->type()}, function, clone_arguments); + CHECK_EQ(cloned.size(), 1); + return cloned.front(); +} - if (node->op().op == at::aten::as_strided) { - return LowerAsStrided(torch::lazy::NodeCast( - node, torch::lazy::OpKind(at::aten::as_strided))); - } - if (node->op() == *torch::lazy::ltc_as_strided_view_update) { - return LowerAsStridedViewUpdate( - torch::lazy::NodeCast( - node, *torch::lazy::ltc_as_strided_view_update)); - } - if (node->op() == *torch::lazy::ltc_cast) { - return LowerCast(torch::lazy::NodeCast( - node, *torch::lazy::ltc_cast)); - } - if (node->op() == *torch::lazy::ltc_select_view_update) { - return LowerSelectViewUpdate( - torch::lazy::NodeCast( - node, *torch::lazy::ltc_select_view_update)); - } - if (node->op() == *torch::lazy::ltc_narrow_view_update) { - return LowerNarrowViewUpdate( - torch::lazy::NodeCast( - node, *torch::lazy::ltc_narrow_view_update)); - } - if (node->op().op == at::prim::Constant) { - return LowerScalar(torch::lazy::NodeCast( - node, torch::lazy::OpKind(at::prim::Constant))); - } - if (node->op().op == at::aten::bernoulli) { - std::vector arguments; - arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); - return LowerBuiltin(node, arguments); - } - if (node->op().op == at::aten::expand) { - return LowerExpand(torch::lazy::NodeCast( - node, torch::lazy::OpKind(at::aten::expand))); - } - if (node->op().op == at::aten::narrow) { - return LowerNarrow(torch::lazy::NodeCast( - node, torch::lazy::OpKind(at::aten::narrow))); - } - if (node->op().op == at::aten::permute) { - return LowerPermute(torch::lazy::NodeCast( - node, torch::lazy::OpKind(at::aten::permute))); - } - if (node->op().op == at::aten::select) { - return LowerSelect(torch::lazy::NodeCast( - node, torch::lazy::OpKind(at::aten::select))); - } - if (node->op().op == at::aten::squeeze) { - return LowerSqueeze(torch::lazy::NodeCast( - node, torch::lazy::OpKind(at::aten::squeeze))); - } - if (node->op().op == at::aten::unsqueeze) { - return LowerUnsqueeze(torch::lazy::NodeCast( - node, torch::lazy::OpKind(at::aten::unsqueeze))); - } - if (node->op().op == at::aten::view) { - return LowerView(torch::lazy::NodeCast( - node, torch::lazy::OpKind(at::aten::view))); - } - if (node->op() == *torch::lazy::ltc_device_data) { - const torch::lazy::DeviceData* device_data_node = - torch::lazy::NodeCast( - node, *torch::lazy::ltc_device_data); - auto infoptr = device_data_node->data()->info(); - auto deviceDataInfoPtr = - (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr; - if (GRAPH_DUMP_ENABLED) { - LOG(ERROR) << "Lowering device data node, tensor id " - << deviceDataInfoPtr->tensor_id << std::endl; - } - return {loctx()->GetParameter(device_data_node->data())}; - } +void GenerateCopy(torch::jit::Value* destination, torch::jit::Value* source, TorchMlirFunction function) { std::vector arguments; - for (const torch::lazy::Output& output : node->operands()) { - arguments.emplace_back(loctx()->GetOutputOp(output)); - } - return LowerBuiltin(node, arguments); - } + arguments.emplace_back(destination); + arguments.emplace_back(source); + LowerBuiltin( + at::aten::copy_, + c10::ArrayRef(compute_shape_copy(source->type())), function, arguments); +} - TorchMlirOpVector LowerBuiltin( - const torch::lazy::Node* node, - const std::vector& arguments, - const std::vector& kwarguments = {}) { - return LowerTorchMlirBuiltin( - function_, node->op().op, node->shapes(), arguments, kwarguments); - } - TorchMlirOpVector LowerBuiltin( - c10::Symbol sym, const c10::ArrayRef result_shapes, - const std::vector& arguments, - const std::vector& kwarguments = {}) { - return LowerTorchMlirBuiltin( - function_, sym, result_shapes, arguments, kwarguments); - } - TorchMlirOpVector LowerBuiltin( - c10::Symbol sym, const std::vector types, - const std::vector& arguments, - const std::vector& kwarguments = {}) { - return LowerTorchMlirBuiltin(function_, sym, types, arguments, kwarguments); - } - TorchMlirOpVector LowerAsStrided(const torch::lazy::AsStrided* node) { - std::vector arguments; - arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); - arguments.emplace_back(node->size); - arguments.emplace_back(node->stride); - arguments.emplace_back(node->storage_offset); - TorchMlirOpVector as_strided_out = LowerBuiltin(node, arguments); - CHECK_EQ(as_strided_out.size(), 1); - return {GenerateClone(as_strided_out.front())}; - } +torch::jit::Value* GenerateSlice( + torch::jit::Value* base, int64_t dim, int64_t start, int64_t end, + int64_t step, TorchMlirFunction function) { + std::vector arguments; + arguments.emplace_back(base); + arguments.emplace_back(dim); + arguments.emplace_back(start); + arguments.emplace_back(end); + arguments.emplace_back(step); + + TorchMlirOpVector selected = LowerBuiltin( + at::aten::slice, + c10::ArrayRef( + compute_shape_slice(base->type(), dim, start, end, step)), + function, + arguments); + CHECK_EQ(selected.size(), 1); + return selected.front(); +} - TorchMlirOpVector - LowerAsStridedViewUpdate(const torch::lazy::AsStridedViewUpdate* node) { - torch::jit::Value* destination = - GenerateClone(loctx()->GetOutputOp(node->operand(0))); - const torch::lazy::Output& input_op = node->operand(1); - const torch::lazy::Shape& input_shape = input_op.shape(); - const auto input_dimensions = input_shape.sizes(); - std::vector dest_arguments; - dest_arguments.emplace_back(destination); - dest_arguments.emplace_back( - std::vector(input_dimensions.begin(), input_dimensions.end())); - dest_arguments.emplace_back(node->stride); - dest_arguments.emplace_back(node->storage_offset); - TorchMlirOpVector as_strided_out = - LowerBuiltin(at::aten::as_strided, node->shapes(), dest_arguments); - CHECK_EQ(as_strided_out.size(), 1); - torch::jit::Value* as_strided = as_strided_out.front(); - GenerateCopy(as_strided, loctx()->GetOutputOp(input_op)); - return {destination}; - } +// Node Lowerings - TorchMlirOpVector LowerCast(const torch::lazy::Cast* node) { - std::vector arguments; - arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); - arguments.emplace_back(node->dtype); - return LowerBuiltin(at::aten::to, node->shapes(), arguments); +// Default Node Lowering +TorchMlirOpVector TorchMlirNode::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + std::vector arguments; + for (const torch::lazy::Output& output : operands()) { + arguments.emplace_back(loctx->GetOutputOp(output)); } + return LowerBuiltin(this, function, arguments); +} - TorchMlirOpVector LowerExpand(const torch::lazy::Expand* node) { - std::vector arguments; - arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); - arguments.emplace_back(node->size); - auto expand_out = LowerBuiltin(node, arguments); - if (node->is_scalar_expand) { - // The aten::expand operations sets all strides to 0 when the original - // of rank 0. This leads to false positives when checking for internal - // memory overlap, because at::has_internal_overlap returns - // MemOverlap::YES when a stride is set to 0. - CHECK_EQ(expand_out.size(), 1); - return {GenerateClone(expand_out.front())}; - } - return expand_out; - } +// TorchMlir specific nodes - TorchMlirOpVector LowerNarrow(const torch::lazy::Narrow* node) { - const torch::lazy::Output& input = node->operand(0); - torch::jit::Value* base = loctx()->GetOutputOp(input); - const auto& base_indices = node->base_indices; - const auto& sizes = node->sizes; - const torch::lazy::Shape& input_shape = input.shape(); - CHECK_EQ(sizes.size(), base_indices.size()); - CHECK_EQ(input_shape.dim(), base_indices.size()); - for (size_t dim = 0; dim < base_indices.size(); ++dim) { - int64_t start = base_indices[dim]; - base = GenerateSlice( - /*base=*/base, /*dim=*/dim, /*start=*/start, - /*end=*/start + sizes[dim], /*step=*/1); - } - return {base}; - } +// Non-native nodes - TorchMlirOpVector LowerPermute(const torch::lazy::Permute* node) { - std::vector arguments; - arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); - arguments.push_back(node->dims); - return LowerBuiltin(node, arguments); - } +TorchMlirOpVector +Cast::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + std::vector arguments; + arguments.emplace_back(loctx->GetOutputOp(operand(0))); + arguments.emplace_back(dtype); + return LowerBuiltin(at::aten::to, shapes(), function, arguments); +} - TorchMlirOpVector LowerScalar(const torch::lazy::Scalar* node) { - const at::Scalar& value = node->value; - const torch::lazy::Shape& shape = node->shape(); - auto options = - at::TensorOptions() - .device(torch::lazy::getBackend()->EagerFallbackDeviceType()) - .dtype(shape.scalar_type()); - return { - loctx()->graph()->insertConstant(at::scalar_tensor(value, options))}; +TorchMlirOpVector DeviceData::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + auto infoptr = data_->info(); + auto deviceDataInfoPtr = + (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr; + if (GRAPH_DUMP_ENABLED) { + LOG(ERROR) << "Lowering device data node, tensor id " + << deviceDataInfoPtr->tensor_id << std::endl; } + return {loctx->GetParameter(data_)}; +} - TorchMlirOpVector LowerSelect(const torch::lazy::Select* node) { - int64_t step = torch::lazy::GetStride(node->start, node->end, node->stride); - torch::jit::Value* base = loctx()->GetOutputOp(node->operand(0)); - return {GenerateSlice( - /*base=*/base, /*dim=*/node->dim, - /*start=*/node->start, /*end=*/node->end, - /*step=*/step)}; +TorchMlirOpVector Expand::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + std::vector arguments; + arguments.emplace_back(loctx->GetOutputOp(operand(0))); + arguments.emplace_back(size); + auto expand_out = LowerBuiltin(this, function, arguments); + if (is_scalar_expand) { + // The aten::expand operations sets all strides to 0 when the original is + // of rank 0. This leads to false positives when checking for internal + // memory overlap, because at::has_internal_overlap returns + // MemOverlap::YES when a stride is set to 0. + CHECK_EQ(expand_out.size(), 1); + return {GenerateClone(expand_out.front(), function)}; } + return expand_out; +} - TorchMlirOpVector LowerSqueeze(const torch::lazy::Squeeze* node) { - std::vector arguments; - arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); - if (node->dim != -1) { - arguments.push_back(node->dim); - } - return LowerBuiltin(node, arguments); - } +TorchMlirOpVector Scalar::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + auto options = + at::TensorOptions() + .device(torch::lazy::getBackend()->EagerFallbackDeviceType()) + .dtype(shape().scalar_type()); + return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))}; +} - TorchMlirOpVector - LowerSelectViewUpdate(const torch::lazy::SelectViewUpdate* node) { - torch::jit::Value* dest = - GenerateClone(loctx()->GetOutputOp(node->operand(0))); - int64_t step = torch::lazy::GetStride(node->start, node->end, node->stride); - torch::jit::Value* selected = GenerateSlice( - /*base=*/dest, /*dim=*/node->dim, /*start=*/node->start, - /*end=*/node->end, /*step=*/step); - GenerateCopy(selected, loctx()->GetOutputOp(node->operand(1))); - return {dest}; - } +// View Ops - TorchMlirOpVector - LowerNarrowViewUpdate(const torch::lazy::NarrowViewUpdate* node) { - torch::jit::Value* dest = - GenerateClone(loctx()->GetOutputOp(node->operand(0))); - const auto& base_indices = node->base_indices; - const torch::lazy::Output& source_argument = node->operand(1); - const torch::lazy::Shape& source_shape = source_argument.shape(); - CHECK_EQ(source_shape.dim(), base_indices.size()); - torch::jit::Value* base = dest; - for (size_t dim = 0; dim < base_indices.size(); ++dim) { - int64_t start = base_indices[dim]; - base = GenerateSlice( - /*base=*/base, /*dim=*/dim, /*start=*/start, - /*end=*/start + source_shape.size(dim), - /*step=*/1); - } - GenerateCopy(base, loctx()->GetOutputOp(source_argument)); - return {dest}; - } +TorchMlirOpVector AsStrided::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { - TorchMlirOpVector LowerUnsqueeze(const torch::lazy::Unsqueeze* node) { - std::vector arguments; - arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); - arguments.push_back(node->dim); - return LowerBuiltin(node, arguments); + std::vector arguments; + arguments.emplace_back(loctx->GetOutputOp(operand(0))); + arguments.emplace_back(size); + arguments.emplace_back(stride); + arguments.emplace_back(storage_offset); + TorchMlirOpVector as_strided_out = LowerBuiltin(this, function, arguments); + CHECK_EQ(as_strided_out.size(), 1); + return {GenerateClone(as_strided_out.front(), function)}; +} + +TorchMlirOpVector AsStridedViewUpdate::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + + torch::jit::Value* destination = + GenerateClone(loctx->GetOutputOp(operand(0)), function); + const torch::lazy::Output& input_op = operand(1); + const torch::lazy::Shape& input_shape = input_op.shape(); + const auto input_dimensions = input_shape.sizes(); + std::vector dest_arguments; + dest_arguments.emplace_back(destination); + dest_arguments.emplace_back( + std::vector(input_dimensions.begin(), input_dimensions.end())); + dest_arguments.emplace_back(stride); + dest_arguments.emplace_back(storage_offset); + TorchMlirOpVector as_strided_out = + LowerBuiltin(at::aten::as_strided, shapes(), function, dest_arguments); + CHECK_EQ(as_strided_out.size(), 1); + torch::jit::Value* as_strided = as_strided_out.front(); + GenerateCopy(as_strided, loctx->GetOutputOp(input_op), function); + return {destination}; +} + +TorchMlirOpVector Diagonal::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + + std::vector arguments; + arguments.emplace_back(loctx->GetOutputOp(operand(0))); + arguments.emplace_back(offset); + arguments.emplace_back(dim1); + arguments.emplace_back(dim2); + return LowerBuiltin(this, function, arguments); +} + +TorchMlirOpVector DiagonalViewUpdate::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + // Since we promise the backends that we never generate any aliased + // inplace update IR, therefore we clone the target first and then + // update the clone inplace instead. Since the clone is transient, + // it will never be aliased, and therefore it's safe. + torch::jit::Value* destination = + GenerateClone(loctx->GetOutputOp(operand(0)), function); + + // Replay the diagonal. + std::vector arguments; + arguments.emplace_back(destination); + arguments.emplace_back(offset); + arguments.emplace_back(dim1); + arguments.emplace_back(dim2); + auto diag = LowerBuiltin(at::aten::diagonal, shapes(), function, arguments); + + // Update the replayed diagonal view with the input. + GenerateCopy(diag.front(), loctx->GetOutputOp(operand(1)), function); + + // Destination's diag view should be updated. + return {destination}; +} + +TorchMlirOpVector Narrow::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + const torch::lazy::Output& input = operand(0); + torch::jit::Value* base = loctx->GetOutputOp(input); + const torch::lazy::Shape& input_shape = input.shape(); + CHECK_EQ(sizes.size(), base_indices.size()); + CHECK_EQ(input_shape.dim(), base_indices.size()); + for (size_t dim = 0; dim < base_indices.size(); ++dim) { + int64_t start = base_indices[dim]; + base = GenerateSlice( + /*base=*/base, /*dim=*/dim, /*start=*/start, + /*end=*/start + sizes[dim], /*step=*/1, + /*function=*/function); } + return {base}; +} - TorchMlirOpVector LowerView(const torch::lazy::View* node) { - std::vector arguments; - arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); - arguments.push_back(node->output_size); - return LowerBuiltin(at::aten::reshape, node->shapes(), arguments); +TorchMlirOpVector NarrowViewUpdate::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + torch::jit::Value* dest = + GenerateClone(loctx->GetOutputOp(operand(0)), function); + const torch::lazy::Output& source_argument = operand(1); + const torch::lazy::Shape& source_shape = source_argument.shape(); + CHECK_EQ(source_shape.dim(), base_indices.size()); + torch::jit::Value* base = dest; + for (size_t dim = 0; dim < base_indices.size(); ++dim) { + int64_t start = base_indices[dim]; + base = GenerateSlice( + /*base=*/base, /*dim=*/dim, /*start=*/start, + /*end=*/start + source_shape.size(dim), /*step=*/1, + /*function=*/function); } + GenerateCopy(base, loctx->GetOutputOp(source_argument), function); + return {dest}; +} - torch::jit::Value* GenerateClone(torch::jit::Value* val) { - std::vector clone_arguments; - clone_arguments.emplace_back(val); +TorchMlirOpVector Permute::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + std::vector arguments; + arguments.emplace_back(loctx->GetOutputOp(operand(0))); + arguments.emplace_back(dims); + return LowerBuiltin(this, function, arguments); +} - // Type of cloned value should be identical to the original one. - TorchMlirOpVector cloned = - LowerBuiltin(at::aten::clone, {val->type()}, clone_arguments); - CHECK_EQ(cloned.size(), 1); - return cloned.front(); - } +TorchMlirOpVector Resize::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { - void GenerateCopy(torch::jit::Value* destination, torch::jit::Value* source) { - std::vector arguments; - arguments.emplace_back(destination); - arguments.emplace_back(source); - LowerBuiltin( - at::aten::copy_, - c10::ArrayRef(compute_shape_copy(source->type())), arguments); + std::vector arguments; + for (const torch::lazy::Output& output : operands()) { + arguments.emplace_back(loctx->GetOutputOp(output)); } + return LowerBuiltin(this, function, arguments); +} - torch::jit::Value* GenerateSlice( - torch::jit::Value* base, int64_t dim, int64_t start, int64_t end, - int64_t step) { - std::vector arguments; - arguments.emplace_back(base); +TorchMlirOpVector Select::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + int64_t step = torch::lazy::GetStride(start, end, stride); + torch::jit::Value* base = loctx->GetOutputOp(operand(0)); + return {GenerateSlice( + /*base=*/base, /*dim=*/dim, + /*start=*/start, /*end=*/end, + /*step=*/step, /*function=*/function)}; +} + +TorchMlirOpVector SelectViewUpdate::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + torch::jit::Value* dest = + GenerateClone(loctx->GetOutputOp(operand(0)), function); + int64_t step = torch::lazy::GetStride(start, end, stride); + torch::jit::Value* selected = GenerateSlice( + /*base=*/dest, /*dim=*/dim, /*start=*/start, + /*end=*/end, /*step=*/step, /*function=*/function); + GenerateCopy(selected, loctx->GetOutputOp(operand(1)), function); + return {dest}; +} + +TorchMlirOpVector Squeeze::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + std::vector arguments; + arguments.emplace_back(loctx->GetOutputOp(operand(0))); + if (dim != -1) { arguments.emplace_back(dim); - arguments.emplace_back(start); - arguments.emplace_back(end); - arguments.emplace_back(step); - - TorchMlirOpVector selected = LowerBuiltin( - at::aten::slice, - c10::ArrayRef( - compute_shape_slice(base->type(), dim, start, end, step)), - arguments); - CHECK_EQ(selected.size(), 1); - return selected.front(); } - torch::lazy::TorchMlirLoweringContext* loctx_; - std::shared_ptr function_; -}; - -std::unique_ptr -TorchMlirNodeLoweringInterface::Create(torch::lazy::LoweringContext* loctx) { - return std::make_unique( - "TorchMlirNodeLowering", - static_cast(loctx)); + return LowerBuiltin(this, function, arguments); } + +TorchMlirOpVector Unsqueeze::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + std::vector arguments; + arguments.emplace_back(loctx->GetOutputOp(operand(0))); + arguments.emplace_back(dim); + return LowerBuiltin(this, function, arguments); +} + +TorchMlirOpVector +View::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + std::vector arguments; + arguments.emplace_back(loctx->GetOutputOp(operand(0))); + arguments.emplace_back(output_size); + return LowerBuiltin(at::aten::reshape, shapes(), function, arguments); +} + } // namespace lazy } // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.h b/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.h index 0e3e8d63505..a91cb1a7790 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.h +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.h @@ -1,5 +1,6 @@ #pragma once +#include "../mlir_lowering_context.h" #include "../mlir_node.h" #include @@ -34,6 +35,8 @@ class TORCH_API DeviceData : public TorchMlirNode { data_ = data; } + TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override; + static const DeviceData* Cast(const Node* node); // To reuse IR nodes, use this method to create DeviceData nodes