Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao committed Dec 19, 2023
1 parent 4931c16 commit 4178851
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 23 deletions.
11 changes: 8 additions & 3 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,11 @@ const phi::DDim &GetValueDims(Value value) {

pir::OpResult apply(Value self, py::object func) {
py::gil_scoped_acquire gil;
auto stop_gradient = self.attribute<BoolAttribute>(kAttrStopGradients);
if (stop_gradient && !stop_gradient.data()) {
PADDLE_THROW(phi::errors::Unavailable(
"Cannot apply function on a tensor that required gradient."));
}
PyObject *py_func = func.release().ptr();
Py_INCREF(py_func);
PyObject *res = nullptr;
Expand All @@ -592,10 +597,10 @@ pir::OpResult apply(Value self, py::object func) {
Py_DECREF(tmp_self);
} catch (std::exception &e) {
PADDLE_THROW(phi::errors::Unavailable(
"Hook function of Tensor raises an exception: %s.", e.what()));
"Apply function of Tensor raises an exception: %s.", e.what()));
} catch (...) {
PADDLE_THROW(phi::errors::Fatal(
"Hook function of Tensor raises an unknown exception."));
"Apply function of Tensor raises an unknown exception."));
}
if (res == Py_None) {
return self.dyn_cast<OpResult>();
Expand Down Expand Up @@ -817,7 +822,7 @@ void BindValue(py::module *m) {
return paddle::dialect::scale(self, -1.0, 0.0, true);
})
.def("apply", &apply);
.def("is_same", &Value::operator==)
.def("is_same", &Value::operator==)
.def("hash", [](Value self) { return std::hash<pir::Value>{}(self); })
.def("__repr__", &Value2String);
// For basaic operators
Expand Down
49 changes: 29 additions & 20 deletions test/legacy_test/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,46 @@ def test_dygraph(self):
def test_error(self):
self.x.stop_gradient = False

def fn(x):
def fn_inplace(x):
x.apply_(self.function)

self.assertRaises(RuntimeError, fn, self.x)
def fn_outplace(x, func):
x.apply(func)

def test_to_pir(self):
def fn(x):
y = x.apply(self.function)
self.assertRaises(RuntimeError, fn_inplace, self.x)
self.assertRaises(RuntimeError, fn_outplace, self.x, self.function)
with paddle.jit.api.sot_mode_guard(False):
self.assertRaises(
RuntimeError,
paddle.jit.to_static(fn_outplace),
self.x,
self.function,
)
with paddle.pir_utils.IrGuard():
self.assertRaises(
RuntimeError,
paddle.jit.to_static(fn_outplace),
self.x,
self.function,
)

def test_to_static(self):
def fn(x, func):
y = x.apply(func)
return y

with paddle.jit.api.sot_mode_guard(False):
paddle.disable_static()
jit_g = paddle.jit.to_static(fn)
out = jit_g(self.x)

np.testing.assert_allclose(
self.function(self.x).numpy(), out.numpy(), rtol=1e-05
)

def test_to_legacy_ir(self):
def fn(x):
y = x.apply(self.function)
return y

with paddle.jit.api.sot_mode_guard(False):
out_legacy_ir = jit_g(self.x, self.function)
with paddle.pir_utils.IrGuard():
paddle.disable_static()
jit_g = paddle.jit.to_static(fn)
out = jit_g(self.x)
out_pir = jit_g(self.x, self.function)
np.testing.assert_allclose(
self.function(self.x).numpy(), out_legacy_ir.numpy(), rtol=1e-05
)
np.testing.assert_allclose(
self.function(self.x).numpy(), out.numpy(), rtol=1e-05
self.function(self.x).numpy(), out_pir.numpy(), rtol=1e-05
)


Expand Down

0 comments on commit 4178851

Please sign in to comment.