forked from msr-fiddle/pipedream
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpre_hook_pytorch_latest.patch
288 lines (270 loc) · 11.4 KB
/
pre_hook_pytorch_latest.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
diff --git a/torch/autograd/function.py b/torch/autograd/function.py
index da59648e13..f24503df16 100644
--- a/torch/autograd/function.py
+++ b/torch/autograd/function.py
@@ -69,6 +69,14 @@ class _HookMixin(object):
backward_hooks[handle.id] = hook
return backward_hooks, handle
+ @staticmethod
+ def _register_pre_hook(backward_pre_hooks, hook):
+ if backward_pre_hooks is None:
+ backward_pre_hooks = OrderedDict()
+ handle = hooks.RemovableHandle(backward_pre_hooks)
+ backward_pre_hooks[handle.id] = hook
+ return backward_pre_hooks, handle
+
class BackwardCFunction(_C._FunctionBase, _ContextMethodMixin, _HookMixin):
_is_legacy = False
diff --git a/torch/csrc/autograd/python_cpp_function.cpp b/torch/csrc/autograd/python_cpp_function.cpp
index 58341e95e7..f78a3aa5f8 100644
--- a/torch/csrc/autograd/python_cpp_function.cpp
+++ b/torch/csrc/autograd/python_cpp_function.cpp
@@ -158,6 +158,25 @@ PyObject* THPCppFunction_name(PyObject* self) {
return THPUtils_packString(fn.name());
}
+PyObject* THPCppFunction_register_pre_hook_dict(PyObject* self, PyObject* _var)
+{
+ if (!THPVariable_Check(_var)) {
+ return PyErr_Format(PyExc_TypeError, "_register_pre_hook_dict expected a variable");
+ }
+ auto var = (THPVariable*)_var;
+ auto& fn = *((THPCppFunction*)self)->cdata;
+ std::unique_ptr<FunctionPreHook> hook(
+ new PyFunctionPreHook(var->backward_pre_hooks, var->cdata.output_nr()));
+ fn.add_pre_hook(std::move(hook));
+ Py_RETURN_NONE;
+}
+
+PyObject* THPCppFunction_register_pre_hook(PyObject* self, PyObject* hook)
+{
+ auto& fn = *((THPCppFunction*)self)->cdata;
+ return registerFunctionPreHook(fn, hook);
+}
+
static struct PyMethodDef default_methods[] = {
THP_FUNCTION_DEFAULT_METHODS,
{nullptr}
@@ -268,4 +287,30 @@ PyObject* registerFunctionHook(Node& fn, PyObject* hook)
return handle;
}
+PyObject* registerFunctionPreHook(Node& fn, PyObject* hook)
+{
+ PyObject* dict = Py_None;
+ for (const auto& hook : fn.pre_hooks()) {
+ if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
+ dict = pyhook->dict;
+ break;
+ }
+ }
+
+ THPObjectPtr register_fn(PyObject_GetAttrString(THPFunctionClass, "_register_pre_hook"));
+ if (!register_fn) return nullptr;
+ THPObjectPtr res(PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr));
+ if (!res) return nullptr;
+
+ if (dict == Py_None) {
+ dict = PyTuple_GET_ITEM(res.get(), 0);
+ std::unique_ptr<FunctionPreHook> hook(new PyFunctionPreHook(dict, 0));
+ fn.add_pre_hook(std::move(hook));
+ }
+
+ PyObject* handle = PyTuple_GET_ITEM(res.get(), 1);
+ Py_INCREF(handle);
+ return handle;
+}
+
}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/python_cpp_function.h b/torch/csrc/autograd/python_cpp_function.h
index 6a4228ecd6..c4a9e05c8b 100644
--- a/torch/csrc/autograd/python_cpp_function.h
+++ b/torch/csrc/autograd/python_cpp_function.h
@@ -33,7 +33,9 @@ PyObject* CppFunction_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds)
#define THP_FUNCTION_DEFAULT_METHODS \
{(char*)"_register_hook_dict", (PyCFunction)THPCppFunction_register_hook_dict, METH_O, nullptr}, \
{(char*)"register_hook", (PyCFunction)THPCppFunction_register_hook, METH_O, nullptr}, \
- {(char*)"name", (PyCFunction)THPCppFunction_name, METH_NOARGS, nullptr}
+ {(char*)"name", (PyCFunction)THPCppFunction_name, METH_NOARGS, nullptr}, \
+ {(char*)"_register_pre_hook_dict", (PyCFunction)THPCppFunction_register_pre_hook_dict, METH_O, nullptr}, \
+ {(char*)"register_pre_hook", (PyCFunction)THPCppFunction_register_pre_hook, METH_O, nullptr}
#define THP_FUNCTION_DEFAULT_PROPERTIES \
{(char*)"next_functions", (getter)THPCppFunction_next_functions, nullptr, nullptr, nullptr}, \
@@ -46,11 +48,14 @@ PyObject* THPCppFunction_requires_grad(THPCppFunction* self);
PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var);
PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook);
PyObject* THPCppFunction_name(PyObject* self);
+PyObject* THPCppFunction_register_pre_hook_dict(PyObject* self, PyObject* _var);
+PyObject* THPCppFunction_register_pre_hook(PyObject* self, PyObject* hook);
PyTypeObject* _initFunctionPyTypeObject(PyTypeObject& type, const char* name,
PyGetSetDef* function_properties, PyMethodDef* function_methods);
PyObject* registerFunctionHook(Node& fn, PyObject* hook);
+PyObject* registerFunctionPreHook(Node& fn, PyObject* hook);
template<typename Ctor>
PyTypeObject* createForwardFunctionPyTypeObject(PyTypeObject& type, const char* name,
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index 517bb068f7..7232ba4594 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -885,6 +885,14 @@ PyObject* THPFunction_register_hook(THPFunction *self, PyObject *hook)
END_HANDLE_TH_ERRORS
}
+PyObject* THPFunction_register_pre_hook(THPFunction *self, PyObject *hook)
+{
+ HANDLE_TH_ERRORS
+ auto cdata = self->cdata.lock();
+ return torch::autograd::registerFunctionPreHook(*cdata, hook);
+ END_HANDLE_TH_ERRORS
+}
+
static PyObject *unpack_saved_variables(
THPFunction *self,
const std::function<PyObject*(const Variable&)>& unpack_fn)
diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp
index 3f72bb2d53..1c07305207 100644
--- a/torch/csrc/autograd/python_variable.cpp
+++ b/torch/csrc/autograd/python_variable.cpp
@@ -405,6 +405,17 @@ PyObject *THPVariable_get_backwards_hooks(THPVariable *self)
END_HANDLE_TH_ERRORS
}
+PyObject *THPVariable_get_backwards_pre_hooks(THPVariable *self)
+{
+ HANDLE_TH_ERRORS
+ if (self->backward_pre_hooks) {
+ Py_INCREF(self->backward_pre_hooks);
+ return self->backward_pre_hooks;
+ }
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj)
{
HANDLE_TH_ERRORS
@@ -423,6 +434,23 @@ int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj)
END_HANDLE_TH_ERRORS_RET(-1)
}
+int THPVariable_set_backwards_pre_hooks(THPVariable *self, PyObject *obj)
+{
+ HANDLE_TH_ERRORS
+ if (obj == Py_None) {
+ obj = nullptr;
+ }
+ Py_XINCREF(obj);
+ Py_XDECREF(self->backward_pre_hooks);
+ self->backward_pre_hooks = obj;
+ self->cdata.clear_hooks();
+ if (obj) {
+ self->cdata.add_hook(std::make_shared<PyFunctionPreHook>(obj, 0));
+ }
+ return 0;
+ END_HANDLE_TH_ERRORS_RET(-1)
+}
+
PyObject *THPVariable_get_base(THPVariable *self)
{
HANDLE_TH_ERRORS
@@ -508,6 +536,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"output_nr", (getter)THPVariable_get_output_nr, nullptr, nullptr, nullptr},
{"requires_grad", (getter)THPVariable_get_requires_grad, (setter)THPVariable_set_requires_grad, nullptr, nullptr},
{"_backward_hooks", (getter)THPVariable_get_backwards_hooks, (setter)THPVariable_set_backwards_hooks, nullptr, nullptr},
+ {"_backward_pre_hooks", (getter)THPVariable_get_backwards_pre_hooks, (setter)THPVariable_set_backwards_pre_hooks, nullptr, nullptr},
{"name", (getter)THPVariable_get_name, nullptr, nullptr, nullptr},
{"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr},
{"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h
index 708a9c4e0a..bfe25eed19 100644
--- a/torch/csrc/autograd/python_variable.h
+++ b/torch/csrc/autograd/python_variable.h
@@ -16,6 +16,7 @@ struct THPVariable {
// Hooks to be run on backwards pass (corresponds to Python attr
// '_backwards_hooks', set by 'register_hook')
PyObject* backward_hooks = nullptr;
+ PyObject* backward_pre_hooks = nullptr;
};
THP_API PyObject *THPVariableClass;
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 6de8198ac4..f832c4735e 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -83,6 +83,7 @@ class Module(object):
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
+ self._backward_pre_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
@@ -431,6 +432,12 @@ class Module(object):
return self._apply(convert)
+ def reset_hooks(self):
+ self._backward_hooks = OrderedDict()
+ self._backward_pre_hooks = OrderedDict()
+ self._forward_hooks = OrderedDict()
+ self._backward_hooks = OrderedDict()
+
def register_backward_hook(self, hook):
r"""Registers a backward hook on the module.
@@ -464,6 +471,15 @@ class Module(object):
self._backward_hooks[handle.id] = hook
return handle
+ def register_backward_pre_hook(self, hook):
+ r"""Registers a backward pre-hook on the module.
+
+ The hook will be called every time before :func:`backward` is invoked.
+ """
+ handle = hooks.RemovableHandle(self._backward_pre_hooks)
+ self._backward_pre_hooks[handle.id] = hook
+ return handle
+
def register_forward_pre_hook(self, hook):
r"""Registers a forward pre-hook on the module.
@@ -549,6 +565,24 @@ class Module(object):
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
+ if len(self._backward_pre_hooks) > 0:
+ var = result
+ while not isinstance(var, torch.Tensor):
+ if isinstance(var, dict):
+ var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
+ else:
+ var = var[0]
+ grad_fn = var.grad_fn
+ if grad_fn is not None:
+ for hook in self._backward_pre_hooks.values():
+ wrapper = functools.partial(hook, self)
+ functools.update_wrapper(wrapper, hook)
+ try:
+ grad_fn.register_pre_hook(wrapper)
+ except Exception as e:
+ print("Error in registering pre-hook")
+ print("Error: %s" % e)
+ continue
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
diff --git a/torch/tensor.py b/torch/tensor.py
index 016a11cccb..79d49b46bb 100644
--- a/torch/tensor.py
+++ b/torch/tensor.py
@@ -158,6 +158,22 @@ class Tensor(torch._C._TensorBase):
self._backward_hooks[handle.id] = hook
return handle
+ def register_pre_hook(self, hook):
+ r"""Registers a backward pre-hook.
+
+ The hook will be called every time a backward pass is invoked.
+ """
+ if not self.requires_grad:
+ raise RuntimeError("cannot register a pre_hook on a tensor that "
+ "doesn't require gradient")
+ if self._backward_pre_hooks is None:
+ self._backward_pre_hooks = OrderedDict()
+ if self.grad_fn is not None:
+ self.grad_fn._register_pre_hook_dict(self)
+ handle = hooks.RemovableHandle(self._backward_pre_hooks)
+ self._backward_pre_hooks[handle.id] = hook
+ return handle
+
def reinforce(self, reward):
def trim(str):
return '\n'.join([line.strip() for line in str.split('\n')])