Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

frame data precision issue fix #544

Merged
merged 6 commits into from
Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ htmlcov/
.coverage
.coveragerc
.tmp/
.xmake/
2 changes: 1 addition & 1 deletion maro/backends/backend.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ ctypedef float ATTR_FLOAT
ctypedef double ATTR_DOUBLE

# Type for snapshot querying.
ctypedef float QUERY_FLOAT
ctypedef double QUERY_FLOAT

# TYPE of node and attribute
ctypedef unsigned short NODE_TYPE
Expand Down
35 changes: 33 additions & 2 deletions maro/backends/np_backend.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ attribute_type_mapping = {
AttributeType.Double: "d"
}

attribute_type_range = {
"b": ("AttributeType.Byte", -128, 127),
"B": ("AttributeType.UByte", 0, 255),
"h": ("AttributeType.Short", -32768, 32767),
"H": ("AttributeType.UShort", 0, 65535),
"i": ("AttributeType.Int", -2147483648, 2147483647),
"I": ("AttributeType.UInt", 0, 4294967295),
"q": ("AttributeType.Long", -9223372036854775808, 9223372036854775807),
"Q": ("AttributeType.ULong", 0, 18446744073709551615),
}


IF NODES_MEMORY_LAYOUT == "ONE_BLOCK":
# with this flag, we will allocate a big enough memory for all node types, then use this block construct numpy array
Expand Down Expand Up @@ -167,6 +178,13 @@ cdef class NumpyBackend(BackendAbc):

cdef AttrInfo attr = self._attrs_list[attr_type]

cdef bytes dtype = attr.dtype.encode()
if dtype in attribute_type_range:
assert value >= attribute_type_range[dtype][1] and value <= attribute_type_range[dtype][2], (
f"Value {value} out of range ({attribute_type_range[dtype][0]}: "
f"[{attribute_type_range[dtype][1]}, {attribute_type_range[dtype][2]}])"
)

if attr.node_type >= len(self._nodes_list):
raise Exception("Invalid node type.")

Expand Down Expand Up @@ -208,9 +226,22 @@ cdef class NumpyBackend(BackendAbc):
cdef AttrInfo attr = self._attrs_list[attr_type]
cdef np.ndarray attr_array = self._node_data_dict[attr.node_type][attr.name]

cdef bytes dtype = attr.dtype.encode()

if attr.slot_number == 1:
if dtype in attribute_type_range:
assert value[0] >= attribute_type_range[dtype][1] and value[0] <= attribute_type_range[dtype][2], (
f"Value {value[0]} out of range ({attribute_type_range[dtype][0]}: "
f"[{attribute_type_range[dtype][1]}, {attribute_type_range[dtype][2]}])"
)
attr_array[0][node_index, slot_index[0]] = value[0]
else:
if dtype in attribute_type_range:
for val in value:
assert val >= attribute_type_range[dtype][1] and val <= attribute_type_range[dtype][2], (
f"Value {val} out of range ({attribute_type_range[dtype][0]}: "
f"[{attribute_type_range[dtype][1]}, {attribute_type_range[dtype][2]}])"
)
attr_array[0][node_index, slot_index] = value

cdef list get_attr_values(self, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX[:] slot_indices) except +:
Expand Down Expand Up @@ -500,10 +531,10 @@ cdef class NPSnapshotList(SnapshotListAbc):

# since we have a clear tick to index mapping, do not need additional checking here
if tick in self._tick2index_dict:
retq.append(data_arr[attr.name][self._tick2index_dict[tick], node_index].astype("f").flatten())
retq.append(data_arr[attr.name][self._tick2index_dict[tick], node_index].astype(np.double).flatten())
else:
# padding for tick which not exist
retq.append(np.zeros(attr.slot_number, dtype='f'))
retq.append(np.zeros(attr.slot_number, dtype=np.double))

return np.concatenate(retq)

Expand Down
3 changes: 2 additions & 1 deletion maro/backends/raw/attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ namespace maro

bool Attribute::is_nan() const noexcept
{
return _type == AttrDataType::AFLOAT && isnan(get_value<ATTR_FLOAT>());
return (_type == AttrDataType::AFLOAT && isnan(get_value<ATTR_FLOAT>()))
|| (_type == AttrDataType::ADOUBLE && isnan(get_value<ATTR_DOUBLE>()));
}

template<typename T>
Expand Down
2 changes: 1 addition & 1 deletion maro/backends/raw/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace maro

using NODE_INDEX = uint32_t;
using SLOT_INDEX = uint32_t;
using QUERY_FLOAT = float;
using QUERY_FLOAT = double; // TODO: Precision issue for Long data type.

