Skip to content

Commit

Permalink
Update Python examples
Browse files Browse the repository at this point in the history
  • Loading branch information
mwydmuch committed Nov 4, 2020
1 parent 6ca020e commit 499a368
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Coming soon:

- Possibility to use any type of binary classifier from Python.
- Efficient prediction with different threshold for each label.
- Improved dataset loading from Python.
- Improved dataset loading in Python.
- More datasets from XML Repository.


Expand Down
28 changes: 22 additions & 6 deletions python/examples/basic.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,39 @@
#!/usr/bin/env python3

# This demo shows the most basic usage of the napkinXC library.

from napkinxc.datasets import load_dataset
from napkinxc.models import PLT
from napkinxc.measures import precision_at_k

# Use load_dataset function to load one of the benchmark datasets
# from XML Repository (http://manikvarma.org/downloads/XC/XMLRepository.html)
# from XML Repository (http://manikvarma.org/downloads/XC/XMLRepository.html).
X_train, Y_train = load_dataset("eurlex-4k", "train")
X_test, Y_test = load_dataset("eurlex-4k", "test")

# Create Probabilistic Labels Tree models,
# directory "eurlex-model" will be created and used for model training and storing
# Create Probabilistic Labels Tree model,
# directory "eurlex-model" will be created and used during model training.
# napkinXC stores already trained parts of the model to save RAM.
# Model directory is only a required argument for model constructors.
plt = PLT("eurlex-model")

# Fit the model on the train dataset
# Fit the model on the training dataset.
# The model weights and additional data will be stored in "eurlex-model" directory.
# Features matrix X must be SciPy csr_matrix, NumPy array, or list of tuples of (idx, value),
# while labels matrix Y should be list of lists or tuples containing positive labels.
plt.fit(X_train, Y_train)

# Predict only the best label for each datapoint in the test dataset
# After the training model is not loaded to RAM.
# You can preload the model to RAM to perform prediction.
plt.load()

# Predict only the best label (top-1 label) for each data point in the test dataset.
# This will also load the model if it is not loaded.
Y_pred = plt.predict(X_test, top_k=1)

# Evaluate precision at 1
# Evaluate the prediction with precision at 1 measure.
print(precision_at_k(Y_test, Y_pred, k=1))

# Unload the model from RAM
# You can also just delete the object if you do not need it
plt.unload()
37 changes: 37 additions & 0 deletions python/examples/train_store_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/usr/bin/env python3

# This demo shows how to train, store and later load the napkinXC model.

from napkinxc.datasets import load_dataset
from napkinxc.models import PLT
from napkinxc.measures import precision_at_k

# The beginning is the same as in the basic.py example.

# Use load_dataset function to load one of the benchmark datasets
# from XML Repository (http://manikvarma.org/downloads/XC/XMLRepository.html).
X_train, Y_train = load_dataset("eurlex-4k", "train")
X_test, Y_test = load_dataset("eurlex-4k", "test")

# Create PLT model with "eurlex-model" directory,
# it will be created and used during model training for storing weights.
# napkinXC stores already trained parts of the models to save RAM.
plt = PLT("eurlex-model")

# Fit the model on the training dataset.
# The model weights and additional data will be stored in "eurlex-model" directory.
plt.fit(X_train, Y_train)

# Predict.
Y_pred = plt.predict(X_test, top_k=1)
print(precision_at_k(Y_test, Y_pred, k=1))

# Delete plt object.
del plt

# To load the model, create a new PLT object with the same directory as the previous one.
new_plt = PLT("eurlex-model")

# Predict using a new model object.
Y_pred = new_plt.predict(X_test, top_k=1)
print(precision_at_k(Y_test, Y_pred, k=1))

0 comments on commit 499a368

Please sign in to comment.