diff --git a/README.md b/README.md index d45ebc4..f0c2a93 100644 --- a/README.md +++ b/README.md @@ -65,10 +65,10 @@ For instance, Rasa requires a `.tar.gz` format trained model in the - [x] [Rasa](https://rasa.com/) intents/entities classifier (to use Rasa, please install it with `pip install -e ".[rasa]"`) - [ ] [Watson Assistant](https://www.ibm.com/products/watson-assistant) intents/entities classifier -- Planned -**NOTE**: To use OpenAI GPT models don't forget to add your `OPEN_API_KEY` in a `.env` file under your project folder. -The `.env` file should look like: +**NOTE**: To use OpenAI GPT models don't forget to add the `OPEN_API_KEY` environment +variable with: ```bash -OPENAI_API_KEY=your_api_key +export OPENAI_API_KEY=your_api_key ``` ## Write your own Engine @@ -119,7 +119,11 @@ ltl_formulas = translate(utterance, engine=my_engine, filter=my_filter) Contributions are welcome! Here's how to set up the development environment: - set up your preferred virtualenv environment - clone the repo: `git clone https://github.com/IBM/nl2ltl.git && cd nl2ltl` +- install dependencies: `pip install -e .` - install dev dependencies: `pip install -e ".[dev]"` +- install pre-commit: `pre-commit install` +- sign-off your commits using the `-s` flag in the commit message to be compliant with +the [DCO](https://developercertificate.org/) ## Tests diff --git a/nl2ltl/filters/simple_filters.py b/nl2ltl/filters/simple_filters.py index 506b40d..65ae13c 100644 --- a/nl2ltl/filters/simple_filters.py +++ b/nl2ltl/filters/simple_filters.py @@ -3,6 +3,7 @@ from pylogics.syntax.base import Formula +from nl2ltl.declare.base import Template from nl2ltl.filters.base import Filter from nl2ltl.filters.utils.conflicts import conflicts from nl2ltl.filters.utils.subsumptions import subsumptions @@ -44,7 +45,7 @@ def enforce(output: Dict[Formula, float], entities: Dict[str, float], **kwargs) """ result_set = set() - highest_scoring_formula = max(output, key=output.get) + highest_scoring_formula = max(output, key=output.get, default=Template) formula_conflicts = conflicts(highest_scoring_formula) formula_subsumptions = subsumptions(highest_scoring_formula) diff --git a/tests/conftest.py b/tests/conftest.py index 0e2ff51..ca9ca9b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,8 +13,8 @@ class UtterancesFixtures: utterances = [ - "whenever I get a Slack, send a Gmail", - "Invite Sales employees to Thursday's meeting", - "If a new Eventbrite is created, alert me through Slack", - "send me a Slack whenever I get a Gmail", + "whenever I get a Slack, send a Gmail.", + "Invite Sales employees.", + "If a new Eventbrite is created, alert me through Slack.", + "send me a Slack whenever I get a Gmail.", ] diff --git a/tests/test_gpt.py b/tests/test_gpt.py index 643bfa6..841e13a 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -4,7 +4,7 @@ import pytest from nl2ltl import translate -from nl2ltl.engines.gpt.core import GPTEngine, Models +from nl2ltl.engines.gpt.core import GPTEngine from nl2ltl.filters.simple_filters import BasicFilter, GreedyFilter from .conftest import UtterancesFixtures @@ -18,7 +18,7 @@ def setup_class(cls): """Setup any state specific to the execution of the given class (which usually contains tests). """ - cls.gpt_engine = GPTEngine(model=Models.GPT35_INSTRUCT.value) + cls.gpt_engine = GPTEngine() cls.basic_filter = BasicFilter() cls.greedy_filter = GreedyFilter() diff --git a/tox.ini b/tox.ini index 7c3febc..906f448 100644 --- a/tox.ini +++ b/tox.ini @@ -35,7 +35,7 @@ commands = ruff check . [testenv:ruff-check-apply] skip_install = True -deps = ruff==0.1.9r +deps = ruff==0.1.9 commands = ruff check --fix --show-fixes . [testenv:ruff-format]