Skip to content

Commit

Permalink
feat: load model from file (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
bassmang authored Jul 26, 2023
1 parent 89c4424 commit faddb3b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
47 changes: 46 additions & 1 deletion src/vowpal_wabbit_next/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(
args (List[str]): VowpalWabbit command line options for configuring the model. An overall list can be found `here <https://vowpalwabbit.org/docs/vowpal_wabbit/python/latest/command_line_args.html>`_. Options which affect the driver are not supported. For example:
`--sort_features`, `--ngram`, `--feature_limit`, `--ignore`, `--extra_metrics`, `--dump_json_weights_experimental`
model_data (Optional[bytes], optional): Bytes of a VW model to be loaded.
record_invert_hash (bool, optional): If true, the invert hash will be recorded for each example. This is required to use :py:meth:`vowpal_wabbit_next.Workspace.json_weights`. This will slow down parsing and learn/predict.
record_feature_names (bool, optional): If true, the invert hash will be recorded for each example. This is required to use :py:meth:`vowpal_wabbit_next.Workspace.json_weights`. This will slow down parsing and learn/predict.
record_metrics (bool, optional): If true, reduction metrics will be enabled and can be fetched with :py:attr:`vowpal_wabbit_next.Workspace.metrics`
enable_debug_tree (bool, optional): If true, debug information in the form of the computation tree will be emitted by :py:meth:`~vowpal_wabbit_next.learn_one`, :py:meth:`~vowpal_wabbit_next.predict_one` and :py:meth:`~vowpal_wabbit_next.predict_then_learn_one`. This will affect performance negatively. See :py:class:`~vowpal_wabbit_next.DebugNode` for more information.
Expand Down Expand Up @@ -335,6 +335,51 @@ def serialize(self) -> bytes:
"""
return self._workspace.serialize()

@staticmethod
def load_from_file(
file_path: Union[str, os.PathLike[Any]],
args: List[str] = [],
*,
record_feature_names: bool = False,
record_metrics: bool = False,
enable_debug_tree: bool = False,
) -> Workspace[Any]:
"""Load a VW model from a file.
Args:
file_path (Union[str, os.PathLike[Any]]): Path to file containing serialized model
args (List[str]): VowpalWabbit command line options for configuring the model. An overall list can be found `here <https://vowpalwabbit.org/docs/vowpal_wabbit/python/latest/command_line_args.html>`_. Options which affect the driver are not supported. For example:
`--sort_features`, `--ngram`, `--feature_limit`, `--ignore`, `--extra_metrics`, `--dump_json_weights_experimental`
record_feature_names (bool, optional): If true, the invert hash will be recorded for each example. This is required to use :py:meth:`vowpal_wabbit_next.Workspace.json_weights`. This will slow down parsing and learn/predict.
record_metrics (bool, optional): If true, reduction metrics will be enabled and can be fetched with :py:attr:`vowpal_wabbit_next.Workspace.metrics`
enable_debug_tree (bool, optional): If true, debug information in the form of the computation tree will be emitted by :py:meth:`~vowpal_wabbit_next.learn_one`, :py:meth:`~vowpal_wabbit_next.predict_one` and :py:meth:`~vowpal_wabbit_next.predict_then_learn_one`. This will affect performance negatively. See :py:class:`~vowpal_wabbit_next.DebugNode` for more information.
.. warning::
This is an experimental feature.
Returns:
Workspace[Any]: Workspace with the loaded model
"""
with open(file_path, "rb") as f:
model_data = f.read()

if enable_debug_tree:
return Workspace[Literal[True]](
args,
model_data=model_data,
record_feature_names=record_feature_names,
record_metrics=record_metrics,
enable_debug_tree=True,
)
else:
return Workspace[Literal[False]](
args,
model_data=model_data,
record_feature_names=record_feature_names,
record_metrics=record_metrics,
enable_debug_tree=False,
)

def serialize_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
"""Serialize the current workspace as a VW model to a file."""
return self._workspace.serialize_to_file(os.fspath(file_path))
Expand Down
4 changes: 1 addition & 3 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ def test_serialize_to_file_and_load() -> None:
model.serialize_to_file(model_path)

try:
with open(model_path, "rb") as f:
model2 = vw.Workspace(model_data=f.read())

model2 = vw.Workspace.load_from_file(model_path)
parser2 = vw.TextFormatParser(model)
test_example2 = parser2.parse_line(test_example_input)
pred2 = model2.predict_one(test_example2)
Expand Down

0 comments on commit faddb3b

Please sign in to comment.