Skip to content

Commit

Permalink
frame data precision issue fix (#544)
Browse files Browse the repository at this point in the history
* fix frame precision issue

* add .xmake to .gitignore

* update frame precision lost warning message

* add assert to frame precision checking

* typo fix

* add TODO for future Long data type issue fix
  • Loading branch information
Jinyu-W authored Jun 9, 2022
1 parent 1b6e370 commit ed15f69
Show file tree
Hide file tree
Showing 11 changed files with 256 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ htmlcov/
.coverage
.coveragerc
.tmp/
.xmake/
2 changes: 1 addition & 1 deletion maro/backends/backend.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ ctypedef float ATTR_FLOAT
ctypedef double ATTR_DOUBLE

# Type for snapshot querying.
ctypedef float QUERY_FLOAT
ctypedef double QUERY_FLOAT

# TYPE of node and attribute
ctypedef unsigned short NODE_TYPE
Expand Down
35 changes: 33 additions & 2 deletions maro/backends/np_backend.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ attribute_type_mapping = {
AttributeType.Double: "d"
}

attribute_type_range = {
"b": ("AttributeType.Byte", -128, 127),
"B": ("AttributeType.UByte", 0, 255),
"h": ("AttributeType.Short", -32768, 32767),
"H": ("AttributeType.UShort", 0, 65535),
"i": ("AttributeType.Int", -2147483648, 2147483647),
"I": ("AttributeType.UInt", 0, 4294967295),
"q": ("AttributeType.Long", -9223372036854775808, 9223372036854775807),
"Q": ("AttributeType.ULong", 0, 18446744073709551615),
}


IF NODES_MEMORY_LAYOUT == "ONE_BLOCK":
# with this flag, we will allocate a big enough memory for all node types, then use this block construct numpy array
Expand Down Expand Up @@ -167,6 +178,13 @@ cdef class NumpyBackend(BackendAbc):

cdef AttrInfo attr = self._attrs_list[attr_type]

cdef bytes dtype = attr.dtype.encode()
if dtype in attribute_type_range:
assert value >= attribute_type_range[dtype][1] and value <= attribute_type_range[dtype][2], (
f"Value {value} out of range ({attribute_type_range[dtype][0]}: "
f"[{attribute_type_range[dtype][1]}, {attribute_type_range[dtype][2]}])"
)

if attr.node_type >= len(self._nodes_list):
raise Exception("Invalid node type.")

Expand Down Expand Up @@ -208,9 +226,22 @@ cdef class NumpyBackend(BackendAbc):
cdef AttrInfo attr = self._attrs_list[attr_type]
cdef np.ndarray attr_array = self._node_data_dict[attr.node_type][attr.name]

cdef bytes dtype = attr.dtype.encode()

if attr.slot_number == 1:
if dtype in attribute_type_range:
assert value[0] >= attribute_type_range[dtype][1] and value[0] <= attribute_type_range[dtype][2], (
f"Value {value[0]} out of range ({attribute_type_range[dtype][0]}: "
f"[{attribute_type_range[dtype][1]}, {attribute_type_range[dtype][2]}])"
)
attr_array[0][node_index, slot_index[0]] = value[0]
else:
if dtype in attribute_type_range:
for val in value:
assert val >= attribute_type_range[dtype][1] and val <= attribute_type_range[dtype][2], (
f"Value {val} out of range ({attribute_type_range[dtype][0]}: "
f"[{attribute_type_range[dtype][1]}, {attribute_type_range[dtype][2]}])"
)
attr_array[0][node_index, slot_index] = value

cdef list get_attr_values(self, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX[:] slot_indices) except +:
Expand Down Expand Up @@ -500,10 +531,10 @@ cdef class NPSnapshotList(SnapshotListAbc):

# since we have a clear tick to index mapping, do not need additional checking here
if tick in self._tick2index_dict:
retq.append(data_arr[attr.name][self._tick2index_dict[tick], node_index].astype("f").flatten())
retq.append(data_arr[attr.name][self._tick2index_dict[tick], node_index].astype(np.double).flatten())
else:
# padding for tick which not exist
retq.append(np.zeros(attr.slot_number, dtype='f'))
retq.append(np.zeros(attr.slot_number, dtype=np.double))

return np.concatenate(retq)

Expand Down
3 changes: 2 additions & 1 deletion maro/backends/raw/attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ namespace maro

bool Attribute::is_nan() const noexcept
{
return _type == AttrDataType::AFLOAT && isnan(get_value<ATTR_FLOAT>());
return (_type == AttrDataType::AFLOAT && isnan(get_value<ATTR_FLOAT>()))
|| (_type == AttrDataType::ADOUBLE && isnan(get_value<ATTR_DOUBLE>()));
}

template<typename T>
Expand Down
2 changes: 1 addition & 1 deletion maro/backends/raw/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace maro

using NODE_INDEX = uint32_t;
using SLOT_INDEX = uint32_t;
using QUERY_FLOAT = float;
using QUERY_FLOAT = double; // TODO: Precision issue for Long data type.

using ATTR_CHAR = char;
using ATTR_UCHAR = unsigned char;
Expand Down
5 changes: 5 additions & 0 deletions maro/backends/raw/test/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# How to run

1. Install xmake according to <https://xmake.io/#/guide/installation>
2. Go to directory: maro/backends/raw
3. Run commands: `xmake; xmake run`
7 changes: 7 additions & 0 deletions maro/backends/raw/test/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include <gtest/gtest.h>


int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
34 changes: 34 additions & 0 deletions maro/backends/raw/test/test_attribute.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include <gtest/gtest.h>

#include "../attribute.h"

using namespace maro::backends::raw;


// test attribute creation
TEST(Attribute, Creation) {
Attribute attr;

EXPECT_EQ(attr.get_type(), AttrDataType::ACHAR);
EXPECT_FALSE(attr.is_nan());
EXPECT_EQ(attr.slot_number, 0);

}

// test create attribute with other type value.
TEST(Attribute, CreateWithTypedValue) {
Attribute attr{ ATTR_UINT(12)};

EXPECT_EQ(attr.get_type(), AttrDataType::AUINT);
EXPECT_EQ(attr.get_value<ATTR_UINT>(), 12);
EXPECT_EQ(attr.slot_number, 0);
EXPECT_FALSE(attr.is_nan());
}

// test is nan case
TEST(Attribute, CreateWithNan) {
Attribute attr{ nan("nan")};

EXPECT_TRUE(attr.is_nan());
}

84 changes: 84 additions & 0 deletions maro/backends/raw/test/test_frame.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#include <iomanip>
#include <array>

#include <gtest/gtest.h>

#include "../common.h"
#include "../frame.h"
#include "../snapshotlist.h"

using namespace maro::backends::raw;


TEST(test, correct) {
EXPECT_EQ(1, 1);
}

// show how to use frame and snapshot at c++ end
TEST(test, show_case) {
// a new frame
Frame frame;

// add a new node with a name
auto node_type = frame.add_node("test_node", 1);

// add an attribute to this node, this is a list attribute, it has different value to change the value
// NOTE: list means is it dynamic array, that the size can be changed even after setting up
auto attr_type_1 = frame.add_attr(node_type, "a1", AttrDataType::AUINT, 10, false, true);

// this is a normal attribute
// NOTE: list == false means it is a fixed array that cannot change the size after setting up
auto attr_type_2 = frame.add_attr(node_type, "a2", AttrDataType::AUINT, 2);

// setup means initialize the frame with node definitions (allocate memory)
// NOTE: call this method before accessing the attributes
frame.setup();

// list and normal attribute have different method to set value
frame.set_value<ATTR_UINT>(0, attr_type_2, 0, 33554441);
frame.insert_to_list<ATTR_UINT>(0, attr_type_1, 0, 33554442);

// but they have same get method
auto v1 = frame.get_value<ATTR_UINT>(0, attr_type_1, 0);
auto v2 = frame.get_value<ATTR_UINT>(0, attr_type_2, 0);

// test with true type
EXPECT_EQ(v2, 33554441);
EXPECT_EQ(v1, 33554442);

// test with query result type
EXPECT_EQ(QUERY_FLOAT(v2), 3.3554441e+07);
EXPECT_EQ(QUERY_FLOAT(v1), 3.3554442e+07);

// snapshot instance
SnapshotList ss;

// NOTE: we need following 2 method to initialize the snapshot instance, or accessing will cause exception
// which frame we will use to copy the values
ss.setup(&frame);
// max snapshot it will keep, oldeat one will be delete when reading the limitation
ss.set_max_size(10);

// take a snapshot for a tick
ss.take_snapshot(0);

// query parameters
std::array<int, 1> ticks{ 0 };
std::array<NODE_INDEX, 1> indices{ 0 };
std::array< ATTR_TYPE, 1> attributes{attr_type_1};

// we need use the parameter to get how many items we need to hold the results
auto shape = ss.prepare(node_type, &(ticks[0]), ticks.size(), &(indices[0]), indices.size(), &(attributes[0]), attributes.size());

auto total = shape.tick_number * shape.max_node_number * shape.max_slot_number * shape.attr_number;

// then query (the snapshot instance will remember the latest query parameters, so just pass the result array
QUERY_FLOAT* results = new QUERY_FLOAT[total];

ss.query(results);

// 1st slot value of first node
EXPECT_EQ(results[0], 3.3554442e+07);

delete[] results;
}
7 changes: 7 additions & 0 deletions maro/backends/raw/xmake.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
add_requires("gtest")

target("test")
set_kind("binary")
add_files("test/*.cpp")
add_files("./*.cpp")
add_packages("gtest")
Loading

0 comments on commit ed15f69

Please sign in to comment.