Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dispatcher API #83

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@

import argparse
import math
import os
import glob
import time


import torch
import torch.utils.cpp_extension
import pkg_resources

TIME_SCALES = {'s': 1, 'ms': 1000, 'us': 1000000}

Expand All @@ -20,6 +25,7 @@
parser.add_argument('-d', '--double', action='store_true')
options = parser.parse_args()

LIB_EXT = torch.utils.cpp_extension.LIB_EXT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like all the changes you made in benchmark.py are not used.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The op is called from LLTM module class' forward() See {cpp|cuda}/lltm.py

if options.example == 'py':
from python.lltm import LLTM
elif options.example == 'cpp':
Expand Down
33 changes: 25 additions & 8 deletions check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import argparse
import numpy as np
import os
import glob
import torch
import torch.utils.cpp_extension
import pkg_resources

import python.lltm_baseline
import cpp.lltm


def check_equal(first, second, verbose):
if verbose:
Expand All @@ -19,8 +21,7 @@ def check_equal(first, second, verbose):
print("x = {}".format(x.flatten()))
print("y = {}".format(y.flatten()))
print('-' * 80)
np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i))

np.testing.assert_allclose(x, y, rtol=2e-6, atol=2e-7, err_msg="Index: {}".format(i))

def zero_grad(variables):
for variable in variables:
Expand All @@ -33,14 +34,16 @@ def get_grads(variables):

def check_forward(variables, with_cuda, verbose):
baseline_values = python.lltm_baseline.LLTMFunction.apply(*variables)
cpp_values = cpp.lltm.LLTMFunction.apply(*variables)
cpp_variables = [v.cpu() for v in variables]
cpp_values = torch.ops.myops.lltm(*cpp_variables)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like it's something you already call from LLTMFunction in cpp/lltm.py, so why not keep the LLTMFunction object here as it was before and place the backend dependent code in evry module (cpp and cuda) ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this has been migrated into an operator that is registered via dispatcher API, python implementation of LLTMFunction no longer exists (see cpp/lltm.cpp:115).

