Skip to content

Commit

Permalink
Merge pull request #211 from TileDB-Inc/kostastsitsimpikos/sc-33666/p…
Browse files Browse the repository at this point in the history
…review-longer-than-2048-prevents-creation

Store Long Model Previews in Array Metadata
  • Loading branch information
Shelnutt2 authored Sep 5, 2023
2 parents a40a555 + 454285a commit 3c42df2
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 3 deletions.
15 changes: 15 additions & 0 deletions tiledb/ml/models/_array_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from enum import Enum, unique


@unique
class ModelArrayMetadata(Enum):
"""
Enum Class that contains all model array metadata.
"""

TILEDB_ML_MODEL_ML_FRAMEWORK = "TILEDB_ML_MODEL_ML_FRAMEWORK"
TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION = "TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION"
TILEDB_ML_MODEL_STAGE = "TILEDB_ML_MODEL_STAGE"
TILEDB_ML_MODEL_PYTHON_VERSION = "TILEDB_ML_MODEL_PYTHON_VERSION"
TILEDB_ML_MODEL_PREVIEW = "TILEDB_ML_MODEL_PREVIEW"
TILEDB_ML_MODEL_VERSION = "TILEDB_ML_MODEL_VERSION"
18 changes: 15 additions & 3 deletions tiledb/ml/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import tiledb

from .. import __version__
from ._array_metadata import ModelArrayMetadata
from ._cloud_utils import get_cloud_uri, update_file_properties
from ._file_properties import ModelFileProperties

Expand Down Expand Up @@ -71,9 +72,14 @@ def __init__(
ModelFileProperties.TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION.value: self.Version,
ModelFileProperties.TILEDB_ML_MODEL_STAGE.value: "STAGING",
ModelFileProperties.TILEDB_ML_MODEL_PYTHON_VERSION.value: platform.python_version(),
ModelFileProperties.TILEDB_ML_MODEL_PREVIEW.value: self.preview(),
ModelFileProperties.TILEDB_ML_MODEL_PREVIEW.value: self.preview_short(),
ModelFileProperties.TILEDB_ML_MODEL_VERSION.value: __version__,
}
# Full/long versions of all properties
self._array_metadata = self._file_properties
self._array_metadata[
ModelArrayMetadata.TILEDB_ML_MODEL_PREVIEW.value
] = self.preview()

@abstractmethod
def save(self, *, meta: Optional[Meta] = None) -> None:
Expand Down Expand Up @@ -109,6 +115,12 @@ def preview(self) -> str:
Creates a string representation of a machine learning model.
"""

@abstractmethod
def preview_short(self) -> str:
"""
Creates a string representation of a machine learning model that is under 2048 characters.
"""

def _create_array(self, fields: Sequence[str]) -> None:
"""Internal method that creates a TileDB array based on the model's spec."""

Expand Down Expand Up @@ -160,7 +172,7 @@ def _write_array(

if meta is None:
meta = {}
if not meta.keys().isdisjoint(self._file_properties.keys()):
if not meta.keys().isdisjoint(self._array_metadata.keys()):
raise ValueError(
"Please avoid using file property key names as metadata keys!"
)
Expand All @@ -185,7 +197,7 @@ def _write_array(
key: np.pad(value, (0, max_len - len(value)))
for key, value in one_d_buffers.items()
}
for mapping in meta, self._file_properties:
for mapping in meta, self._array_metadata:
for key, value in mapping.items():
model_array.meta[key] = value

Expand Down
7 changes: 7 additions & 0 deletions tiledb/ml/models/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,10 @@ def preview(self) -> str:
:return: str. A string representation of the models internal configuration.
"""
return str(self.artifact) if self.artifact else ""

def preview_short(self) -> str:
"""Create a string representation of the model that is under 2048 characters.
:return: str. A string representation of the models internal configuration.
"""
return self.preview()[0:2048]
10 changes: 10 additions & 0 deletions tiledb/ml/models/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ def preview(self, *, display: str = "text") -> str:
else:
return ""

def preview_short(self, *, display: str = "text") -> str:
"""Create a text representation of the model that is under 2048 characters.
:param display: If ‘diagram’, estimators will be displayed as a diagram in an
HTML format when shown in a jupyter notebook. If ‘text’, estimators will be
displayed as text.
:return: A string representation of the models internal configuration.
"""
return self.preview(display=display)[0:2048]

def _serialize_model(self) -> bytes:
"""Serialize a Sklearn model with pickle.
Expand Down
4 changes: 4 additions & 0 deletions tiledb/ml/models/tensorflow_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ def preview(self) -> str:
return model_summary
return ""

def preview_short(self) -> str:
"""Create a string representation of the Tensorflow model that is under 2048 characters."""
return self.preview()[0:2048]

def _serialize_optimizer_weights(
self,
) -> bytes:
Expand Down

0 comments on commit 3c42df2

Please sign in to comment.