Skip to content

Commit

Permalink
fix: add default formula to greedy filter, update tests
Browse files Browse the repository at this point in the history
Signed-off-by: Francesco Fuggitti <francesco.fuggitti@gmail.com>
  • Loading branch information
francescofuggitti committed Feb 15, 2024
1 parent ae0c243 commit 5efcd2c
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 11 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ For instance, Rasa requires a `.tar.gz` format trained model in the
- [x] [Rasa](https://rasa.com/) intents/entities classifier (to use Rasa, please install it with `pip install -e ".[rasa]"`)
- [ ] [Watson Assistant](https://www.ibm.com/products/watson-assistant) intents/entities classifier -- Planned

**NOTE**: To use OpenAI GPT models don't forget to add your `OPEN_API_KEY` in a `.env` file under your project folder.
The `.env` file should look like:
**NOTE**: To use OpenAI GPT models don't forget to add the `OPEN_API_KEY` environment
variable with:
```bash
OPENAI_API_KEY=your_api_key
export OPENAI_API_KEY=your_api_key
```

## Write your own Engine
Expand Down Expand Up @@ -119,7 +119,11 @@ ltl_formulas = translate(utterance, engine=my_engine, filter=my_filter)
Contributions are welcome! Here's how to set up the development environment:
- set up your preferred virtualenv environment
- clone the repo: `git clone https://github.com/IBM/nl2ltl.git && cd nl2ltl`
- install dependencies: `pip install -e .`
- install dev dependencies: `pip install -e ".[dev]"`
- install pre-commit: `pre-commit install`
- sign-off your commits using the `-s` flag in the commit message to be compliant with
the [DCO](https://developercertificate.org/)

## Tests

Expand Down
3 changes: 2 additions & 1 deletion nl2ltl/filters/simple_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pylogics.syntax.base import Formula

from nl2ltl.declare.base import Template
from nl2ltl.filters.base import Filter
from nl2ltl.filters.utils.conflicts import conflicts
from nl2ltl.filters.utils.subsumptions import subsumptions
Expand Down Expand Up @@ -44,7 +45,7 @@ def enforce(output: Dict[Formula, float], entities: Dict[str, float], **kwargs)
"""
result_set = set()

highest_scoring_formula = max(output, key=output.get)
highest_scoring_formula = max(output, key=output.get, default=Template)
formula_conflicts = conflicts(highest_scoring_formula)
formula_subsumptions = subsumptions(highest_scoring_formula)

Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

class UtterancesFixtures:
utterances = [
"whenever I get a Slack, send a Gmail",
"Invite Sales employees to Thursday's meeting",
"If a new Eventbrite is created, alert me through Slack",
"send me a Slack whenever I get a Gmail",
"whenever I get a Slack, send a Gmail.",
"Invite Sales employees.",
"If a new Eventbrite is created, alert me through Slack.",
"send me a Slack whenever I get a Gmail.",
]
4 changes: 2 additions & 2 deletions tests/test_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from nl2ltl import translate
from nl2ltl.engines.gpt.core import GPTEngine, Models
from nl2ltl.engines.gpt.core import GPTEngine
from nl2ltl.filters.simple_filters import BasicFilter, GreedyFilter

from .conftest import UtterancesFixtures
Expand All @@ -18,7 +18,7 @@ def setup_class(cls):
"""Setup any state specific to the execution of the given class (which
usually contains tests).
"""
cls.gpt_engine = GPTEngine(model=Models.GPT35_INSTRUCT.value)
cls.gpt_engine = GPTEngine()
cls.basic_filter = BasicFilter()
cls.greedy_filter = GreedyFilter()

Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ commands = ruff check .

[testenv:ruff-check-apply]
skip_install = True
deps = ruff==0.1.9r
deps = ruff==0.1.9
commands = ruff check --fix --show-fixes .

[testenv:ruff-format]
Expand Down

0 comments on commit 5efcd2c

Please sign in to comment.