Skip to content

Commit

Permalink
feat: suggest template variables when symbol is unknown (#48)
Browse files Browse the repository at this point in the history
Problem:
 When the user has a format string in their template that references a
template variable that doesn't exist (either at that scope or otherwise)
we just provide a generic "symbol unknown" style error message. We can
do better.

Solution:
 Add a recommender based on the edit distance between the unknown symbol
and the available symbols at that location. This is mostly targetting
typo errors (e.g. "Parm.Foo" instead of "Param.Foo") right now, so we
also have a threshold distance to avoid some misleading suggestions.

Result:

Given a template like:

```yaml
specificationVersion: jobtemplate-2023-09
name: DemoJob
parameterDefinitions:
- name: Foo
  type: INT
steps:
- name: DemoStep
  script:
    actions:
      onRun:
        command: echo
        args:
        - "{{Parm.Foo}}"
```

We generate the error:

```
__root__ -> steps[0] -> script -> actions -> onRun -> args[0]:
        Variable Parm.Foo does not exist at this location. Did you mean: Param.Foo
```

Signed-off-by: Daniel Neilson <53624638+ddneilson@users.noreply.github.com>
  • Loading branch information
ddneilson committed Feb 6, 2024
1 parent 6ba9a72 commit 435971a
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 3 deletions.
72 changes: 72 additions & 0 deletions src/openjd/model/_format_strings/_edit_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

from array import array


def closest(symbols: set[str], match: str) -> tuple[int, set[str]]:
"""Return the set of symbols that most closely match the given match symbol.
Returns:
tuple[int, set[str]]
- [0]: Distance from 'match' to its closest match(es)
- [1]: Empty-set - If there is no such symbol (i.e. this table is empty)
Other -- One or more symbols that match the closest
"""
best_cost = len(match) + 1
best_match = set()
for sym in symbols:
distance = _edit_distance(sym, match)
if distance < best_cost:
best_cost = distance
best_match = set((sym,))
elif distance == best_cost:
best_match.add(sym)
return best_cost, best_match


def _edit_distance(s1: str, s2: str) -> int:
# Levenshtein distance for turning s1 into s2.
# Dynamic programming implementation storing only two rows of the DP matrix.
# Reference: https://www.codeproject.com/Articles/13525/Fast-memory-efficient-Levenshtein-algorithm-2

if len(s1) == 0:
return len(s2)
if len(s2) == 0:
return len(s1)

# Previous row of distances. Initialized to:
# a0[i] = edit distance from "" to s2[0:i]
a0 = array("L", (i for i in range(0, len(s2) + 1)))

# Current row of distances -- initialized values are irrelevant; they'll be overwritten
a1 = array("L", a0.tobytes())

for s1_idx in range(1, len(s1) + 1):
# Calculate a1 as the edit distance from s1[0:s1_idx] to s2

# a1[0] = edit distance from s1[0:s1_idx] to ""
# i.e. delete s1_idx characters from s1
a1[0] = s1_idx
for s2_idx in range(1, len(s2) + 1):
# Calculate a1[s2_idx] as the edit distance from s1[0:s1_idx] to s2[0:s2_idx]
# given:
# a0[s2_idx-1] = edit distance from s1[0:s1_idx-1] to s2[0:s2_idx-1]
# a0[s2_idx] = edit distance from s1[0:s1_idx-1] to s2[0:s2_idx]
# a1[s2_idx-1] = edit distance from s1[0:s1_idx] to s2[0:s2_idx-1]

# If we have s2[0:s2_idx] already then the step would be to delete the s1[s1_idx]
delete_cost = a0[s2_idx] + 1

# If we have s2[0:s2_idx-1] and are inserting s1[s1_idx]
insert_cost = a1[s2_idx - 1] + 1

# If we have s2[0:s2_idx-1] and are changing s1[s1_idx] in to s2[s2_idx-1]
substitution_cost = a0[s2_idx - 1] + (0 if s1[s1_idx - 1] == s2[s2_idx - 1] else 1)

# Cost of going from s2[0:s2_idx-1] to s2[0:s2_idx]
a1[s2_idx] = min(delete_cost, insert_cost, substitution_cost)

# Swap for the next iteration
a0, a1 = a1, a0

return a0[len(s2)]
17 changes: 14 additions & 3 deletions src/openjd/model/_format_strings/_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any

from .._symbol_table import SymbolTable
from ._edit_distance import closest


class Node(ABC):
Expand Down Expand Up @@ -54,6 +55,11 @@ def __repr__(self) -> str: # pragma: no cover
pass


# A heuristic. Any closest match with an edit distance greater than this will
# not be returned as a closest match for error reporting purposes.
MAX_MATCH_DISTANCE_THRESHOLD = 5


@dataclass
class FullNameNode(Node):
"""Expression tree node representing a fully qualified identifier name.
Expand All @@ -64,9 +70,14 @@ class FullNameNode(Node):

def validate_symbol_refs(self, *, symbols: set[str]) -> None:
if self.name not in symbols:
raise ValueError(
f"{self.name} is referenced by an expression, but is out of scope or has no value"
)
msg = f"Variable {self.name} does not exist at this location."
distance, closest_matches = closest(symbols, self.name)
if distance < MAX_MATCH_DISTANCE_THRESHOLD:
if len(closest_matches) == 1:
msg += f" Did you mean: {''.join(closest_matches)}"
elif len(closest_matches) > 1:
msg += f" Did you mean one of: {', '.join(sorted(closest_matches))}"
raise ValueError(msg)

def evaluate(self, *, symtab: SymbolTable) -> Any:
if self.name not in symtab:
Expand Down
60 changes: 60 additions & 0 deletions test/openjd/model/format_strings/test_edit_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

import pytest

from openjd.model._format_strings._edit_distance import _edit_distance, closest


class TestEditDistance:
@pytest.mark.parametrize(
"s1,s2,expected",
[
pytest.param("", "", 0, id="empty strings"),
pytest.param("", "a", 1, id="empty s1"),
pytest.param("a", "", 1, id="empty s2"),
pytest.param("a", "bc", 2, id="seq1"),
pytest.param("ab", "bc", 2, id="seq2"),
pytest.param("abc", "bc", 1, id="seq3/delete-start"),
pytest.param("abc", "ac", 1, id="delete inside"),
pytest.param("abc", "ab", 1, id="delete end"),
pytest.param("abc", "zabc", 1, id="insert start"),
pytest.param("abc", "azbc", 1, id="insert inside"),
pytest.param("abc", "abcz", 1, id="insert end"),
pytest.param(
"abcdefghijklmnopqrstuvwxyz", "zyxwvutsrqponmlkjihgfedcba", 26, id="reverse"
),
],
)
def test(self, s1: str, s2: str, expected: int) -> None:
# WHEN
result = _edit_distance(s1, s2)

# THEN
assert result == expected


class TestClosest:
@pytest.mark.parametrize(
"given, match, expected",
[
pytest.param(set(), "Param.Foo", (10, set()), id="no match"),
pytest.param(
set(("Param.Foo", "Param.Boo", "Param.Another")),
"Parm.Foo",
(1, set(("Param.Foo",))),
id="One close",
),
pytest.param(
set(("Param.Foo", "Param.Boo", "Param.Another")),
"Param.Zoo",
(1, set(("Param.Foo", "Param.Boo"))),
id="Two closest",
),
],
)
def test(self, given: set[str], match: str, expected: tuple[int, set[str]]) -> None:
# WHEN
result = closest(given, match)

# THEN
assert result == expected
39 changes: 39 additions & 0 deletions test/openjd/model/format_strings/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,45 @@ def test_init_reraises_tokenizer_error(self):
with pytest.raises(TokenError):
InterpolationExpression(expr)

def test_validate_success(self) -> None:
# GIVEN
symbols = set(("Test.Name",))
expr = InterpolationExpression("Test.Name")

# THEN
expr.validate_symbol_refs(symbols=symbols) # Does not raise

@pytest.mark.parametrize(
"symbols, expr, error_matches",
[
pytest.param(
set(),
"Test.Foo",
"Variable Test.Foo does not exist at this location.",
id="empty set",
),
pytest.param(
set(("Test.Foo", "Test.Boo", "Test.Another")),
"Tst.Foo",
"Variable Tst.Foo does not exist at this location. Did you mean: Test.Foo",
id="one candidate",
),
pytest.param(
set(("Test.Foo", "Test.Boo", "Test.Another")),
"Test.Zoo",
"Variable Test.Zoo does not exist at this location. Did you mean one of: Test.Boo, Test.Foo",
id="two candidates",
),
],
)
def test_validate_error(self, symbols: set[str], expr: str, error_matches: str) -> None:
# GIVEN
test = InterpolationExpression(expr)

# THEN
with pytest.raises(ValueError, match=error_matches):
test.validate_symbol_refs(symbols=symbols)

def test_evaluate_success(self):
# GIVEN
symtab = SymbolTable()
Expand Down

0 comments on commit 435971a

Please sign in to comment.