From 435971ac240c5fedf1c24310e9a3f50d487abaf6 Mon Sep 17 00:00:00 2001 From: Daniel Neilson <53624638+ddneilson@users.noreply.github.com> Date: Tue, 6 Feb 2024 15:21:42 -0600 Subject: [PATCH] feat: suggest template variables when symbol is unknown (#48) Problem: When the user has a format string in their template that references a template variable that doesn't exist (either at that scope or otherwise) we just provide a generic "symbol unknown" style error message. We can do better. Solution: Add a recommender based on the edit distance between the unknown symbol and the available symbols at that location. This is mostly targetting typo errors (e.g. "Parm.Foo" instead of "Param.Foo") right now, so we also have a threshold distance to avoid some misleading suggestions. Result: Given a template like: ```yaml specificationVersion: jobtemplate-2023-09 name: DemoJob parameterDefinitions: - name: Foo type: INT steps: - name: DemoStep script: actions: onRun: command: echo args: - "{{Parm.Foo}}" ``` We generate the error: ``` __root__ -> steps[0] -> script -> actions -> onRun -> args[0]: Variable Parm.Foo does not exist at this location. Did you mean: Param.Foo ``` Signed-off-by: Daniel Neilson <53624638+ddneilson@users.noreply.github.com> --- .../model/_format_strings/_edit_distance.py | 72 +++++++++++++++++++ src/openjd/model/_format_strings/_nodes.py | 17 ++++- .../format_strings/test_edit_distance.py | 60 ++++++++++++++++ .../model/format_strings/test_expression.py | 39 ++++++++++ 4 files changed, 185 insertions(+), 3 deletions(-) create mode 100644 src/openjd/model/_format_strings/_edit_distance.py create mode 100644 test/openjd/model/format_strings/test_edit_distance.py diff --git a/src/openjd/model/_format_strings/_edit_distance.py b/src/openjd/model/_format_strings/_edit_distance.py new file mode 100644 index 0000000..6eccf4b --- /dev/null +++ b/src/openjd/model/_format_strings/_edit_distance.py @@ -0,0 +1,72 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from array import array + + +def closest(symbols: set[str], match: str) -> tuple[int, set[str]]: + """Return the set of symbols that most closely match the given match symbol. + + Returns: + tuple[int, set[str]] + - [0]: Distance from 'match' to its closest match(es) + - [1]: Empty-set - If there is no such symbol (i.e. this table is empty) + Other -- One or more symbols that match the closest + """ + best_cost = len(match) + 1 + best_match = set() + for sym in symbols: + distance = _edit_distance(sym, match) + if distance < best_cost: + best_cost = distance + best_match = set((sym,)) + elif distance == best_cost: + best_match.add(sym) + return best_cost, best_match + + +def _edit_distance(s1: str, s2: str) -> int: + # Levenshtein distance for turning s1 into s2. + # Dynamic programming implementation storing only two rows of the DP matrix. + # Reference: https://www.codeproject.com/Articles/13525/Fast-memory-efficient-Levenshtein-algorithm-2 + + if len(s1) == 0: + return len(s2) + if len(s2) == 0: + return len(s1) + + # Previous row of distances. Initialized to: + # a0[i] = edit distance from "" to s2[0:i] + a0 = array("L", (i for i in range(0, len(s2) + 1))) + + # Current row of distances -- initialized values are irrelevant; they'll be overwritten + a1 = array("L", a0.tobytes()) + + for s1_idx in range(1, len(s1) + 1): + # Calculate a1 as the edit distance from s1[0:s1_idx] to s2 + + # a1[0] = edit distance from s1[0:s1_idx] to "" + # i.e. delete s1_idx characters from s1 + a1[0] = s1_idx + for s2_idx in range(1, len(s2) + 1): + # Calculate a1[s2_idx] as the edit distance from s1[0:s1_idx] to s2[0:s2_idx] + # given: + # a0[s2_idx-1] = edit distance from s1[0:s1_idx-1] to s2[0:s2_idx-1] + # a0[s2_idx] = edit distance from s1[0:s1_idx-1] to s2[0:s2_idx] + # a1[s2_idx-1] = edit distance from s1[0:s1_idx] to s2[0:s2_idx-1] + + # If we have s2[0:s2_idx] already then the step would be to delete the s1[s1_idx] + delete_cost = a0[s2_idx] + 1 + + # If we have s2[0:s2_idx-1] and are inserting s1[s1_idx] + insert_cost = a1[s2_idx - 1] + 1 + + # If we have s2[0:s2_idx-1] and are changing s1[s1_idx] in to s2[s2_idx-1] + substitution_cost = a0[s2_idx - 1] + (0 if s1[s1_idx - 1] == s2[s2_idx - 1] else 1) + + # Cost of going from s2[0:s2_idx-1] to s2[0:s2_idx] + a1[s2_idx] = min(delete_cost, insert_cost, substitution_cost) + + # Swap for the next iteration + a0, a1 = a1, a0 + + return a0[len(s2)] diff --git a/src/openjd/model/_format_strings/_nodes.py b/src/openjd/model/_format_strings/_nodes.py index 29dddb6..1cc5ac9 100644 --- a/src/openjd/model/_format_strings/_nodes.py +++ b/src/openjd/model/_format_strings/_nodes.py @@ -7,6 +7,7 @@ from typing import Any from .._symbol_table import SymbolTable +from ._edit_distance import closest class Node(ABC): @@ -54,6 +55,11 @@ def __repr__(self) -> str: # pragma: no cover pass +# A heuristic. Any closest match with an edit distance greater than this will +# not be returned as a closest match for error reporting purposes. +MAX_MATCH_DISTANCE_THRESHOLD = 5 + + @dataclass class FullNameNode(Node): """Expression tree node representing a fully qualified identifier name. @@ -64,9 +70,14 @@ class FullNameNode(Node): def validate_symbol_refs(self, *, symbols: set[str]) -> None: if self.name not in symbols: - raise ValueError( - f"{self.name} is referenced by an expression, but is out of scope or has no value" - ) + msg = f"Variable {self.name} does not exist at this location." + distance, closest_matches = closest(symbols, self.name) + if distance < MAX_MATCH_DISTANCE_THRESHOLD: + if len(closest_matches) == 1: + msg += f" Did you mean: {''.join(closest_matches)}" + elif len(closest_matches) > 1: + msg += f" Did you mean one of: {', '.join(sorted(closest_matches))}" + raise ValueError(msg) def evaluate(self, *, symtab: SymbolTable) -> Any: if self.name not in symtab: diff --git a/test/openjd/model/format_strings/test_edit_distance.py b/test/openjd/model/format_strings/test_edit_distance.py new file mode 100644 index 0000000..6e54e94 --- /dev/null +++ b/test/openjd/model/format_strings/test_edit_distance.py @@ -0,0 +1,60 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import pytest + +from openjd.model._format_strings._edit_distance import _edit_distance, closest + + +class TestEditDistance: + @pytest.mark.parametrize( + "s1,s2,expected", + [ + pytest.param("", "", 0, id="empty strings"), + pytest.param("", "a", 1, id="empty s1"), + pytest.param("a", "", 1, id="empty s2"), + pytest.param("a", "bc", 2, id="seq1"), + pytest.param("ab", "bc", 2, id="seq2"), + pytest.param("abc", "bc", 1, id="seq3/delete-start"), + pytest.param("abc", "ac", 1, id="delete inside"), + pytest.param("abc", "ab", 1, id="delete end"), + pytest.param("abc", "zabc", 1, id="insert start"), + pytest.param("abc", "azbc", 1, id="insert inside"), + pytest.param("abc", "abcz", 1, id="insert end"), + pytest.param( + "abcdefghijklmnopqrstuvwxyz", "zyxwvutsrqponmlkjihgfedcba", 26, id="reverse" + ), + ], + ) + def test(self, s1: str, s2: str, expected: int) -> None: + # WHEN + result = _edit_distance(s1, s2) + + # THEN + assert result == expected + + +class TestClosest: + @pytest.mark.parametrize( + "given, match, expected", + [ + pytest.param(set(), "Param.Foo", (10, set()), id="no match"), + pytest.param( + set(("Param.Foo", "Param.Boo", "Param.Another")), + "Parm.Foo", + (1, set(("Param.Foo",))), + id="One close", + ), + pytest.param( + set(("Param.Foo", "Param.Boo", "Param.Another")), + "Param.Zoo", + (1, set(("Param.Foo", "Param.Boo"))), + id="Two closest", + ), + ], + ) + def test(self, given: set[str], match: str, expected: tuple[int, set[str]]) -> None: + # WHEN + result = closest(given, match) + + # THEN + assert result == expected diff --git a/test/openjd/model/format_strings/test_expression.py b/test/openjd/model/format_strings/test_expression.py index 49c2dc3..b27d5d2 100644 --- a/test/openjd/model/format_strings/test_expression.py +++ b/test/openjd/model/format_strings/test_expression.py @@ -34,6 +34,45 @@ def test_init_reraises_tokenizer_error(self): with pytest.raises(TokenError): InterpolationExpression(expr) + def test_validate_success(self) -> None: + # GIVEN + symbols = set(("Test.Name",)) + expr = InterpolationExpression("Test.Name") + + # THEN + expr.validate_symbol_refs(symbols=symbols) # Does not raise + + @pytest.mark.parametrize( + "symbols, expr, error_matches", + [ + pytest.param( + set(), + "Test.Foo", + "Variable Test.Foo does not exist at this location.", + id="empty set", + ), + pytest.param( + set(("Test.Foo", "Test.Boo", "Test.Another")), + "Tst.Foo", + "Variable Tst.Foo does not exist at this location. Did you mean: Test.Foo", + id="one candidate", + ), + pytest.param( + set(("Test.Foo", "Test.Boo", "Test.Another")), + "Test.Zoo", + "Variable Test.Zoo does not exist at this location. Did you mean one of: Test.Boo, Test.Foo", + id="two candidates", + ), + ], + ) + def test_validate_error(self, symbols: set[str], expr: str, error_matches: str) -> None: + # GIVEN + test = InterpolationExpression(expr) + + # THEN + with pytest.raises(ValueError, match=error_matches): + test.validate_symbol_refs(symbols=symbols) + def test_evaluate_success(self): # GIVEN symtab = SymbolTable()