Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update cast_from_pyobject to throw on unsupported types rather than returning null #451

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion python/mrc/_pymrc/include/pymrc/utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -54,6 +54,13 @@ void from_import_as(pybind11::module_& dest, const std::string& from, const std:
*/
const std::type_info* cpptype_info_from_object(pybind11::object& obj);

/**
* @brief Given a pybind11 object, return the Python type name essentially the same as `str(type(obj))`
* @param obj : pybind11 object
* @return std::string.
*/
std::string get_py_type_name(const pybind11::object& obj);

void show_deprecation_warning(const std::string& deprecation_message, ssize_t stack_level = 1);

#pragma GCC visibility pop
Expand Down
50 changes: 44 additions & 6 deletions python/mrc/_pymrc/src/utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -17,6 +17,9 @@

#include "pymrc/utils.hpp"

#include "pymrc/utilities/acquire_gil.hpp"

#include <glog/logging.h>
#include <nlohmann/json.hpp>
#include <pybind11/cast.h>
#include <pybind11/detail/internals.h>
Expand All @@ -25,6 +28,7 @@
#include <pyerrors.h>
#include <warnings.h>

#include <sstream>
#include <string>
#include <utility>

Expand Down Expand Up @@ -72,6 +76,18 @@
return nullptr;
}

std::string get_py_type_name(const pybind11::object& obj)
{
if (!obj)
{
// calling py::type::of on a null object will trigger an abort
return "";
}

const auto py_type = py::type::of(obj);
return py_type.attr("__name__").cast<std::string>();
}

py::object cast_from_json(const json& source)
{
if (source.is_null())
Expand Down Expand Up @@ -123,7 +139,7 @@
// throw std::runtime_error("Unsupported conversion type.");
}

json cast_from_pyobject(const py::object& source)
json cast_from_pyobject_impl(const py::object& source, const std::string& parent_path = "")
{
// Dont return via initializer list with JSON. It performs type deduction and gives different results
// NOLINTBEGIN(modernize-return-braced-init-list)
Expand All @@ -137,7 +153,9 @@
auto json_obj = json::object();
for (const auto& p : py_dict)
{
json_obj[py::cast<std::string>(p.first)] = cast_from_pyobject(p.second.cast<py::object>());
std::string key{p.first.cast<std::string>()};
std::string path{parent_path + "/" + key};
json_obj[key] = cast_from_pyobject_impl(p.second.cast<py::object>(), path);
}

return json_obj;
Expand All @@ -148,7 +166,7 @@
auto json_arr = json::array();
for (const auto& p : py_list)
{
json_arr.push_back(cast_from_pyobject(p.cast<py::object>()));
json_arr.push_back(cast_from_pyobject_impl(p.cast<py::object>(), parent_path));
}

return json_arr;
Expand All @@ -170,11 +188,31 @@
return json(py::cast<std::string>(source));
}

// else unsupported return null
return json();
// else unsupported return throw a type error
{
AcquireGIL gil;
std::ostringstream error_message;
std::string path{parent_path};
if (path.empty())
{
path = "/";
}

error_message << "Object (" << py::str(source).cast<std::string>() << ") of type: " << get_py_type_name(source)
<< " at path: " << path << " is not JSON serializable";

DVLOG(5) << error_message.str();

Check warning on line 204 in python/mrc/_pymrc/src/utils.cpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/src/utils.cpp#L204

Added line #L204 was not covered by tests
throw py::type_error(error_message.str());
}

// NOLINTEND(modernize-return-braced-init-list)
}

json cast_from_pyobject(const py::object& source)
{
return cast_from_pyobject_impl(source);
}

void show_deprecation_warning(const std::string& deprecation_message, ssize_t stack_level)
{
PyErr_WarnEx(PyExc_DeprecationWarning, deprecation_message.c_str(), stack_level);
Expand Down
29 changes: 28 additions & 1 deletion python/mrc/_pymrc/tests/test_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -41,6 +41,7 @@
namespace py = pybind11;
namespace pymrc = mrc::pymrc;
using namespace std::string_literals;
using namespace pybind11::literals; // to bring in the `_a` literal

// Create values too big to fit in int & float types to ensure we can pass
// long & double types to both nlohmann/json and python
Expand Down Expand Up @@ -143,6 +144,32 @@ TEST_F(TestUtils, CastFromPyObject)
}
}

TEST_F(TestUtils, CastFromPyObjectSerializeErrors)
{
// Test to verify that cast_from_pyobject throws a python TypeError when encountering something that is not json
// serializable issue #450

// decimal.Decimal is not serializable
py::object Decimal = py::module_::import("decimal").attr("Decimal");
py::object o = Decimal("1.0");
EXPECT_THROW(pymrc::cast_from_pyobject(o), py::type_error);

// Test with object in a nested dict
py::dict d("a"_a = py::dict("b"_a = py::dict("c"_a = py::dict("d"_a = o))), "other"_a = 2);
EXPECT_THROW(pymrc::cast_from_pyobject(d), py::type_error);
}

TEST_F(TestUtils, GetTypeName)
{
// invalid objects should return an empty string
EXPECT_EQ(pymrc::get_py_type_name(py::object()), "");
EXPECT_EQ(pymrc::get_py_type_name(py::none()), "NoneType");

py::object Decimal = py::module_::import("decimal").attr("Decimal");
py::object o = Decimal("1.0");
EXPECT_EQ(pymrc::get_py_type_name(o), "Decimal");
}

TEST_F(TestUtils, PyObjectWrapper)
{
py::list test_list;
Expand Down
Loading