Skip to content

Commit

Permalink
Merge pull request #212 from TileDB-Inc/ss/ensure-metadata-and-proper…
Browse files Browse the repository at this point in the history
…ties-are-separate

Ensure metadata is a separate copy from model properties
  • Loading branch information
Shelnutt2 committed Sep 12, 2023
2 parents 3c42df2 + 5e1c357 commit effa1c1
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 6 deletions.
26 changes: 25 additions & 1 deletion tests/models/test_tensorflow_keras_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import tiledb
from tiledb.ml import __version__ as tiledb_ml_version
from tiledb.ml.models import SHORT_PREVIEW_LIMIT
from tiledb.ml.models.tensorflow_keras import TensorflowKerasTileDBModel

try:
Expand All @@ -34,7 +35,6 @@

# Suppress all Tensorflow messages
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

batch_get_value = tf.keras.backend.batch_get_value


Expand Down Expand Up @@ -453,6 +453,30 @@ def test_preview(self, tmpdir, api, loss, optimizer, metrics):


class TestTensorflowKerasModelCloud:
def test_truncated_file_property_vs_array_meta(self, tmpdir):
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(10, 10)))

# create a deep model with summary 2298 long preview summary
layers_num = 15
model.add(tf.keras.layers.Dense(layers_num, activation=tf.nn.relu))

for layer_size in range(layers_num - 1, 2, -1):
model.add(tf.keras.layers.Dense(layer_size, activation=tf.nn.softmax))

# Get model summary in a string
s = io.StringIO()
model.summary(print_fn=lambda x: s.write(x + "\n"))
model_summary = s.getvalue()

uri = os.path.join(tmpdir, "model_array")
tiledb_obj = TensorflowKerasTileDBModel(uri=uri, model=model)

key = "TILEDB_ML_MODEL_PREVIEW"
assert len(tiledb_obj._file_properties[key]) == SHORT_PREVIEW_LIMIT
assert len(tiledb_obj._array_metadata[key]) == len(model_summary)
assert SHORT_PREVIEW_LIMIT != len(model_summary)

def test_file_properties(self, tmpdir):
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(10, 10)))
Expand Down
1 change: 1 addition & 0 deletions tiledb/ml/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SHORT_PREVIEW_LIMIT = 2048
3 changes: 1 addition & 2 deletions tiledb/ml/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
Artifact = TypeVar("Artifact")
Meta = Mapping[str, Any]
Timestamp = Tuple[int, int]

Weights = Union[Sequence[np.ndarray], Mapping[str, Any]]


Expand Down Expand Up @@ -76,7 +75,7 @@ def __init__(
ModelFileProperties.TILEDB_ML_MODEL_VERSION.value: __version__,
}
# Full/long versions of all properties
self._array_metadata = self._file_properties
self._array_metadata = self._file_properties.copy()
self._array_metadata[
ModelArrayMetadata.TILEDB_ML_MODEL_PREVIEW.value
] = self.preview()
Expand Down
3 changes: 2 additions & 1 deletion tiledb/ml/models/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import tiledb

from . import SHORT_PREVIEW_LIMIT
from ._base import Meta, TileDBArtifact, Timestamp


Expand Down Expand Up @@ -166,4 +167,4 @@ def preview_short(self) -> str:
:return: str. A string representation of the models internal configuration.
"""
return self.preview()[0:2048]
return self.preview()[0:SHORT_PREVIEW_LIMIT]
3 changes: 2 additions & 1 deletion tiledb/ml/models/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import tiledb

from . import SHORT_PREVIEW_LIMIT
from ._base import Meta, TileDBArtifact, Timestamp


Expand Down Expand Up @@ -90,7 +91,7 @@ def preview_short(self, *, display: str = "text") -> str:
displayed as text.
:return: A string representation of the models internal configuration.
"""
return self.preview(display=display)[0:2048]
return self.preview(display=display)[0:SHORT_PREVIEW_LIMIT]

def _serialize_model(self) -> bytes:
"""Serialize a Sklearn model with pickle.
Expand Down
3 changes: 2 additions & 1 deletion tiledb/ml/models/tensorflow_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import tiledb

from . import SHORT_PREVIEW_LIMIT
from ._base import Meta, TileDBArtifact, Timestamp

keras_major, keras_minor, keras_patch = keras.__version__.split(".")
Expand Down Expand Up @@ -302,7 +303,7 @@ def preview(self) -> str:

def preview_short(self) -> str:
"""Create a string representation of the Tensorflow model that is under 2048 characters."""
return self.preview()[0:2048]
return self.preview()[0:SHORT_PREVIEW_LIMIT]

def _serialize_optimizer_weights(
self,
Expand Down

0 comments on commit effa1c1

Please sign in to comment.