diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index e99c8c2477b..a24ee750a1c 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -56,6 +56,11 @@ jobs: cd $GITHUB_WORKSPACE export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" python -m e2e_testing.torchscript.main --config=tosa -v + - name: Lazy Tensor Core - TorchScript end-to-end tests + run: | + cd $GITHUB_WORKSPACE + export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" + python -m e2e_testing.torchscript.main --config=lazy_tensor_core -v build-out-of-tree: name: Build out-of-tree (Release Asserts) diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index 08880f4e20c..10e86a89663 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -15,20 +15,20 @@ # Available test configs. from torch_mlir_e2e_test.torchscript.configs import ( - LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig + LazyTensorCoreTestConfig, LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig ) from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend -from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET +from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET # Import tests to register them in the global registry. from torch_mlir_e2e_test.test_suite import register_all_tests register_all_tests() def _get_argparse(): - config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode'] + config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode', 'lazy_tensor_core'] parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') parser.add_argument('-c', '--config', choices=config_choices, @@ -40,6 +40,7 @@ def _get_argparse(): "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). "eager_mode": run through torch-mlir's eager mode frontend, using RefBackend for execution. +"lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph. ''') parser.add_argument('-f', '--filter', default='.*', help=''' Regular expression specifying which tests to include in this run. @@ -86,6 +87,9 @@ def main(): elif args.config == 'eager_mode': config = EagerModeTestConfig() xfail_set = EAGER_MODE_XFAIL_SET + elif args.config == 'lazy_tensor_core': + config = LazyTensorCoreTestConfig() + xfail_set = LTC_XFAIL_SET # Find the selected tests, and emit a diagnostic if none are found. tests = [ diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index b03f3abac78..483f52a5af6 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -170,3 +170,312 @@ "NumpyTRankNDynamicModule_basic", "EmbeddingModuleI32Static_basic", } + +LTC_XFAIL_SET = { + "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", + "AddIntModule_basic", + "AllBoolFalseModule_basic", + "AllBoolTrueModule_basic", + "AnyBoolFalseModule_basic", + "AnyBoolTrueModule_basic", + "ArangeDtypeFloatModule_basic", + "ArangeDtypeIntModule_basic", + "ArangeFalsePinMemoryModule_basic", + "ArangeFloatModule_basic", + "ArangeIntModule_basic", + "ArangeNegativeStartFloatModule_basic", + "ArangeNegativeStartIntModule_basic", + "ArangeStartFloatModule_basic", + "ArangeStartIntModule_basic", + "ArangeStartNegativeStepFloatModule_basic", + "ArangeStartNegativeStepIntModule_basic", + "ArangeStartStepFloatModule_basic", + "ArangeStartStepIntModule_basic", + "ArangeZeroElementOutputModule_basic", + "AvgPool2dCeilModeTrueModule_basic", + "AvgPool2dDivisorOverrideModule_basic", + "AvgPool2dFloatModule_basic", + "AvgPool2dIntModule_basic", + "AvgPool2dStaticModule_basic", + "BernoulliFloatModule_basic", + "BernoulliModule_basic", + "BernoulliOnesModule_basic", + "BernoulliTensorModule_basic", + "BernoulliZerosModule_basic", + "BincountMinlengthModule_basic", + "BincountModule_basic", + "BincountStaticSizeModule_basic", + "BoolFloatConstantModule_basic", + "BoolFloatFalseModule_basic", + "BoolFloatTrueModule_basic", + "BoolIntConstantModule_basic", + "BoolIntFalseModule_basic", + "BoolIntTrueModule_basic", + "CeilFloatModule_basic", + "DivFloatModule_basic", + "DropoutTrainModule_basic", + "ElementwiseAtenLogicalOrOpBrodcastModule_basic", + "ElementwiseAtenLogicalOrOpDiffArgs1Module_basic", + "ElementwiseAtenLogicalOrOpDiffArgs2Module_basic", + "ElementwiseAtenLogicalOrOpDiffArgs3Module_basic", + "ElementwiseAtenLogicalOrOpModule_basic", + "ElementwiseAtenLogicalOrOpNegativeModule_basic", + "ElementwiseAtenLogicalOrOpRandomFloatModule_basic", + "ElementwiseAtenLogicalOrOpRandomModule_basic", + "ElementwiseClampMaxModule_basic", + "ElementwiseClampMinModule_basic", + "ElementwiseClampModule_basic", + "ElementwiseWhereScalarModule_basic", + "ElementwiseWhereScalarOtherModule_basic", + "ElementwiseWhereScalarSelfModule_basic", + "ElementwiseWhereSelfModule_basic", + "EmptyLikeMemoryFormatModule_basic", + "EmptyLikeModule_defaultDtype", + "EmptyLikeModule_falsePinMemory", + "EmptyLikeModule_float", + "EmptyLikeModule_int", + "EmptyModule_contiguous", + "EmptyModule_defaultDtype", + "EmptyModule_falsePinMemory", + "EmptyModule_float", + "EmptyModule_int", + "EqIntModule_basic", + "Fill_TensorFloat64WithFloat32_basic", + "Fill_TensorFloat64WithFloat64_basic", + "Fill_TensorFloat64WithInt64_basic", + "FullLikeModuleDefaultDtype_basic", + "FullLikeModuleFalsePinMemory_basic", + "FullLikeModuleFloat2D_basic", + "FullLikeModuleFloat3DStatic_basic", + "FullLikeModuleFloat3D_basic", + "FullLikeModuleInt2DStatic_basic", + "FullLikeModuleInt2D_basic", + "FullLikeModuleInt3D_basic", + "FullModuleDefaultDtype_basic", + "FullModuleFalsePinMemory_basic", + "FullModuleFloat2D_basic", + "FullModuleFloat3D_basic", + "FullModuleInt2D_basic", + "FullModuleInt3D_basic", + "GeFloatIntModule_basic", + "GeFloatModule_basic", + "GtFloatIntModule_basic", + "GtIntModule_basic", + "HBC_basic", + "HardTanhIntModule_basic", + "HardTanhModule_basic", + "HardswishModule_basic", + "HardswishRandomModule_basic", + "IndexPut1DFloatAccumulateModule_basic", + "IndexPut1DFloatNonAccumulateModule_basic", + "IndexPut1DIntAccumulateModule_basic", + "IndexPut1DIntNonAccumulateModule_basic", + "IndexPut2DFloatAccumulateModule_basic", + "IndexPut2DFloatNonAccumulateModule_basic", + "IndexPut2DIntAccumulateModule_basic", + "IndexPut2DIntNonAccumulateModule_basic", + "IndexPut3DFloatAccumulateModule_basic", + "IndexPut3DFloatNonAccumulateModule_basic", + "IndexPut3DIntAccumulateModule_basic", + "IndexPut3DIntNonAccumulateModule_basic", + "IndexPutHackedTwin1DFloatAccumulateModule_basic", + "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin1DIntAccumulateModule_basic", + "IndexPutHackedTwin1DIntNonAccumulateModule_basic", + "IndexPutHackedTwin2DFloatAccumulateModule_basic", + "IndexPutHackedTwin2DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin2DIntAccumulateModule_basic", + "IndexPutHackedTwin2DIntNonAccumulateModule_basic", + "IndexPutHackedTwin3DFloatAccumulateModule_basic", + "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin3DIntAccumulateModule_basic", + "IndexPutHackedTwin3DIntNonAccumulateModule_basic", + "IndexPutImpl1DFloatAccumulateModule_basic", + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", + "IndexPutImpl2DFloatAccumulateModule_basic", + "IndexPutImpl2DFloatNonAccumulateModule_basic", + "IndexPutImpl3DFloatAccumulateModule_basic", + "IndexPutImpl3DFloatNonAccumulateModule_basic", + "IndexSelectDynamicIndexSizeModule_basic", + "IndexSelectDynamicInputSizeModule_basic", + "IndexSelectDynamicModulebasic", + "IndexSelectSingleIdxModule_basic", + "IndexSelectTwoIdxModule_basic", + "IndexSelectWholeDimensionModule_basic", + "IndexSelectWholeTensorModule_basic", + "IndexTensorModule_basic", + "MaskedFillScalarDefaultModule_basic", + "MaskedFillScalarFloatValueModule_basic", + "MaskedFillScalarIntValueModule_basic", + "Matmul_dot", + "Matmul_matvec", + "Matmul_vecmat", + "MaxPool2dCeilModeTrueModule_basic", + "MaxPool2dModule_basic", + "MaxPool2dStaticModule_basic", + "MaxPool2dWith3dInputModule_basic", + "MaxPool2dWithIndicesAllNegativeValuesModule_basic", + "MaxPool2dWithIndicesAllOnesModule_basic", + "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", + "MaxPool2dWithIndicesBackwardDynamic4DModule_basic", + "MaxPool2dWithIndicesBackwardStatic3DModule_basic", + "MaxPool2dWithIndicesBackwardStatic4DModule_basic", + "MaxPool2dWithIndicesCeilModeTrueModule_basic", + "MaxPool2dWithIndicesFullSizeKernelModule_basic", + "MaxPool2dWithIndicesModule_basic", + "MaxPool2dWithIndicesNonDefaultDilationModule_basic", + "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", + "MaxPool2dWithIndicesNonDefaultParamsModule_basic", + "MaxPool2dWithIndicesNonDefaultStrideModule_basic", + "MaxPool2dWithIndicesStaticModule_basic", + "MaxPool2dWithIndicesWith3dInputModule_basic", + "MeanDimAllReduceKeepdimModule_basic", + "MeanDimAllReduceModule_basic", + "MeanDimDtypeModule_basic", + "MeanDimKeepdimModule_basic", + "MeanDimModule_basic", + "MeanDimNegativeModule_basic", + "MeanDtypeModule_basic", + "MeanDynamicSizesModule_basic", + "MeanModule_basic", + "MobilenetV3Module_basic", + "MulIntModule_basic", + "NativeBatchNorm1DModule_basic", + "NativeBatchNorm2DModule_basic", + "NativeBatchNorm3DModule_basic", + "NativeBatchNormNoneWeightModule_basic", + "NativeLayerNormDynamicModule_basic", + "NativeLayerNormModule_basic", + "NeFloatIntModule_basic", + "NeIntModule_basic", + "NewEmptyModuleDefaultDtype_basic", + "NewEmptyModuleFalsePinMemory_basic", + "NewEmptyModuleFloat2D_basic", + "NewEmptyModuleFloat3D_basic", + "NewEmptyModuleInt2D_basic", + "NewEmptyModuleInt3D_basic", + "NewEmptyModuleLayoutIntDtype_basic", + "NewEmptyModuleNonDefaultFloatDtype_basic", + "NewEmptyModuleNonDefaultIntDtype_basic", + "NewOnesModuleDefaultDtype_basic", + "NewOnesModuleFalsePinMemory_basic", + "NewOnesModuleFloat2D_basic", + "NewOnesModuleFloat3D_basic", + "NewOnesModuleInt2D_basic", + "NewOnesModuleInt3D_basic", + "NewZerosModuleDefaultDtype_basic", + "NewZerosModuleFalsePinMemory_basic", + "NewZerosModuleFloat2D_basic", + "NewZerosModuleFloat3D_basic", + "NewZerosModuleInt2D_basic", + "NewZerosModuleInt3D_basic", + "NllLossModuleBackward1DMeanWeight_basic", + "NllLossModuleBackward1DMean_basic", + "NllLossModuleBackward1DSumWeight_basic", + "NllLossModuleBackward1DSum_basic", + "NllLossModuleBackward1DWeight_basic", + "NllLossModuleBackward1D_basic", + "NllLossModuleBackwardMeanWeight_basic", + "NllLossModuleBackwardMean_basic", + "NllLossModuleBackwardSumWeight_basic", + "NllLossModuleBackwardSum_basic", + "NllLossModuleBackwardWeight_basic", + "NllLossModuleBackward_basic", + "NllLossModuleBackward_ignore_index", + "NllLossModule_1D_basic", + "NllLossModule_basic", + "NllLossModule_ignore_index_out_of_bounds_basic", + "NllLossModule_mean_basic", + "NllLossModule_sum_basic", + "NumelModule_basic", + "NumelZeroRankModule_basic", + "OnesLikeModule_defaultDtype", + "OnesLikeModule_falsePinMemory", + "OnesLikeModule_float", + "OnesLikeModule_int", + "OnesModuleDefaultDtype_basic", + "OnesModuleFalsePinMemory_basic", + "OnesModuleFloat_basic", + "OnesModuleInt_basic", + "QuantizedMLP_basic", + "RandLikeDtypeModule_basic", + "RandLikeModule_basic", + "ReduceMaxKeepDimReturnBoth_basic", + "ReduceMaxNegativeDim_basic", + "ReshapeAliasCollapseModule_basic", + "ReshapeAliasExpandModule_basic", + "ReturnThreeTensorFloat32_basic", + "ReturnTwoTensorF32I64_basic", + "ScalarImplicitFloatModule_basic", + "ScalarImplicitIntModule_basic", + "SelectIntModule_basic", + "SliceEndSleStartModule_basic", + "SliceNegIdxModule_basic", + "SliceOutOfLowerBoundEndIndexModule_basic", + "SliceOutOfLowerBoundStartIndexModule_basic", + "SliceOutOfUpperBoundIndexModule_basic", + "SliceSingleIdxModule_basic", + "SliceSizeTwoStepModule_basic", + "SliceStartEqEndModule_basic", + "SliceWholeTensorModule_basic", + "SqrtIntConstantModule_basic", + "SqrtIntModule_basic", + "StdBiasedModule_basic", + "StdUnbiasedModule_basic", + "SubFloatModule_basic", + "SubIntModule_basic", + "TModuleRank0_basic", + "TModuleRank1_basic", + "TableBatchEmbeddingModule_basic", + "TensorToBoolZeroRank_basic", + "TensorToBool_basic", + "TensorToFloatZeroRank_basic", + "TensorToFloat_basic", + "TensorToIntZeroRank_basic", + "TensorToInt_basic", + "TensorsConcatModule_basic", + "TestMultipleTensorAndPrimitiveTypesReturn_basic", + "TestMultipleTensorReturn_basic", + "Threshold1dFloatModule_basic", + "Threshold1dIntI32Module_basic", + "Threshold1dIntModule_basic", + "Threshold2dFloatModule_basic", + "Threshold2dIntModule_basic", + "Threshold3dFloatModule_basic", + "Threshold3dIntModule_basic", + "ThresholdBackward1dFloatModule_basic", + "ThresholdBackward1dIntModule_basic", + "ThresholdBackward1dMixedModule_basic", + "ThresholdBackward2dFloatModule_basic", + "ThresholdBackward2dIntModule_basic", + "ThresholdBackward2dMixedModule_basic", + "ThresholdBackward3dFloatModule_basic", + "ThresholdBackward3dIntModule_basic", + "ThresholdBackward3dMixedModule_basic", + "TorchPrimLoopForLikeModule_basic", + "TorchPrimLoopWhileLikeModule_basic", + "UniformModule_basic", + "UniformStaticModule_basic", + "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + "VarBiasedModule_basic", + "VarUnbiasedModule_basic", + "ViewCollapseDynamicWithAtenSizeIntModule_basic", + "ZeroFloat32Module_basic", + "ZeroInt32Module_basic", + "ZeroInt64Module_basic", + "ZerosLikeModule_defaultDtype", + "ZerosLikeModule_falsePinMemory", + "ZerosLikeModule_float", + "ZerosLikeModule_int", + "ZerosModuleDefaultDtype_basic", + "ZerosModuleFalsePinMemory_basic", + "ZerosModuleFloat2D_basic", + "ZerosModuleFloat3D_basic", + "ZerosModuleInt2D_basic", + "ZerosModuleInt3D_basic", +} diff --git a/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.cpp b/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.cpp index 8f5b507cdc1..1b51346a6f0 100644 --- a/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.cpp +++ b/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.cpp @@ -12,8 +12,8 @@ #include #include -#include #include +#include #include #include #include @@ -60,10 +60,13 @@ class ExampleMlirBackendImpl : public torch::lazy::TorchMlirBackendImpl { // Vendor backend specific lowering can be exec here before returning. for (const auto &instance : instances) { - std::cout << "Instance received at Compile: \n" - << GetComputationBackendText(instance) << std::endl; + // Store computation instance for external access after compilation. + GetLatestComputation() = instance; } + std::cout << "Received " << instances.size() + << " computation instances at Compile!" << std::endl; + return instances; } @@ -133,9 +136,13 @@ class ExampleMlirBackendImpl : public torch::lazy::TorchMlirBackendImpl { * */ std::string GetComputationBackendText(const ComputationPtr computation) const override { - auto mlir_computation = - static_cast(computation.get()); - return mlir_computation->to_string(); + // Store computation instance for external access after compilation. + // We do this in GetComputationBackendText since there may be instances + // where a ComputationPtr does not pass through Compile (e.g. when using + // DumpUtil::ToBackend.) + GetLatestComputation() = computation; + + return computation->to_string(); } private: @@ -154,5 +161,11 @@ void InitExampleMlirBackend() { g_registrar.reset(new BackendRegistrar(GetExampleMlirBackendImpl())); } +ComputationPtr &GetLatestComputation() { + // Store the computation from the most recent compile. + static ComputationPtr computation; + return computation; +} + } // namespace lazy } // namespace torch diff --git a/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.h b/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.h index 377ae4d219f..4c915fa9fdd 100644 --- a/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.h +++ b/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.h @@ -23,5 +23,7 @@ torch::lazy::BackendImplInterface *GetExampleMlirBackendImpl(); void InitExampleMlirBackend(); +ComputationPtr &GetLatestComputation(); + } // namespace lazy } // namespace torch diff --git a/examples/ltc_backend/ltc_backend/csrc/example_mlir_backend_pybind.cpp b/examples/ltc_backend/ltc_backend/csrc/example_mlir_backend_pybind.cpp index 1474b4dc907..ff1aa766642 100644 --- a/examples/ltc_backend/ltc_backend/csrc/example_mlir_backend_pybind.cpp +++ b/examples/ltc_backend/ltc_backend/csrc/example_mlir_backend_pybind.cpp @@ -10,6 +10,8 @@ #include "torch/csrc/jit/python/pybind.h" #include "torch/csrc/lazy/backend/backend_interface.h" +#include + #include #include #include @@ -61,7 +63,16 @@ void Shutdown() { } // anonymous namespace PYBIND11_MODULE(_EXAMPLE_MLIR_BACKEND, m) { + py::class_(m, "TorchMlirComputation") + .def("to_string", &torch::lazy::TorchMlirComputation::to_string) + .def("debug_string", &torch::lazy::TorchMlirComputation::debug_string); + m.doc() = ("pybind11 for example MLIR LTC backend."); + m.def("get_latest_computation", []() { + auto computation = static_cast( + torch::lazy::GetLatestComputation().get()); + return py::cast(computation); + }); m.def("_initialize", []() { NoGilSection gil; Initialize(); diff --git a/examples/ltc_backend_bert.py b/examples/ltc_backend_bert.py index d8434f5ef14..d309ba87136 100644 --- a/examples/ltc_backend_bert.py +++ b/examples/ltc_backend_bert.py @@ -14,13 +14,19 @@ """ import argparse +import sys +from typing import List + +import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend import torch +import torch._C +import torch._lazy +import torch._lazy.ts_backend from datasets import load_dataset from datasets.dataset_dict import DatasetDict from torch.utils.data import DataLoader from transformers import BertForSequenceClassification, \ BertConfig, BertTokenizer, AdamW, get_scheduler -from typing import List def tokenize_dataset(dataset: DatasetDict) -> DatasetDict: @@ -42,8 +48,7 @@ def train(model: BertForSequenceClassification, num_epochs: int, num_training_steps: int, train_dataloader: DataLoader, - device: torch.device, - do_mark_step: bool) -> List[torch.Tensor]: + device: torch.device) -> List[torch.Tensor]: optimizer = AdamW(model.parameters(), lr=5e-5) lr_scheduler = get_scheduler('linear', optimizer=optimizer, num_warmup_steps=0, @@ -63,31 +68,21 @@ def train(model: BertForSequenceClassification, lr_scheduler.step() optimizer.zero_grad() - if do_mark_step and 'lazy' in str(model.device): + if 'lazy' in str(model.device): print("Calling Mark Step") torch._lazy.mark_step() return losses -def main(device, lower_only, full_size): - if device in ("TS", "MLIR_EXAMPLE"): - import torch._lazy - - if device == "TS": - import torch._lazy.ts_backend - - torch._lazy.ts_backend.init() - - elif device == "MLIR_EXAMPLE": - import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend - - ltc_backend._initialize() +def main(device='lazy', full_size=False): + """ + Load model to specified device. Ensure that any backends have been initialized by this point. - device = "lazy" - print("Initialized backend") - else: - device = device.lower() + :param device: name of device to load tensors to + :param full_size: if true, use a full pretrained bert-base-cased model instead of a smaller variant + """ + torch.manual_seed(0) tokenized_datasets = tokenize_dataset(load_dataset('imdb')) small_train_dataset = tokenized_datasets['train'].shuffle(seed=42) \ @@ -117,22 +112,20 @@ def main(device, lower_only, full_size): num_epochs = 3 num_training_steps = num_epochs * len(train_dataloader) - losses = train(model, num_epochs, - num_training_steps, train_dataloader, device, not lower_only) - - if lower_only: - print('\nJIT Graph:') - import torch._C - graph_str = torch._C._lazy._get_tensors_backend([losses[0]]) - print(graph_str) - else: - # Execute computation - print('Loss: ', losses) + losses = train(model, num_epochs, num_training_steps, train_dataloader, device) + # Get debug information from LTC + if 'ltc_backend' in sys.modules: + computation = ltc_backend.get_latest_computation() + if computation: + print(computation.debug_string()) -if __name__ == "__main__": - torch.manual_seed(0) + print('Loss: ', losses) + return model, losses + + +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "-d", @@ -142,13 +135,6 @@ def main(device, lower_only, full_size): default="MLIR_EXAMPLE", help="The device type", ) - parser.add_argument( - "-l", - "--lower_only", - action='store_true', - default=False, - help="Only get backend printout -- do not execute computation", - ) parser.add_argument( "-f", "--full_size", @@ -157,4 +143,17 @@ def main(device, lower_only, full_size): help="Use full sized BERT model instead of one with smaller parameterization", ) args = parser.parse_args() - main(args.device, args.lower_only, args.full_size) + + if args.device in ("TS", "MLIR_EXAMPLE"): + if args.device == "TS": + torch._lazy.ts_backend.init() + + elif args.device == "MLIR_EXAMPLE": + ltc_backend._initialize() + + device = "lazy" + print("Initialized backend") + else: + device = args.device.lower() + + main(device, args.full_size) diff --git a/examples/ltc_backend_mnist.py b/examples/ltc_backend_mnist.py index 7448bbc0b34..65f8a4e5f37 100644 --- a/examples/ltc_backend_mnist.py +++ b/examples/ltc_backend_mnist.py @@ -6,30 +6,22 @@ Example use of the example Torch MLIR LTC backend. """ import argparse +import sys +import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend +import torch +import torch._lazy +import torch._lazy.ts_backend import torch.nn.functional as F -def main(device): - import torch +def main(device='lazy'): + """ + Load model to specified device. Ensure that any backends have been initialized by this point. - if device in ("TS", "MLIR_EXAMPLE"): - import torch._lazy - - if device == "TS": - import torch._lazy.ts_backend - - torch._lazy.ts_backend.init() - - elif device == "MLIR_EXAMPLE": - import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend - - ltc_backend._initialize() - - device = "lazy" - print("Initialized backend") - else: - device = device.lower() + :param device: name of device to load tensors to + """ + torch.manual_seed(0) inputs = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.float32, device=device) assert inputs.device.type == device @@ -57,24 +49,35 @@ def forward(self, x): criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - optimizer.zero_grad() - outputs = model(inputs) - loss = criterion(outputs, targets) - loss.backward() - optimizer.step() + num_epochs = 3 + losses = [] + for _ in range(num_epochs): + optimizer.zero_grad() - if device == "lazy": - print("Calling Mark Step") - torch._lazy.mark_step() + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + losses.append(loss) - print() - print(loss) + optimizer.step() + if device == "lazy": + print("Calling Mark Step") + torch._lazy.mark_step() -if __name__ == "__main__": - torch.manual_seed(0) + # Get debug information from LTC + if 'ltc_backend' in sys.modules: + computation = ltc_backend.get_latest_computation() + if computation: + print(computation.debug_string()) + print(losses) + + return model, losses + + +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "-d", @@ -85,4 +88,17 @@ def forward(self, x): help="The device type", ) args = parser.parse_args() - main(args.device) + + if args.device in ("TS", "MLIR_EXAMPLE"): + if args.device == "TS": + torch._lazy.ts_backend.init() + + elif args.device == "MLIR_EXAMPLE": + ltc_backend._initialize() + + device = "lazy" + print("Initialized backend") + else: + device = args.device.lower() + + main(device) diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp index 8d57b6149ea..fb21dcd85c9 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -323,24 +323,14 @@ std::shared_ptr TorchMlirComputation::graph() const { MlirOperation TorchMlirComputation::func_op() const { return func_op_; } -const std::string TorchMlirComputation::to_string() const { - // Since we use the C-MLIR API, we need to use a callback to print. - MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) { - // user_data is a void ptr to some data structure of our choice -- in this - // case, the string stream where we'll be accumulating the strings. - std::stringstream* ss_ptr = static_cast(user_data); - *ss_ptr << std::string(part.data, part.length); - }; - +const std::string TorchMlirComputation::debug_string() const { std::stringstream ss; // JIT Graph ss << "JIT Graph: \n" << graph_->toString() << "\n\n"; // MLIR - ss << "MLIR: \n"; - mlirOperationPrint(func_op_, print_callback, &ss); - ss << "\n"; + ss << "MLIR: \n" << to_string() << "\n"; // Input/Output Mapping ss << "Input/Output Alias Mapping: \n"; @@ -356,5 +346,18 @@ const std::string TorchMlirComputation::to_string() const { return ss.str(); } +const std::string TorchMlirComputation::to_string() const { + // Since we use the C-MLIR API, we need to use a callback to print. + MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) { + // user_data is a void ptr to some data structure of our choice -- in this + // case, the string stream where we'll be accumulating the strings. + std::stringstream* ss_ptr = static_cast(user_data); + *ss_ptr << std::string(part.data, part.length); + }; + std::stringstream ss; + mlirOperationPrint(func_op_, print_callback, &ss); + return ss.str(); +} + } // namespace lazy } // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h index 7c4c36a91fc..e17e564b527 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h @@ -135,6 +135,8 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation { MlirOperation func_op() const; + const std::string debug_string() const; + const std::string to_string() const; private: diff --git a/python/torch_mlir_e2e_test/torchscript/configs/__init__.py b/python/torch_mlir_e2e_test/torchscript/configs/__init__.py index 14c2f48c36c..63d9a733940 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/__init__.py +++ b/python/torch_mlir_e2e_test/torchscript/configs/__init__.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. +from .lazy_tensor_core import LazyTensorCoreTestConfig from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig from .native_torch import NativeTorchTestConfig from .torchscript import TorchScriptTestConfig diff --git a/python/torch_mlir_e2e_test/torchscript/configs/lazy_tensor_core.py b/python/torch_mlir_e2e_test/torchscript/configs/lazy_tensor_core.py new file mode 100644 index 00000000000..9c5b90cda84 --- /dev/null +++ b/python/torch_mlir_e2e_test/torchscript/configs/lazy_tensor_core.py @@ -0,0 +1,34 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend +import torch +from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem + + +class LazyTensorCoreTestConfig(TestConfig): + """TestConfig that runs torch.nn.Module thru the Lazy Tensor Core frontend for Torch MLIR""" + + def __init__(self): + super().__init__() + ltc_backend._initialize() + + def compile(self, program: torch.nn.Module) -> torch.nn.Module: + return program.to('lazy') + + def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: + result: Trace = [] + + for item in trace: + # We need to move all the inputs to the lazy device before running in LTC. + lazy_inputs = [arg.to('lazy') for arg in item.inputs] + output = getattr(artifact, item.symbol)(*lazy_inputs) + + result.append( + TraceItem(symbol=item.symbol, + inputs=item.inputs, + output=output.to('cpu'))) + + return result