If I understand the intent of the code correctly, this use of LLTMFunction was a way to directly call the C++ forward implementation via python, but now we are using the dispatcher API, and can call the forward op directly. (I think ideally the baseline solution would be updated to match this pattern too.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C++ LLTMFunction impl is primarily used for re-dispatching from autograd op, as per: https://pytorch.org/tutorials/advanced/dispatcher.html#adding-autograd-support


print('Forward: Baseline (Python) vs. C++ ... ', end='')
check_equal(baseline_values, cpp_values, verbose)
print('Ok')

if with_cuda:
cuda_values = cuda.lltm.LLTMFunction.apply(*variables)
cuda_variables = [v.cuda() for v in variables]
cuda_values = torch.ops.myops.lltm(*cuda_variables)
print('Forward: Baseline (Python) vs. CUDA ... ', end='')
check_equal(baseline_values, cuda_values, verbose)
print('Ok')
Expand All @@ -53,7 +56,7 @@ def check_backward(variables, with_cuda, verbose):

zero_grad(variables)

cpp_values = cpp.lltm.LLTMFunction.apply(*variables)
cpp_values = torch.ops.myops.lltm(*variables)
(cpp_values[0] + cpp_values[1]).sum().backward()
grad_cpp = get_grads(variables)

Expand All @@ -63,7 +66,7 @@ def check_backward(variables, with_cuda, verbose):

if with_cuda:
zero_grad(variables)
cuda_values = cuda.lltm.LLTMFunction.apply(*variables)
cuda_values = torch.ops.myops.lltm(*variables)
(cuda_values[0] + cuda_values[1]).sum().backward()
grad_cuda = get_grads(variables)

Expand All @@ -81,9 +84,22 @@ def check_backward(variables, with_cuda, verbose):
parser.add_argument('-v', '--verbose', action='store_true')
options = parser.parse_args()

LIB_EXT = torch.utils.cpp_extension.LIB_EXT
cpp_module_path = os.path.dirname(
pkg_resources.resource_filename(
pkg_resources.Requirement.parse('lltm_cpp'), "lltm_cpp.py"))
cpp_lib_path = glob.glob(os.path.join(cpp_module_path, f"lltm_cpp*{LIB_EXT}"))[0]
torch.ops.load_library(cpp_lib_path)

if options.cuda:
import cuda.lltm
device = torch.device("cuda")

cuda_module_path = os.path.dirname(
pkg_resources.resource_filename(
pkg_resources.Requirement.parse('lltm_cuda'), "lltm_cuda.py"))
cuda_lib_path = glob.glob(os.path.join(cuda_module_path, f"lltm_cuda*{LIB_EXT}"))[0]
torch.ops.load_library(cuda_lib_path)
else:
device = torch.device("cpu")

Expand All @@ -100,6 +116,7 @@ def check_backward(variables, with_cuda, verbose):

variables = [X, W, b, h, C]


if 'forward' in options.direction:
check_forward(variables, options.cuda, options.verbose)

Expand Down
81 changes: 78 additions & 3 deletions cpp/lltm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ std::vector<torch::Tensor> lltm_forward(
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {

auto X = torch::cat({old_h, input}, /*dim=*/1);

auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
Expand Down Expand Up @@ -84,7 +85,81 @@ std::vector<torch::Tensor> lltm_backward(
return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &lltm_forward, "LLTM forward");
m.def("backward", &lltm_backward, "LLTM backward");
std::vector<torch::Tensor> lltm_op(torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell){
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("myops::lltm", "")
.typed<decltype(lltm_op)>();
return op.call(input, weights, bias, old_h, old_cell);
}

std::vector<torch::Tensor> lltm_op_backward(torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights){
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("myops::lltm", "backward")
.typed<decltype(lltm_op_backward)>();
return op.call(grad_h, grad_cell, new_cell, input_gate,
output_gate, candidate_cell, X, gate_weights, weights);
}

class LLTMFunction : public torch::autograd::Function<LLTMFunction> {
public:
static std::vector<torch::Tensor> forward(torch::autograd::AutogradContext *ctx,
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell){
at::AutoDispatchBelowADInplaceOrView g;
std::vector<torch::Tensor> outputs = lltm_op(input, weights, bias, old_h, old_cell);
ctx->save_for_backward({outputs[1], outputs[2], outputs[3],
outputs[4], outputs[5], outputs[6], weights});

return {outputs[0], outputs[1]};
}

static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx,
torch::autograd::tensor_list grad_outputs){
auto saved = ctx->get_saved_variables();
auto outputs = lltm_op_backward(grad_outputs[0].contiguous(),
grad_outputs[1].contiguous(),
saved[0], saved[1], saved[2], saved[3],
saved[4], saved[5], saved[6]);
return {outputs[1], outputs[2], outputs[3], outputs[0], outputs[4]};
}
};

std::vector<torch::Tensor> lltm_autograd(torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {
return LLTMFunction::apply(input, weights, bias, old_h, old_cell);
}

TORCH_LIBRARY(myops, m){
m.def("lltm(Tensor input, Tensor weights, Tensor bias, Tensor old_h, Tensor old_cell)" \
"-> Tensor[]");
m.def("lltm.backward(Tensor grad_h, Tensor grad_cell, Tensor new_cell, " \
"Tensor input_gate, Tensor output_gate, Tensor candidate_cell, Tensor X, " \
"Tensor gate_weights, Tensor weights) -> Tensor[]");
}

TORCH_LIBRARY_IMPL(myops, CPU, m){
m.impl(TORCH_SELECTIVE_NAME("lltm"), TORCH_FN(lltm_forward));
m.impl(TORCH_SELECTIVE_NAME("lltm.backward"), TORCH_FN(lltm_backward));
}

TORCH_LIBRARY_IMPL(myops, Autograd, m) {
m.impl("lltm", lltm_autograd);
}
35 changes: 14 additions & 21 deletions cpp/lltm.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,23 @@
import math
import os
from torch import nn
from torch.autograd import Function
import glob
import torch

import lltm_cpp
import torch.utils.cpp_extension
import pkg_resources

# Get the location of shared library for the lltm op, and load it.
LIB_EXT = torch.utils.cpp_extension.LIB_EXT
cpp_module_path = os.path.dirname(
pkg_resources.resource_filename(
pkg_resources.Requirement.parse('lltm_cpp'), "lltm_cpp.py"))
cpp_lib_path = glob.glob(
os.path.join(cpp_module_path, f"lltm_cpp*{LIB_EXT}"))[0]
torch.ops.load_library(cpp_lib_path)

torch.manual_seed(42)


class LLTMFunction(Function):
@staticmethod
def forward(ctx, input, weights, bias, old_h, old_cell):
outputs = lltm_cpp.forward(input, weights, bias, old_h, old_cell)
new_h, new_cell = outputs[:2]
variables = outputs[1:] + [weights]
ctx.save_for_backward(*variables)

return new_h, new_cell

@staticmethod
def backward(ctx, grad_h, grad_cell):
d_old_h, d_input, d_weights, d_bias, d_old_cell = lltm_cpp.backward(
grad_h, grad_cell, *ctx.saved_variables)
return d_input, d_weights, d_bias, d_old_h, d_old_cell


class LLTM(nn.Module):
def __init__(self, input_features, state_size):
super(LLTM, self).__init__()
Expand All @@ -41,4 +34,4 @@ def reset_parameters(self):
weight.data.uniform_(-stdv, +stdv)

def forward(self, input, state):
return LLTMFunction.apply(input, self.weights, self.bias, *state)
return torch.ops.myops.lltm(input, self.weights, self.bias, *state)
2 changes: 1 addition & 1 deletion cpp/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
setup(
name='lltm_cpp',
ext_modules=[
CppExtension('lltm_cpp', ['lltm.cpp']),
CppExtension('lltm_cpp', ['lltm.cpp'], library_dirs=['/lib/x86_64-linux-gnu/'], runtime_library_dirs=['/lib/x86_64-linux-gnu/']),
],
cmdclass={
'build_ext': BuildExtension
Expand Down
42 changes: 20 additions & 22 deletions cuda/lltm.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
import math
import os
from torch import nn
from torch.autograd import Function
import torch

import lltm_cuda
import glob
import torch.utils.cpp_extension
import pkg_resources

# Get the location of shared library for the lltm op, and load it.
LIB_EXT = torch.utils.cpp_extension.LIB_EXT
# Note: currently there is a dependency on the CPP lib, due to the schema definition
# Eventually, this should move to use a single library registering both CPP and CUDA ops
cpp_module_path = os.path.dirname(
pkg_resources.resource_filename(
pkg_resources.Requirement.parse('lltm_cpp'), "lltm_cpp.py"))
cpp_lib_path = glob.glob(os.path.join(cpp_module_path, f"lltm_cpp*{LIB_EXT}"))[0]
torch.ops.load_library(cpp_lib_path)
cuda_module_path = os.path.dirname(
pkg_resources.resource_filename(
pkg_resources.Requirement.parse('lltm_cuda'), "lltm_cuda.py"))
cuda_lib_path = glob.glob(os.path.join(cuda_module_path, f"lltm_cuda*{LIB_EXT}"))[0]
torch.ops.load_library(cuda_lib_path)

torch.manual_seed(42)


class LLTMFunction(Function):
@staticmethod
def forward(ctx, input, weights, bias, old_h, old_cell):
outputs = lltm_cuda.forward(input, weights, bias, old_h, old_cell)
new_h, new_cell = outputs[:2]
variables = outputs[1:] + [weights]
ctx.save_for_backward(*variables)

return new_h, new_cell

@staticmethod
def backward(ctx, grad_h, grad_cell):
outputs = lltm_cuda.backward(
grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables)
d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates = outputs
return d_input, d_weights, d_bias, d_old_h, d_old_cell


class LLTM(nn.Module):
def __init__(self, input_features, state_size):
super(LLTM, self).__init__()
Expand All @@ -42,4 +40,4 @@ def reset_parameters(self):
weight.data.uniform_(-stdv, +stdv)

def forward(self, input, state):
return LLTMFunction.apply(input, self.weights, self.bias, *state)
return torch.ops.myops.lltm(input, self.weights, self.bias, *state)
8 changes: 5 additions & 3 deletions cuda/lltm_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ std::vector<torch::Tensor> lltm_forward(
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {

CHECK_INPUT(input);
CHECK_INPUT(weights);
CHECK_INPUT(bias);
Expand Down Expand Up @@ -75,7 +76,8 @@ std::vector<torch::Tensor> lltm_backward(
weights);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &lltm_forward, "LLTM forward (CUDA)");
m.def("backward", &lltm_backward, "LLTM backward (CUDA)");
TORCH_LIBRARY_IMPL(myops, CUDA, m){
m.impl(TORCH_SELECTIVE_NAME("lltm"), TORCH_FN(lltm_forward));
m.impl(TORCH_SELECTIVE_NAME("lltm.backward"), TORCH_FN(lltm_backward));
}

Loading