diff --git a/python/mrc/_pymrc/include/pymrc/utils.hpp b/python/mrc/_pymrc/include/pymrc/utils.hpp index f80838c3d..fbfe2e02f 100644 --- a/python/mrc/_pymrc/include/pymrc/utils.hpp +++ b/python/mrc/_pymrc/include/pymrc/utils.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -54,6 +54,13 @@ void from_import_as(pybind11::module_& dest, const std::string& from, const std: */ const std::type_info* cpptype_info_from_object(pybind11::object& obj); +/** + * @brief Given a pybind11 object, return the Python type name essentially the same as `str(type(obj))` + * @param obj : pybind11 object + * @return std::string. + */ +std::string get_py_type_name(const pybind11::object& obj); + void show_deprecation_warning(const std::string& deprecation_message, ssize_t stack_level = 1); #pragma GCC visibility pop diff --git a/python/mrc/_pymrc/src/utils.cpp b/python/mrc/_pymrc/src/utils.cpp index ba6a70584..02b94a269 100644 --- a/python/mrc/_pymrc/src/utils.cpp +++ b/python/mrc/_pymrc/src/utils.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,6 +17,9 @@ #include "pymrc/utils.hpp" +#include "pymrc/utilities/acquire_gil.hpp" + +#include #include #include #include @@ -25,6 +28,7 @@ #include #include +#include #include #include @@ -72,6 +76,18 @@ const std::type_info* cpptype_info_from_object(py::object& obj) return nullptr; } +std::string get_py_type_name(const pybind11::object& obj) +{ + if (!obj) + { + // calling py::type::of on a null object will trigger an abort + return ""; + } + + const auto py_type = py::type::of(obj); + return py_type.attr("__name__").cast(); +} + py::object cast_from_json(const json& source) { if (source.is_null()) @@ -123,7 +139,7 @@ py::object cast_from_json(const json& source) // throw std::runtime_error("Unsupported conversion type."); } -json cast_from_pyobject(const py::object& source) +json cast_from_pyobject_impl(const py::object& source, const std::string& parent_path = "") { // Dont return via initializer list with JSON. It performs type deduction and gives different results // NOLINTBEGIN(modernize-return-braced-init-list) @@ -137,7 +153,9 @@ json cast_from_pyobject(const py::object& source) auto json_obj = json::object(); for (const auto& p : py_dict) { - json_obj[py::cast(p.first)] = cast_from_pyobject(p.second.cast()); + std::string key{p.first.cast()}; + std::string path{parent_path + "/" + key}; + json_obj[key] = cast_from_pyobject_impl(p.second.cast(), path); } return json_obj; @@ -148,7 +166,7 @@ json cast_from_pyobject(const py::object& source) auto json_arr = json::array(); for (const auto& p : py_list) { - json_arr.push_back(cast_from_pyobject(p.cast())); + json_arr.push_back(cast_from_pyobject_impl(p.cast(), parent_path)); } return json_arr; @@ -170,11 +188,31 @@ json cast_from_pyobject(const py::object& source) return json(py::cast(source)); } - // else unsupported return null - return json(); + // else unsupported return throw a type error + { + AcquireGIL gil; + std::ostringstream error_message; + std::string path{parent_path}; + if (path.empty()) + { + path = "/"; + } + + error_message << "Object (" << py::str(source).cast() << ") of type: " << get_py_type_name(source) + << " at path: " << path << " is not JSON serializable"; + + DVLOG(5) << error_message.str(); + throw py::type_error(error_message.str()); + } + // NOLINTEND(modernize-return-braced-init-list) } +json cast_from_pyobject(const py::object& source) +{ + return cast_from_pyobject_impl(source); +} + void show_deprecation_warning(const std::string& deprecation_message, ssize_t stack_level) { PyErr_WarnEx(PyExc_DeprecationWarning, deprecation_message.c_str(), stack_level); diff --git a/python/mrc/_pymrc/tests/test_utils.cpp b/python/mrc/_pymrc/tests/test_utils.cpp index 713bdc5f4..e518bbd87 100644 --- a/python/mrc/_pymrc/tests/test_utils.cpp +++ b/python/mrc/_pymrc/tests/test_utils.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -41,6 +41,7 @@ namespace py = pybind11; namespace pymrc = mrc::pymrc; using namespace std::string_literals; +using namespace pybind11::literals; // to bring in the `_a` literal // Create values too big to fit in int & float types to ensure we can pass // long & double types to both nlohmann/json and python @@ -143,6 +144,32 @@ TEST_F(TestUtils, CastFromPyObject) } } +TEST_F(TestUtils, CastFromPyObjectSerializeErrors) +{ + // Test to verify that cast_from_pyobject throws a python TypeError when encountering something that is not json + // serializable issue #450 + + // decimal.Decimal is not serializable + py::object Decimal = py::module_::import("decimal").attr("Decimal"); + py::object o = Decimal("1.0"); + EXPECT_THROW(pymrc::cast_from_pyobject(o), py::type_error); + + // Test with object in a nested dict + py::dict d("a"_a = py::dict("b"_a = py::dict("c"_a = py::dict("d"_a = o))), "other"_a = 2); + EXPECT_THROW(pymrc::cast_from_pyobject(d), py::type_error); +} + +TEST_F(TestUtils, GetTypeName) +{ + // invalid objects should return an empty string + EXPECT_EQ(pymrc::get_py_type_name(py::object()), ""); + EXPECT_EQ(pymrc::get_py_type_name(py::none()), "NoneType"); + + py::object Decimal = py::module_::import("decimal").attr("Decimal"); + py::object o = Decimal("1.0"); + EXPECT_EQ(pymrc::get_py_type_name(o), "Decimal"); +} + TEST_F(TestUtils, PyObjectWrapper) { py::list test_list;