using ATTR_CHAR = char;
using ATTR_UCHAR = unsigned char;
Expand Down
5 changes: 5 additions & 0 deletions maro/backends/raw/test/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# How to run

1. Install xmake according to <https://xmake.io/#/guide/installation>
2. Go to directory: maro/backends/raw
3. Run commands: `xmake; xmake run`
7 changes: 7 additions & 0 deletions maro/backends/raw/test/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include <gtest/gtest.h>


int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
34 changes: 34 additions & 0 deletions maro/backends/raw/test/test_attribute.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include <gtest/gtest.h>

#include "../attribute.h"

using namespace maro::backends::raw;


// test attribute creation
TEST(Attribute, Creation) {
Attribute attr;

EXPECT_EQ(attr.get_type(), AttrDataType::ACHAR);
EXPECT_FALSE(attr.is_nan());
EXPECT_EQ(attr.slot_number, 0);

}

// test create attribute with other type value.
TEST(Attribute, CreateWithTypedValue) {
Attribute attr{ ATTR_UINT(12)};

EXPECT_EQ(attr.get_type(), AttrDataType::AUINT);
EXPECT_EQ(attr.get_value<ATTR_UINT>(), 12);
EXPECT_EQ(attr.slot_number, 0);
EXPECT_FALSE(attr.is_nan());
}

// test is nan case
TEST(Attribute, CreateWithNan) {
Attribute attr{ nan("nan")};

EXPECT_TRUE(attr.is_nan());
}

84 changes: 84 additions & 0 deletions maro/backends/raw/test/test_frame.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#include <iomanip>
#include <array>

#include <gtest/gtest.h>

#include "../common.h"
#include "../frame.h"
#include "../snapshotlist.h"

using namespace maro::backends::raw;


TEST(test, correct) {
EXPECT_EQ(1, 1);
}

// show how to use frame and snapshot at c++ end
TEST(test, show_case) {
// a new frame
Frame frame;

// add a new node with a name
auto node_type = frame.add_node("test_node", 1);

// add an attribute to this node, this is a list attribute, it has different value to change the value
// NOTE: list means is it dynamic array, that the size can be changed even after setting up
auto attr_type_1 = frame.add_attr(node_type, "a1", AttrDataType::AUINT, 10, false, true);

// this is a normal attribute
// NOTE: list == false means it is a fixed array that cannot change the size after setting up
auto attr_type_2 = frame.add_attr(node_type, "a2", AttrDataType::AUINT, 2);

// setup means initialize the frame with node definitions (allocate memory)
// NOTE: call this method before accessing the attributes
frame.setup();

// list and normal attribute have different method to set value
frame.set_value<ATTR_UINT>(0, attr_type_2, 0, 33554441);
frame.insert_to_list<ATTR_UINT>(0, attr_type_1, 0, 33554442);

// but they have same get method
auto v1 = frame.get_value<ATTR_UINT>(0, attr_type_1, 0);
auto v2 = frame.get_value<ATTR_UINT>(0, attr_type_2, 0);

// test with true type
EXPECT_EQ(v2, 33554441);
EXPECT_EQ(v1, 33554442);

// test with query result type
EXPECT_EQ(QUERY_FLOAT(v2), 3.3554441e+07);
EXPECT_EQ(QUERY_FLOAT(v1), 3.3554442e+07);

// snapshot instance
SnapshotList ss;

// NOTE: we need following 2 method to initialize the snapshot instance, or accessing will cause exception
// which frame we will use to copy the values
ss.setup(&frame);
// max snapshot it will keep, oldeat one will be delete when reading the limitation
ss.set_max_size(10);

// take a snapshot for a tick
ss.take_snapshot(0);

// query parameters
std::array<int, 1> ticks{ 0 };
std::array<NODE_INDEX, 1> indices{ 0 };
std::array< ATTR_TYPE, 1> attributes{attr_type_1};

// we need use the parameter to get how many items we need to hold the results
auto shape = ss.prepare(node_type, &(ticks[0]), ticks.size(), &(indices[0]), indices.size(), &(attributes[0]), attributes.size());

auto total = shape.tick_number * shape.max_node_number * shape.max_slot_number * shape.attr_number;

// then query (the snapshot instance will remember the latest query parameters, so just pass the result array
QUERY_FLOAT* results = new QUERY_FLOAT[total];

ss.query(results);

// 1st slot value of first node
EXPECT_EQ(results[0], 3.3554442e+07);

delete[] results;
}
7 changes: 7 additions & 0 deletions maro/backends/raw/xmake.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
add_requires("gtest")

target("test")
set_kind("binary")
add_files("test/*.cpp")
add_files("./*.cpp")
add_packages("gtest")
Loading