Skip to content

Commit

Permalink
Move the glue demo to a self-contained directory.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 641953632
  • Loading branch information
llcourage authored and LIT team committed Jun 10, 2024
1 parent 08289df commit fc7b0d0
Show file tree
Hide file tree
Showing 24 changed files with 329 additions and 928 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
run: python -m pip install -r requirements.txt
- name: Test Python library
- name: Install LIT package
run: python -m pip install -e .
- name: Test LIT
run: |
python -m pip install pytest
pytest -v
Expand Down
6 changes: 3 additions & 3 deletions lit_nlp/components/ablation_flip_int_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from absl.testing import absltest
from lit_nlp.api import types
from lit_nlp.components import ablation_flip
from lit_nlp.examples.models import glue_models
from lit_nlp.examples.glue import models as glue_models
import numpy as np


Expand Down Expand Up @@ -66,12 +66,12 @@ def setUp(self):
self.classification_model = glue_models.SST2Model(BERT_TINY_PATH)
self.classification_config = {ablation_flip.PREDICTION_KEY: 'probas'}

# Clasification model with the 'sentence' field marked as
# Classification model with the 'sentence' field marked as
# non-required.
self.classification_model_non_required_field = SST2ModelNonRequiredField(
BERT_TINY_PATH)

# Clasification model with a counter to count number of predict calls.
# Classification model with a counter to count number of predict calls.
# TODO(ataly): Consider setting up a Mock object to count number of
# predict calls.
self.classification_model_with_predict_counter = (
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/components/hotflip_int_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from absl.testing import parameterized
from lit_nlp.components import hotflip
# TODO(lit-dev): Move glue_models out of lit_nlp/examples
from lit_nlp.examples.models import glue_models
from lit_nlp.examples.glue import models as glue_models
import numpy as np

from lit_nlp.lib import file_cache
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/components/tcav_int_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from absl.testing import parameterized
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.components import tcav
from lit_nlp.examples.models import glue_models
from lit_nlp.examples.glue import models as glue_models
from lit_nlp.lib import caching # for hash id fn
from lit_nlp.lib import testing_utils

Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/components/thresholder_int_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types as lit_types
from lit_nlp.components import thresholder
from lit_nlp.examples.models import glue_models
from lit_nlp.examples.glue import models as glue_models
from lit_nlp.lib import caching # for hash id fn


Expand Down
16 changes: 8 additions & 8 deletions lit_nlp/examples/blank_slate_demo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
r"""An blank demo ready to load models and datasets.
r"""A blank demo ready to load models and datasets.
The currently supported models and datasets are:
- classification model on SST-2, with the Stanford Sentiment Treebank dataset.
Expand Down Expand Up @@ -30,9 +30,9 @@
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.examples.datasets import classification
from lit_nlp.examples.datasets import glue
from lit_nlp.examples.datasets import lm
from lit_nlp.examples.models import glue_models
from lit_nlp.examples.glue import data as glue_data
from lit_nlp.examples.glue import models as glue_models
from lit_nlp.examples.models import pretrained_lms
from lit_nlp.examples.penguin import data as penguin_data
from lit_nlp.examples.penguin import model as penguin_model
Expand Down Expand Up @@ -99,9 +99,9 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
dataset_loaders: lit_app.DatasetLoadersMap = {}

# glue demo dataset loaders.
dataset_loaders["sst2"] = (glue.SST2Data, glue.SST2Data.init_spec())
dataset_loaders["stsb"] = (glue.STSBData, glue.STSBData.init_spec())
dataset_loaders["mnli"] = (glue.MNLIData, glue.MNLIData.init_spec())
dataset_loaders["sst2"] = (glue_data.SST2Data, glue_data.SST2Data.init_spec())
dataset_loaders["stsb"] = (glue_data.STSBData, glue_data.STSBData.init_spec())
dataset_loaders["mnli"] = (glue_data.MNLIData, glue_data.MNLIData.init_spec())

# penguin demo dataset loaders.
dataset_loaders["penguin"] = (
Expand All @@ -111,8 +111,8 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:

# lm demo dataset loaders.
dataset_loaders["sst (lm)"] = (
glue.SST2DataForLM,
glue.SST2DataForLM.init_spec(),
glue_data.SST2DataForLM,
glue_data.SST2DataForLM.init_spec(),
)
dataset_loaders["imdb (lm)"] = (
classification.IMDBData,
Expand Down
6 changes: 3 additions & 3 deletions lit_nlp/examples/custom_module/potato_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.api import layout
from lit_nlp.examples.datasets import glue
from lit_nlp.examples.models import glue_models
from lit_nlp.examples.glue import data as glue_data
from lit_nlp.examples.glue import models as glue_models
from lit_nlp.lib import file_cache

# NOTE: additional flags defined in server_flags.py
Expand Down Expand Up @@ -84,7 +84,7 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
model, extract_compressed_file=True)

models = {"sst": glue_models.SST2Model(model)}
datasets = {"sst_dev": glue.SST2Data("validation")}
datasets = {"sst_dev": glue_data.SST2Data("validation")}

# Start the LIT server. See server_flags.py for server options.
lit_demo = dev_server.Server(
Expand Down
File renamed without changes.
43 changes: 29 additions & 14 deletions lit_nlp/examples/glue_demo.py → lit_nlp/examples/glue/demo.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
r"""Example demo loading a handful of GLUE models.
For a quick-start set of models, run:
python -m lit_nlp.examples.glue_demo \
blaze run -c opt --config=cuda examples/glue:demo -- \
--quickstart --port=5432
To run with the 'normal' defaults, including full-size BERT models:
python -m lit_nlp.examples.glue_demo --port=5432
blaze run -c opt --config=cuda examples/glue:demo -- --port=5432
Then navigate to localhost:5432 to access the demo UI.
"""

from collections.abc import Sequence
import sys
from typing import Optional
Expand All @@ -19,8 +20,8 @@
from lit_nlp import app as lit_app
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.examples.datasets import glue
from lit_nlp.examples.models import glue_models
from lit_nlp.examples.glue import data as glue_data
from lit_nlp.examples.glue import models as glue_models

# NOTE: additional flags defined in server_flags.py

Expand All @@ -29,8 +30,10 @@
FLAGS.set_default("development_demo", True)

_QUICKSTART = flags.DEFINE_bool(
"quickstart", False,
"Quick-start mode, loads smaller models and a subset of the full data.")
"quickstart",
False,
"Quick-start mode, loads smaller models and a subset of the full data.",
)

_MODELS = flags.DEFINE_list(
"models",
Expand All @@ -50,9 +53,12 @@
)

_MAX_EXAMPLES = flags.DEFINE_integer(
"max_examples", None, "Maximum number of examples to load into LIT. "
"max_examples",
None,
"Maximum number of examples to load into LIT. "
"Note: MNLI eval set is 10k examples, so will take a while to run and may "
"be slow on older machines. Set --max_examples=200 for a quick start.")
"be slow on older machines. Set --max_examples=200 for a quick start.",
)

MODELS_BY_TASK = {
"sst2": glue_models.SST2Model,
Expand Down Expand Up @@ -123,24 +129,33 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
# split = 'validation' will also work, but this will cause TDFS to download
# the entire dataset which can be very slow.
split = "https://storage.googleapis.com/what-if-tool-resources/lit-data/sst2.validation.csv"
datasets["sst_dev"] = glue.SST2Data(split)
dataset_loaders["sst2"] = (glue.SST2Data, glue.SST2Data.init_spec())
datasets["sst_dev"] = glue_data.SST2Data(split)
dataset_loaders["sst2"] = (
glue_data.SST2Data,
glue_data.SST2Data.init_spec(),
)

if "stsb" in tasks_to_load:
logging.info("Loading data for STS-B task.")
# split = 'validation' will also work, but this will cause TDFS to download
# the entire dataset which can be very slow.
split = "https://storage.googleapis.com/what-if-tool-resources/lit-data/stsb.validation.csv"
datasets["stsb_dev"] = glue.STSBData(split)
dataset_loaders["stsb"] = (glue.STSBData, glue.STSBData.init_spec())
datasets["stsb_dev"] = glue_data.STSBData(split)
dataset_loaders["stsb"] = (
glue_data.STSBData,
glue_data.STSBData.init_spec(),
)

if "mnli" in tasks_to_load:
logging.info("Loading data for MultiNLI task.")
# split = 'validation_matched' will also work, but this will cause TDFS to
# download the entire dataset which can be very slow.
split = "https://storage.googleapis.com/what-if-tool-resources/lit-data/mnli.validation_matched.csv"
datasets["mnli_dev"] = glue.MNLIData(split)
dataset_loaders["mnli"] = (glue.MNLIData, glue.MNLIData.init_spec())
datasets["mnli_dev"] = glue_data.MNLIData(split)
dataset_loaders["mnli"] = (
glue_data.MNLIData,
glue_data.MNLIData.init_spec(),
)

# Truncate datasets if --max_examples is set.
if _MAX_EXAMPLES.value is not None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
r"""Integration tests for lit_nlp.examples.models.glue_models.
r"""Integration tests for lit_nlp.examples.glue.models.
Test locally with:
blaze test //third_party/py/lit_nlp/examples/models:integration_tests \
blaze test //third_party/py/lit_nlp/examples/glue:integration_tests \
--guitar_cluster=LOCAL \
--test_output=streamed \
--guitar_detach
Expand All @@ -11,7 +11,7 @@
from typing import Any
from absl.testing import absltest
from absl.testing import parameterized
from lit_nlp.examples.models import glue_models
from lit_nlp.examples.glue import models as glue_models
from lit_nlp.lib import file_cache


Expand All @@ -24,7 +24,8 @@ def __init__(self, *args: Any, **kwargs: Any):
model_path = "https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz" # pylint: disable=line-too-long
if model_path.endswith(".tar.gz"):
model_path = file_cache.cached_path(
model_path, extract_compressed_file=True)
model_path, extract_compressed_file=True
)
self.sst2_model = glue_models.SST2Model(model_path)

@parameterized.named_parameters(
Expand All @@ -48,31 +49,29 @@ def __init__(self, *args: Any, **kwargs: Any):
# Common multiple cases
dict(
testcase_name="no_attention_or_embeddings",
config={
"output_attention": False,
"output_embeddings": False
},
config={"output_attention": False, "output_embeddings": False},
),
dict(
testcase_name="no_attention_or_embeddings_or_gradients",
config={
"compute_grads": False,
"output_attention": False,
"output_embeddings": False
"output_embeddings": False,
},
),
)
def test_sst2_model_predict(self, config: dict[str, bool]):
# Configure model.
if config:
self.sst2_model.config = glue_models.GlueModelConfig(
# Include the SST-2 defaut config options
# Include the SST-2 default config options
text_a_name="sentence",
text_b_name=None,
labels=["0", "1"],
null_label_idx=0,
# Add the output-affecting config options
**config)
**config
)

# Run prediction to ensure no failure.
model_in = [{"sentence": "test sentence"}]
Expand All @@ -83,5 +82,6 @@ def test_sst2_model_predict(self, config: dict[str, bool]):
for key in self.sst2_model.output_spec().keys():
self.assertIn(key, model_out[0])


if __name__ == "__main__":
absltest.main()
Loading

0 comments on commit fc7b0d0

Please sign in to comment.