diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..2b13816 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,37 @@ +name: CI + +on: + push: + branches: + - master + pull_request: {} + +jobs: + build-tiny: + runs-on: ubuntu-20.04 + strategy: + max-parallel: 4 + matrix: + python-version: [3.6, 3.7, 3.8] + + steps: + - uses: actions/checkout@v1 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install . + - name: Lint with flake8 + run: | + pip install flake8 + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + pip install pytest + pytest tests diff --git a/allennlp_optuna/commands/tune.py b/allennlp_optuna/commands/tune.py index 89507d3..1f802bc 100644 --- a/allennlp_optuna/commands/tune.py +++ b/allennlp_optuna/commands/tune.py @@ -63,13 +63,13 @@ def _objective( if "pruner" in optuna_config: pruner_class = getattr(optuna.pruners, optuna_config["pruner"]["type"]) - pruner = pruner_class(**optuna_config["pruner"]["attributes"]) + pruner = pruner_class(**optuna_config["pruner"].get("attributes", {})) else: pruner = None if "sampler" in optuna_config: sampler_class = getattr(optuna.samplers, optuna_config["sampler"]["type"]) - sampler = sampler_class(optuna_config["sampler"]["attributes"]) + sampler = sampler_class(optuna_config["sampler"].get("attributes", {})) else: sampler = None diff --git a/test_fixtures/config/classifier.jsonnet b/test_fixtures/config/classifier.jsonnet new file mode 100644 index 0000000..f76898c --- /dev/null +++ b/test_fixtures/config/classifier.jsonnet @@ -0,0 +1,44 @@ +local dropout = std.parseJson(std.extVar('dropout')); +local embedding_dim = std.parseInt(std.extVar('embedding_dim')); +local lr = std.parseJson(std.extVar('lr')); + + +{ + dataset_reader: { + type: 'text_classification_json', + token_indexers: { + tokens: { + type: 'single_id' + }, + }, + }, + train_data_path: 'test_fixtures/data/train.jsonl', + validation_data_path: 'test_fixtures/data/valid.jsonl', + model: { + type: 'basic_classifier', + text_field_embedder: { + token_embedders: { + tokens: { + type: 'embedding', + embedding_dim: embedding_dim, + }, + }, + }, + seq2vec_encoder: { + type: 'lstm', + input_size: embedding_dim, + hidden_size: 5, + dropout: dropout, + } + }, + data_loader: { + batch_size: 1, + }, + trainer: { + optimizer: { + type: 'adam', + lr: lr, + }, + num_epochs: 5, + }, +} diff --git a/test_fixtures/config/classifier_with_pruning.jsonnet b/test_fixtures/config/classifier_with_pruning.jsonnet new file mode 100644 index 0000000..88c74cb --- /dev/null +++ b/test_fixtures/config/classifier_with_pruning.jsonnet @@ -0,0 +1,47 @@ +local dropout = std.parseJson(std.extVar('dropout')); +local embedding_dim = std.parseInt(std.extVar('embedding_dim')); +local lr = std.parseJson(std.extVar('lr')); + + +{ + dataset_reader: { + type: 'text_classification_json', + token_indexers: { + tokens: { + type: 'single_id' + }, + }, + }, + train_data_path: 'test_fixtures/data/train.jsonl', + validation_data_path: 'test_fixtures/data/valid.jsonl', + model: { + type: 'basic_classifier', + text_field_embedder: { + token_embedders: { + tokens: { + type: 'embedding', + embedding_dim: embedding_dim, + }, + }, + }, + seq2vec_encoder: { + type: 'lstm', + input_size: embedding_dim, + hidden_size: 5, + dropout: dropout, + } + }, + data_loader: { + batch_size: 1, + }, + trainer: { + optimizer: { + type: 'adam', + lr: lr, + }, + num_epochs: 5, + callbacks: [ + { type: 'optuna_pruner' }, + ] + }, +} diff --git a/test_fixtures/config/hparams.json b/test_fixtures/config/hparams.json new file mode 100644 index 0000000..18909c7 --- /dev/null +++ b/test_fixtures/config/hparams.json @@ -0,0 +1,27 @@ +[ + { + "type": "int", + "attributes": { + "name": "embedding_dim", + "low": 5, + "high": 20 + } + }, + { + "type": "float", + "attributes": { + "name": "dropout", + "low": 0.0, + "high": 0.5 + } + }, + { + "type": "float", + "attributes": { + "name": "lr", + "low": 5e-3, + "high": 5e-1, + "log": true + } + } +] diff --git a/test_fixtures/config/optuna.json b/test_fixtures/config/optuna.json new file mode 100644 index 0000000..8974778 --- /dev/null +++ b/test_fixtures/config/optuna.json @@ -0,0 +1,15 @@ +{ + "pruner": { + "type": "HyperbandPruner", + "attributes": { + "min_resource": 1, + "reduction_factor": 5 + } + }, + "sampler": { + "type": "TPESampler", + "attributes": { + "n_startup_trials": 5 + } + } +} diff --git a/test_fixtures/config/optuna_without_attribute.json b/test_fixtures/config/optuna_without_attribute.json new file mode 100644 index 0000000..057fea2 --- /dev/null +++ b/test_fixtures/config/optuna_without_attribute.json @@ -0,0 +1,8 @@ +{ + "pruner": { + "type": "HyperbandPruner" + }, + "sampler": { + "type": "TPESampler" + } +} diff --git a/test_fixtures/data/train.jsonl b/test_fixtures/data/train.jsonl new file mode 100644 index 0000000..8b1797a --- /dev/null +++ b/test_fixtures/data/train.jsonl @@ -0,0 +1,4 @@ +{"text": "I like a pan.", "label": 1} +{"text": "I love a pan.", "label": 1} +{"text": "I dislike a pan.", "label": 0} +{"text": "I hate a pan.", "label": 0} diff --git a/test_fixtures/data/valid.jsonl b/test_fixtures/data/valid.jsonl new file mode 100644 index 0000000..95138e4 --- /dev/null +++ b/test_fixtures/data/valid.jsonl @@ -0,0 +1,4 @@ +{"text": "I like a pen.", "label": 1} +{"text": "I love a pen.", "label": 1} +{"text": "I dislike a pen.", "label": 0} +{"text": "I hate a pen.", "label": 0} diff --git a/tests/commands/test_tune.py b/tests/commands/test_tune.py new file mode 100644 index 0000000..5905f08 --- /dev/null +++ b/tests/commands/test_tune.py @@ -0,0 +1,61 @@ +import os.path +import subprocess +import tempfile + + +def test_tune(): + with tempfile.TemporaryDirectory() as tmpdir: + storage = "sqlite:///" + os.path.join(tmpdir, "allennlp_optuna.db") + command = [ + "allennlp", + "tune", + "test_fixtures/config/classifier.jsonnet", + "test_fixtures/config/hparams.json", + "--serialization-dir", + tmpdir, + "--storage", + storage, + "--n-trials", + "3", + ] + subprocess.check_call(command) + + +def test_tune_with_pruner(): + with tempfile.TemporaryDirectory() as tmpdir: + storage = "sqlite:///" + os.path.join(tmpdir, "allennlp_optuna.db") + command = [ + "allennlp", + "tune", + "test_fixtures/config/classifier_with_pruning.jsonnet", + "test_fixtures/config/hparams.json", + "--optuna-param-path", + "test_fixtures/config/optuna.json", + "--serialization-dir", + tmpdir, + "--storage", + storage, + "--n-trials", + "3", + ] + subprocess.check_call(command) + + +def test_tune_with_pruner_without_attribute(): + with tempfile.TemporaryDirectory() as tmpdir: + storage = "sqlite:///" + os.path.join(tmpdir, "allennlp_optuna.db") + command = [ + "allennlp", + "tune", + "test_fixtures/config/classifier_with_pruning.jsonnet", + "test_fixtures/config/hparams.json", + "--optuna-param-path", + "test_fixtures/config/optuna_without_attribute.json", + "--serialization-dir", + tmpdir, + "--storage", + storage, + "--n-trials", + "3", + ] + subprocess.check_call(command)