Skip to content

Commit

Permalink
Serialising dynamic arrays in SQL; read-only SQLite connection in SQL…
Browse files Browse the repository at this point in the history
… Dataset

Summary:
1. We may need to store arrays of unknown shape in the database. It implements and tests serialisation.

2. Previously, when an inexisting metadata file was passed to SqlIndexDataset, it would try to open it and create an empty file, then crash. We now open the file in a read-only mode, so the error message is more intuitive. Note that the implementation is SQLite specific.

Reviewed By: bottler

Differential Revision: D46047857

fbshipit-source-id: 3064ae4f8122b4fc24ad3d6ab696572ebe8d0c26
  • Loading branch information
shapovalov authored and facebook-github-bot committed May 22, 2023
1 parent ff80183 commit d2119c2
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 5 deletions.
32 changes: 30 additions & 2 deletions pytorch3d/implicitron/dataset/orm_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,35 @@


# these produce policies to serialize structured types to blobs
def ArrayTypeFactory(shape):
def ArrayTypeFactory(shape=None):
if shape is None:

class VariableShapeNumpyArrayType(TypeDecorator):
impl = LargeBinary

def process_bind_param(self, value, dialect):
if value is None:
return None

ndim_bytes = np.int32(value.ndim).tobytes()
shape_bytes = np.array(value.shape, dtype=np.int64).tobytes()
value_bytes = value.astype(np.float32).tobytes()
return ndim_bytes + shape_bytes + value_bytes

def process_result_value(self, value, dialect):
if value is None:
return None

ndim = np.frombuffer(value[:4], dtype=np.int32)[0]
value_start = 4 + 8 * ndim
shape = np.frombuffer(value[4:value_start], dtype=np.int64)
assert shape.shape == (ndim,)
return np.frombuffer(value[value_start:], dtype=np.float32).reshape(
shape
)

return VariableShapeNumpyArrayType

class NumpyArrayType(TypeDecorator):
impl = LargeBinary

Expand Down Expand Up @@ -158,4 +186,4 @@ class SqlSequenceAnnotation(Base):
mapped_column("_point_cloud_n_points", nullable=True),
)
# the bigger the better
viewpoint_quality_score: Mapped[Optional[float]] = mapped_column(default=None)
viewpoint_quality_score: Mapped[Optional[float]] = mapped_column()
6 changes: 4 additions & 2 deletions pytorch3d/implicitron/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,10 @@ def __post_init__(self) -> None:
run_auto_creation(self)
self.frame_data_builder.path_manager = self.path_manager

# pyre-ignore
self._sql_engine = sa.create_engine(f"sqlite:///{self.sqlite_metadata_file}")
# pyre-ignore # NOTE: sqlite-specific args (read-only mode).
self._sql_engine = sa.create_engine(
f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true"
)

sequences = self._get_filtered_sequences_if_any()

Expand Down
27 changes: 26 additions & 1 deletion tests/implicitron/test_orm_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np

from pytorch3d.implicitron.dataset.orm_types import TupleTypeFactory
from pytorch3d.implicitron.dataset.orm_types import ArrayTypeFactory, TupleTypeFactory


class TestOrmTypes(unittest.TestCase):
Expand All @@ -35,3 +35,28 @@ def test_tuple_serialization_2d(self):
self.assertEqual(type(input_hat[0][0]), type(input_tuple[0][0]))
# we use float32 to serialise
np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6)

def test_array_serialization_none(self):
ttype = ArrayTypeFactory((3, 3))()
output = ttype.process_bind_param(None, None)
self.assertIsNone(output)
output = ttype.process_result_value(output, None)
self.assertIsNone(output)

def test_array_serialization(self):
for input_list in [[1, 2, 3], [[4.5, 6.7], [8.9, 10.0]]]:
input_array = np.array(input_list)

# first, dynamic-size array
ttype = ArrayTypeFactory()()
output = ttype.process_bind_param(input_array, None)
input_hat = ttype.process_result_value(output, None)
self.assertEqual(input_hat.dtype, np.float32)
np.testing.assert_almost_equal(input_hat, input_array, decimal=6)

# second, fixed-size array
ttype = ArrayTypeFactory(tuple(input_array.shape))()
output = ttype.process_bind_param(input_array, None)
input_hat = ttype.process_result_value(output, None)
self.assertEqual(input_hat.dtype, np.float32)
np.testing.assert_almost_equal(input_hat, input_array, decimal=6)

0 comments on commit d2119c2

Please sign in to comment.