From e326000330ad7ed174393f46b475060bb53e2f69 Mon Sep 17 00:00:00 2001 From: Ben Koopman Date: Sun, 11 Jun 2023 21:37:14 -0400 Subject: [PATCH 1/3] Basic dispatcher impl --- check.py | 21 ++++++++----- cpp/lltm.cpp | 74 ++++++++++++++++++++++++++++++++++++++++++++-- cpp/lltm.py | 7 +++-- cpp/setup.py | 2 +- cuda/lltm.py | 7 +++-- cuda/lltm_cuda.cpp | 14 +++++++-- grad_check.py | 15 ++++++---- 7 files changed, 113 insertions(+), 27 deletions(-) diff --git a/check.py b/check.py index 8fad6d1..e08dca0 100644 --- a/check.py +++ b/check.py @@ -6,8 +6,9 @@ import torch import python.lltm_baseline -import cpp.lltm - +#import cpp.lltm +torch.ops.load_library("cpp/build/lib.linux-x86_64-cpython-39/lltm_cpp.cpython-39-x86_64-linux-gnu.so") +torch.ops.load_library("cuda/build/lib.linux-x86_64-cpython-39/lltm_cuda.cpython-39-x86_64-linux-gnu.so") def check_equal(first, second, verbose): if verbose: @@ -19,8 +20,7 @@ def check_equal(first, second, verbose): print("x = {}".format(x.flatten())) print("y = {}".format(y.flatten())) print('-' * 80) - np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i)) - + np.testing.assert_allclose(x, y, rtol=2e-6, atol=2e-7, err_msg="Index: {}".format(i)) def zero_grad(variables): for variable in variables: @@ -33,14 +33,19 @@ def get_grads(variables): def check_forward(variables, with_cuda, verbose): baseline_values = python.lltm_baseline.LLTMFunction.apply(*variables) - cpp_values = cpp.lltm.LLTMFunction.apply(*variables) + cpp_variables = [v.cpu() for v in variables] + cpp_values = torch.ops.myops.lltm(*cpp_variables) +# cpp_values = cpp.lltm.LLTMFunction.apply(*variables) + print('Forward: Baseline (Python) vs. C++ ... ', end='') check_equal(baseline_values, cpp_values, verbose) print('Ok') if with_cuda: - cuda_values = cuda.lltm.LLTMFunction.apply(*variables) + cuda_variables = [v.cuda() for v in variables] + cuda_values = torch.ops.myops.lltm(*cuda_variables) +# cuda_values = cuda.lltm.LLTMFunction.apply(*variables) print('Forward: Baseline (Python) vs. CUDA ... ', end='') check_equal(baseline_values, cuda_values, verbose) print('Ok') @@ -53,7 +58,7 @@ def check_backward(variables, with_cuda, verbose): zero_grad(variables) - cpp_values = cpp.lltm.LLTMFunction.apply(*variables) + cpp_values = torch.ops.myops.lltm(*variables) (cpp_values[0] + cpp_values[1]).sum().backward() grad_cpp = get_grads(variables) @@ -63,7 +68,7 @@ def check_backward(variables, with_cuda, verbose): if with_cuda: zero_grad(variables) - cuda_values = cuda.lltm.LLTMFunction.apply(*variables) + cuda_values = torch.ops.myops.lltm(*variables) (cuda_values[0] + cuda_values[1]).sum().backward() grad_cuda = get_grads(variables) diff --git a/cpp/lltm.cpp b/cpp/lltm.cpp index 9bdfe0c..1e28fbe 100644 --- a/cpp/lltm.cpp +++ b/cpp/lltm.cpp @@ -26,6 +26,8 @@ std::vector lltm_forward( torch::Tensor bias, torch::Tensor old_h, torch::Tensor old_cell) { + + std::cout << "CPU!!!" << std::endl; auto X = torch::cat({old_h, input}, /*dim=*/1); auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); @@ -84,7 +86,73 @@ std::vector lltm_backward( return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &lltm_forward, "LLTM forward"); - m.def("backward", &lltm_backward, "LLTM backward"); +// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +// m.def("forward", &lltm_forward, "LLTM forward"); +// m.def("backward", &lltm_backward, "LLTM backward"); +// } + + + +std::vector lltm_op(torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell){ + static auto op = torch::Dispatcher::singleton() + .findSchemaOrThrow("myops::lltm", "") + .typed(); + return op.call(input, weights, bias, old_h, old_cell); +} + +std::vector lltm_op_backward( + torch::Tensor grad_h, + torch::Tensor grad_cell, + torch::Tensor new_cell, + torch::Tensor input_gate, + torch::Tensor output_gate, + torch::Tensor candidate_cell, + torch::Tensor X, + torch::Tensor gate_weights, + torch::Tensor weights){ + static auto op = torch::Dispatcher::singleton() + .findSchemaOrThrow("myops::lltm", "backward") + .typed(); + return op.call(grad_h, grad_cell, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights, weights); +} + +class LLTMFunction : public torch::autograd::Function { +public: + static std::vector forward( + torch::autograd::AutogradContext *ctx, torch::Tensor input, torch::Tensor weights, torch::Tensor bias, torch::Tensor old_h, torch::Tensor old_cell){ + at::AutoNonVariableTypeMode g; + std::vector outputs = lltm_op(input, weights, bias, old_h, old_cell); + ctx->save_for_backward({outputs[1], outputs[2], outputs[3], outputs[4], outputs[5], outputs[6], weights}); + + return {outputs[0], outputs[1]}; + } + + static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, torch::autograd::tensor_list grad_outputs){ + auto saved = ctx->get_saved_variables(); + auto outputs = lltm_op_backward(grad_outputs[0].contiguous(), grad_outputs[1].contiguous(), saved[0], saved[1], saved[2], saved[3], saved[4], saved[5], saved[6]); + return {outputs[1], outputs[2], outputs[3], outputs[0], outputs[4]}; + } +}; + +std::vector lltm_autograd(torch::Tensor input, torch::Tensor weights, torch::Tensor bias, torch::Tensor old_h, torch::Tensor old_cell) { + return LLTMFunction::apply(input, weights, bias, old_h, old_cell); +} + +TORCH_LIBRARY(myops, m){ + m.def("lltm(Tensor input, Tensor weights, Tensor bias, Tensor old_h, Tensor old_cell) -> Tensor[]"); + m.def("lltm.backward(Tensor grad_h, Tensor grad_cell, Tensor new_cell, Tensor input_gate, Tensor output_gate, Tensor candidate_cell, Tensor X, Tensor gate_weights, Tensor weights) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(myops, CPU, m){ + m.impl(TORCH_SELECTIVE_NAME("lltm"), TORCH_FN(lltm_forward)); + m.impl(TORCH_SELECTIVE_NAME("lltm.backward"), TORCH_FN(lltm_backward)); +} + + +TORCH_LIBRARY_IMPL(myops, Autograd, m) { + m.impl("lltm", lltm_autograd); } diff --git a/cpp/lltm.py b/cpp/lltm.py index 24cf82d..4a430b6 100644 --- a/cpp/lltm.py +++ b/cpp/lltm.py @@ -3,7 +3,8 @@ from torch.autograd import Function import torch -import lltm_cpp +#import lltm_cpp +torch.ops.load_library("cpp/build/lib.linux-x86_64-cpython-39/lltm_cpp.cpython-39-x86_64-linux-gnu.so") torch.manual_seed(42) @@ -11,7 +12,7 @@ class LLTMFunction(Function): @staticmethod def forward(ctx, input, weights, bias, old_h, old_cell): - outputs = lltm_cpp.forward(input, weights, bias, old_h, old_cell) + outputs = torch.ops.myops.lltm(input, weights, bias, old_h, old_cell) new_h, new_cell = outputs[:2] variables = outputs[1:] + [weights] ctx.save_for_backward(*variables) @@ -20,7 +21,7 @@ def forward(ctx, input, weights, bias, old_h, old_cell): @staticmethod def backward(ctx, grad_h, grad_cell): - d_old_h, d_input, d_weights, d_bias, d_old_cell = lltm_cpp.backward( + d_old_h, d_input, d_weights, d_bias, d_old_cell = torch.ops.myops.lltm.backward( grad_h, grad_cell, *ctx.saved_variables) return d_input, d_weights, d_bias, d_old_h, d_old_cell diff --git a/cpp/setup.py b/cpp/setup.py index 7a4c164..d663805 100644 --- a/cpp/setup.py +++ b/cpp/setup.py @@ -4,7 +4,7 @@ setup( name='lltm_cpp', ext_modules=[ - CppExtension('lltm_cpp', ['lltm.cpp']), + CppExtension('lltm_cpp', ['lltm.cpp'], library_dirs=['/lib/x86_64-linux-gnu/'], runtime_library_dirs=['/lib/x86_64-linux-gnu/']), ], cmdclass={ 'build_ext': BuildExtension diff --git a/cuda/lltm.py b/cuda/lltm.py index c740b88..3f19c02 100644 --- a/cuda/lltm.py +++ b/cuda/lltm.py @@ -3,7 +3,8 @@ from torch.autograd import Function import torch -import lltm_cuda +#import lltm_cuda +torch.ops.load_library("cuda/build/lib.linux-x86_64-cpython-39/lltm_cuda.cpython-39-x86_64-linux-gnu.so") torch.manual_seed(42) @@ -11,7 +12,7 @@ class LLTMFunction(Function): @staticmethod def forward(ctx, input, weights, bias, old_h, old_cell): - outputs = lltm_cuda.forward(input, weights, bias, old_h, old_cell) + outputs = torch.ops.myops.lltm(input, weights, bias, old_h, old_cell) new_h, new_cell = outputs[:2] variables = outputs[1:] + [weights] ctx.save_for_backward(*variables) @@ -20,7 +21,7 @@ def forward(ctx, input, weights, bias, old_h, old_cell): @staticmethod def backward(ctx, grad_h, grad_cell): - outputs = lltm_cuda.backward( + outputs = torch.ops.myops.lltm.backward( grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables) d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates = outputs return d_input, d_weights, d_bias, d_old_h, d_old_cell diff --git a/cuda/lltm_cuda.cpp b/cuda/lltm_cuda.cpp index 2434776..016824f 100644 --- a/cuda/lltm_cuda.cpp +++ b/cuda/lltm_cuda.cpp @@ -35,6 +35,8 @@ std::vector lltm_forward( torch::Tensor bias, torch::Tensor old_h, torch::Tensor old_cell) { + std::cout << "CUDA!!!!" << std::endl; + CHECK_INPUT(input); CHECK_INPUT(weights); CHECK_INPUT(bias); @@ -75,7 +77,13 @@ std::vector lltm_backward( weights); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &lltm_forward, "LLTM forward (CUDA)"); - m.def("backward", &lltm_backward, "LLTM backward (CUDA)"); +// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +// m.def("forward", &lltm_forward, "LLTM forward (CUDA)"); +// m.def("backward", &lltm_backward, "LLTM backward (CUDA)"); +// } + +TORCH_LIBRARY_IMPL(myops, CUDA, m){ + m.impl(TORCH_SELECTIVE_NAME("lltm"), TORCH_FN(lltm_forward)); + m.impl(TORCH_SELECTIVE_NAME("lltm.backward"), TORCH_FN(lltm_backward)); } + diff --git a/grad_check.py b/grad_check.py index caf3b36..eebed08 100644 --- a/grad_check.py +++ b/grad_check.py @@ -13,6 +13,9 @@ parser.add_argument('-c', '--cuda', action='store_true') options = parser.parse_args() +torch.ops.load_library("cpp/build/lib.linux-x86_64-cpython-39/lltm_cpp.cpython-39-x86_64-linux-gnu.so") +torch.ops.load_library("cuda/build/lib.linux-x86_64-cpython-39/lltm_cuda.cpython-39-x86_64-linux-gnu.so") + if options.example == 'py': from python.lltm_baseline import LLTMFunction elif options.example == 'cpp': @@ -27,14 +30,14 @@ 'device': device, 'requires_grad': True} -X = torch.randn(options.batch_size, options.features, **kwargs) -h = torch.randn(options.batch_size, options.state_size, **kwargs) -C = torch.randn(options.batch_size, options.state_size, **kwargs) -W = torch.randn(3 * options.state_size, options.features + options.state_size, **kwargs) -b = torch.randn(1, 3 * options.state_size, **kwargs) +X = torch.randn(options.batch_size, options.features, **kwargs).to(device) +h = torch.randn(options.batch_size, options.state_size, **kwargs).to(device) +C = torch.randn(options.batch_size, options.state_size, **kwargs).to(device) +W = torch.randn(3 * options.state_size, options.features + options.state_size, **kwargs).to(device) +b = torch.randn(1, 3 * options.state_size, **kwargs).to(device) variables = [X, W, b, h, C] -if gradcheck(LLTMFunction.apply, variables): +if gradcheck(torch.ops.myops.lltm, variables): print('Ok') From 23f3d41045779e52e5c5c0565deeb4b04c9c4ce0 Mon Sep 17 00:00:00 2001 From: Ben Koopman Date: Sun, 11 Jun 2023 23:46:19 -0400 Subject: [PATCH 2/3] Convert pybind11 -> dispatcher --- benchmark.py | 6 +++++ check.py | 24 ++++++++++++++----- cpp/lltm.cpp | 57 ++++++++++++++++++++++++++-------------------- cpp/lltm.py | 36 ++++++++++++----------------- cuda/lltm.py | 43 ++++++++++++++++------------------ cuda/lltm_cuda.cpp | 8 +------ grad_check.py | 24 +++++++++++++------ 7 files changed, 108 insertions(+), 90 deletions(-) diff --git a/benchmark.py b/benchmark.py index 212da08..4a99c59 100644 --- a/benchmark.py +++ b/benchmark.py @@ -3,9 +3,14 @@ import argparse import math +import os +import glob import time + import torch +import torch.utils.cpp_extension +import pkg_resources TIME_SCALES = {'s': 1, 'ms': 1000, 'us': 1000000} @@ -20,6 +25,7 @@ parser.add_argument('-d', '--double', action='store_true') options = parser.parse_args() +LIB_EXT = torch.utils.cpp_extension.LIB_EXT if options.example == 'py': from python.lltm import LLTM elif options.example == 'cpp': diff --git a/check.py b/check.py index e08dca0..743ee12 100644 --- a/check.py +++ b/check.py @@ -3,12 +3,13 @@ import argparse import numpy as np +import os +import glob import torch +import torch.utils.cpp_extension +import pkg_resources import python.lltm_baseline -#import cpp.lltm -torch.ops.load_library("cpp/build/lib.linux-x86_64-cpython-39/lltm_cpp.cpython-39-x86_64-linux-gnu.so") -torch.ops.load_library("cuda/build/lib.linux-x86_64-cpython-39/lltm_cuda.cpython-39-x86_64-linux-gnu.so") def check_equal(first, second, verbose): if verbose: @@ -35,8 +36,6 @@ def check_forward(variables, with_cuda, verbose): baseline_values = python.lltm_baseline.LLTMFunction.apply(*variables) cpp_variables = [v.cpu() for v in variables] cpp_values = torch.ops.myops.lltm(*cpp_variables) -# cpp_values = cpp.lltm.LLTMFunction.apply(*variables) - print('Forward: Baseline (Python) vs. C++ ... ', end='') check_equal(baseline_values, cpp_values, verbose) @@ -45,7 +44,6 @@ def check_forward(variables, with_cuda, verbose): if with_cuda: cuda_variables = [v.cuda() for v in variables] cuda_values = torch.ops.myops.lltm(*cuda_variables) -# cuda_values = cuda.lltm.LLTMFunction.apply(*variables) print('Forward: Baseline (Python) vs. CUDA ... ', end='') check_equal(baseline_values, cuda_values, verbose) print('Ok') @@ -86,9 +84,22 @@ def check_backward(variables, with_cuda, verbose): parser.add_argument('-v', '--verbose', action='store_true') options = parser.parse_args() +LIB_EXT = torch.utils.cpp_extension.LIB_EXT +cpp_module_path = os.path.dirname( + pkg_resources.resource_filename( + pkg_resources.Requirement.parse('lltm_cpp'), "lltm_cpp.py")) +cpp_lib_path = glob.glob(os.path.join(cpp_module_path, f"lltm_cpp*{LIB_EXT}"))[0] +torch.ops.load_library(cpp_lib_path) + if options.cuda: import cuda.lltm device = torch.device("cuda") + + cuda_module_path = os.path.dirname( + pkg_resources.resource_filename( + pkg_resources.Requirement.parse('lltm_cuda'), "lltm_cuda.py")) + cuda_lib_path = glob.glob(os.path.join(cuda_module_path, f"lltm_cuda*{LIB_EXT}"))[0] + torch.ops.load_library(cuda_lib_path) else: device = torch.device("cpu") @@ -105,6 +116,7 @@ def check_backward(variables, with_cuda, verbose): variables = [X, W, b, h, C] + if 'forward' in options.direction: check_forward(variables, options.cuda, options.verbose) diff --git a/cpp/lltm.cpp b/cpp/lltm.cpp index 1e28fbe..37f93f9 100644 --- a/cpp/lltm.cpp +++ b/cpp/lltm.cpp @@ -27,7 +27,6 @@ std::vector lltm_forward( torch::Tensor old_h, torch::Tensor old_cell) { - std::cout << "CPU!!!" << std::endl; auto X = torch::cat({old_h, input}, /*dim=*/1); auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); @@ -86,26 +85,18 @@ std::vector lltm_backward( return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; } -// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { -// m.def("forward", &lltm_forward, "LLTM forward"); -// m.def("backward", &lltm_backward, "LLTM backward"); -// } - - - std::vector lltm_op(torch::Tensor input, - torch::Tensor weights, - torch::Tensor bias, - torch::Tensor old_h, - torch::Tensor old_cell){ + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell){ static auto op = torch::Dispatcher::singleton() .findSchemaOrThrow("myops::lltm", "") .typed(); return op.call(input, weights, bias, old_h, old_cell); } -std::vector lltm_op_backward( - torch::Tensor grad_h, +std::vector lltm_op_backward(torch::Tensor grad_h, torch::Tensor grad_cell, torch::Tensor new_cell, torch::Tensor input_gate, @@ -117,34 +108,51 @@ std::vector lltm_op_backward( static auto op = torch::Dispatcher::singleton() .findSchemaOrThrow("myops::lltm", "backward") .typed(); - return op.call(grad_h, grad_cell, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights, weights); + return op.call(grad_h, grad_cell, new_cell, input_gate, + output_gate, candidate_cell, X, gate_weights, weights); } class LLTMFunction : public torch::autograd::Function { public: - static std::vector forward( - torch::autograd::AutogradContext *ctx, torch::Tensor input, torch::Tensor weights, torch::Tensor bias, torch::Tensor old_h, torch::Tensor old_cell){ - at::AutoNonVariableTypeMode g; + static std::vector forward(torch::autograd::AutogradContext *ctx, + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell){ + at::AutoDispatchBelowADInplaceOrView g; std::vector outputs = lltm_op(input, weights, bias, old_h, old_cell); - ctx->save_for_backward({outputs[1], outputs[2], outputs[3], outputs[4], outputs[5], outputs[6], weights}); + ctx->save_for_backward({outputs[1], outputs[2], outputs[3], + outputs[4], outputs[5], outputs[6], weights}); return {outputs[0], outputs[1]}; } - static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, torch::autograd::tensor_list grad_outputs){ + static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, + torch::autograd::tensor_list grad_outputs){ auto saved = ctx->get_saved_variables(); - auto outputs = lltm_op_backward(grad_outputs[0].contiguous(), grad_outputs[1].contiguous(), saved[0], saved[1], saved[2], saved[3], saved[4], saved[5], saved[6]); + auto outputs = lltm_op_backward(grad_outputs[0].contiguous(), + grad_outputs[1].contiguous(), + saved[0], saved[1], saved[2], saved[3], + saved[4], saved[5], saved[6]); return {outputs[1], outputs[2], outputs[3], outputs[0], outputs[4]}; } }; -std::vector lltm_autograd(torch::Tensor input, torch::Tensor weights, torch::Tensor bias, torch::Tensor old_h, torch::Tensor old_cell) { +std::vector lltm_autograd(torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell) { return LLTMFunction::apply(input, weights, bias, old_h, old_cell); } TORCH_LIBRARY(myops, m){ - m.def("lltm(Tensor input, Tensor weights, Tensor bias, Tensor old_h, Tensor old_cell) -> Tensor[]"); - m.def("lltm.backward(Tensor grad_h, Tensor grad_cell, Tensor new_cell, Tensor input_gate, Tensor output_gate, Tensor candidate_cell, Tensor X, Tensor gate_weights, Tensor weights) -> Tensor[]"); + m.def("lltm(Tensor input, Tensor weights, Tensor bias, Tensor old_h, Tensor old_cell)" \ + "-> Tensor[]"); + m.def("lltm.backward(Tensor grad_h, Tensor grad_cell, Tensor new_cell, " \ + "Tensor input_gate, Tensor output_gate, Tensor candidate_cell, Tensor X, " \ + "Tensor gate_weights, Tensor weights) -> Tensor[]"); } TORCH_LIBRARY_IMPL(myops, CPU, m){ @@ -152,7 +160,6 @@ TORCH_LIBRARY_IMPL(myops, CPU, m){ m.impl(TORCH_SELECTIVE_NAME("lltm.backward"), TORCH_FN(lltm_backward)); } - TORCH_LIBRARY_IMPL(myops, Autograd, m) { m.impl("lltm", lltm_autograd); } diff --git a/cpp/lltm.py b/cpp/lltm.py index 4a430b6..79d6e87 100644 --- a/cpp/lltm.py +++ b/cpp/lltm.py @@ -1,31 +1,23 @@ import math +import os from torch import nn from torch.autograd import Function +import glob import torch - -#import lltm_cpp -torch.ops.load_library("cpp/build/lib.linux-x86_64-cpython-39/lltm_cpp.cpython-39-x86_64-linux-gnu.so") +import torch.utils.cpp_extension +import pkg_resources + +# Get the location of shared library for the lltm op, and load it. +LIB_EXT = torch.utils.cpp_extension.LIB_EXT +cpp_module_path = os.path.dirname( + pkg_resources.resource_filename( + pkg_resources.Requirement.parse('lltm_cpp'), "lltm_cpp.py")) +cpp_lib_path = glob.glob( + os.path.join(cpp_module_path, f"lltm_cpp*{LIB_EXT}"))[0] +torch.ops.load_library(cpp_lib_path) torch.manual_seed(42) - -class LLTMFunction(Function): - @staticmethod - def forward(ctx, input, weights, bias, old_h, old_cell): - outputs = torch.ops.myops.lltm(input, weights, bias, old_h, old_cell) - new_h, new_cell = outputs[:2] - variables = outputs[1:] + [weights] - ctx.save_for_backward(*variables) - - return new_h, new_cell - - @staticmethod - def backward(ctx, grad_h, grad_cell): - d_old_h, d_input, d_weights, d_bias, d_old_cell = torch.ops.myops.lltm.backward( - grad_h, grad_cell, *ctx.saved_variables) - return d_input, d_weights, d_bias, d_old_h, d_old_cell - - class LLTM(nn.Module): def __init__(self, input_features, state_size): super(LLTM, self).__init__() @@ -42,4 +34,4 @@ def reset_parameters(self): weight.data.uniform_(-stdv, +stdv) def forward(self, input, state): - return LLTMFunction.apply(input, self.weights, self.bias, *state) + return torch.ops.myops.lltm(input, self.weights, self.bias, *state) diff --git a/cuda/lltm.py b/cuda/lltm.py index 3f19c02..5fa07f2 100644 --- a/cuda/lltm.py +++ b/cuda/lltm.py @@ -1,32 +1,29 @@ import math +import os from torch import nn from torch.autograd import Function import torch - -#import lltm_cuda -torch.ops.load_library("cuda/build/lib.linux-x86_64-cpython-39/lltm_cuda.cpython-39-x86_64-linux-gnu.so") +import glob +import torch.utils.cpp_extension +import pkg_resources + +# Get the location of shared library for the lltm op, and load it. +LIB_EXT = torch.utils.cpp_extension.LIB_EXT +# Note: currently there is a dependency on the CPP lib, due to the schema definition +# Eventually, this should move to use a single library registering both CPP and CUDA ops +cpp_module_path = os.path.dirname( + pkg_resources.resource_filename( + pkg_resources.Requirement.parse('lltm_cpp'), "lltm_cpp.py")) +cpp_lib_path = glob.glob(os.path.join(cpp_module_path, f"lltm_cpp*{LIB_EXT}"))[0] +torch.ops.load_library(cpp_lib_path) +cuda_module_path = os.path.dirname( + pkg_resources.resource_filename( + pkg_resources.Requirement.parse('lltm_cuda'), "lltm_cuda.py")) +cuda_lib_path = glob.glob(os.path.join(cuda_module_path, f"lltm_cuda*{LIB_EXT}"))[0] +torch.ops.load_library(cuda_lib_path) torch.manual_seed(42) - -class LLTMFunction(Function): - @staticmethod - def forward(ctx, input, weights, bias, old_h, old_cell): - outputs = torch.ops.myops.lltm(input, weights, bias, old_h, old_cell) - new_h, new_cell = outputs[:2] - variables = outputs[1:] + [weights] - ctx.save_for_backward(*variables) - - return new_h, new_cell - - @staticmethod - def backward(ctx, grad_h, grad_cell): - outputs = torch.ops.myops.lltm.backward( - grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables) - d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates = outputs - return d_input, d_weights, d_bias, d_old_h, d_old_cell - - class LLTM(nn.Module): def __init__(self, input_features, state_size): super(LLTM, self).__init__() @@ -43,4 +40,4 @@ def reset_parameters(self): weight.data.uniform_(-stdv, +stdv) def forward(self, input, state): - return LLTMFunction.apply(input, self.weights, self.bias, *state) + return torch.ops.myops.lltm(input, self.weights, self.bias, *state) diff --git a/cuda/lltm_cuda.cpp b/cuda/lltm_cuda.cpp index 016824f..e907161 100644 --- a/cuda/lltm_cuda.cpp +++ b/cuda/lltm_cuda.cpp @@ -35,8 +35,7 @@ std::vector lltm_forward( torch::Tensor bias, torch::Tensor old_h, torch::Tensor old_cell) { - std::cout << "CUDA!!!!" << std::endl; - + CHECK_INPUT(input); CHECK_INPUT(weights); CHECK_INPUT(bias); @@ -77,11 +76,6 @@ std::vector lltm_backward( weights); } -// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { -// m.def("forward", &lltm_forward, "LLTM forward (CUDA)"); -// m.def("backward", &lltm_backward, "LLTM backward (CUDA)"); -// } - TORCH_LIBRARY_IMPL(myops, CUDA, m){ m.impl(TORCH_SELECTIVE_NAME("lltm"), TORCH_FN(lltm_forward)); m.impl(TORCH_SELECTIVE_NAME("lltm.backward"), TORCH_FN(lltm_backward)); diff --git a/grad_check.py b/grad_check.py index eebed08..728a63c 100644 --- a/grad_check.py +++ b/grad_check.py @@ -13,16 +13,26 @@ parser.add_argument('-c', '--cuda', action='store_true') options = parser.parse_args() -torch.ops.load_library("cpp/build/lib.linux-x86_64-cpython-39/lltm_cpp.cpython-39-x86_64-linux-gnu.so") -torch.ops.load_library("cuda/build/lib.linux-x86_64-cpython-39/lltm_cuda.cpython-39-x86_64-linux-gnu.so") +cpp_module_path = os.path.dirname( + pkg_resources.resource_filename( + pkg_resources.Requirement.parse('lltm_cpp'), "lltm_cpp.py")) +cpp_lib_path = glob.glob(os.path.join(cpp_module_path, "lltm_cpp*.so"))[0] +torch.ops.load_library(cpp_lib_path) + +cuda_module_path = os.path.dirname( + pkg_resources.resource_filename( + pkg_resources.Requirement.parse('lltm_cuda'), "lltm_cuda.py")) +cuda_lib_path = glob.glob(os.path.join(cuda_module_path, "lltm_cuda*.so"))[0] +torch.ops.load_library(cuda_lib_path) + if options.example == 'py': from python.lltm_baseline import LLTMFunction -elif options.example == 'cpp': - from cpp.lltm import LLTMFunction + lltm_func = LLTMFunction.apply else: - from cuda.lltm import LLTMFunction - options.cuda = True + lltm_func = torch.ops.myops.lltm + +options.cuda |= (options.example == "cuda") device = torch.device("cuda") if options.cuda else torch.device("cpu") @@ -39,5 +49,5 @@ variables = [X, W, b, h, C] -if gradcheck(torch.ops.myops.lltm, variables): +if gradcheck(lltm_func, variables): print('Ok') From a638a72d2680e87a77309cc8195d7941d0bb49b8 Mon Sep 17 00:00:00 2001 From: Ben Koopman Date: Mon, 12 Jun 2023 09:31:54 -0400 Subject: [PATCH 3/3] Remove cruft from gradcheck --- grad_check.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/grad_check.py b/grad_check.py index 728a63c..5f097fc 100644 --- a/grad_check.py +++ b/grad_check.py @@ -2,8 +2,11 @@ from __future__ import print_function import argparse +import os +import pkg_resources import torch from torch.autograd import gradcheck +import glob parser = argparse.ArgumentParser() parser.add_argument('example', choices=['py', 'cpp', 'cuda']) @@ -40,11 +43,13 @@ 'device': device, 'requires_grad': True} -X = torch.randn(options.batch_size, options.features, **kwargs).to(device) -h = torch.randn(options.batch_size, options.state_size, **kwargs).to(device) -C = torch.randn(options.batch_size, options.state_size, **kwargs).to(device) -W = torch.randn(3 * options.state_size, options.features + options.state_size, **kwargs).to(device) -b = torch.randn(1, 3 * options.state_size, **kwargs).to(device) +X = torch.randn(options.batch_size, options.features, **kwargs) +h = torch.randn(options.batch_size, options.state_size, **kwargs) +C = torch.randn(options.batch_size, options.state_size, **kwargs) +W = torch.randn(3 * options.state_size, + options.features + options.state_size, + **kwargs) +b = torch.randn(1, 3 * options.state_size, **kwargs) variables = [X, W, b, h, C]