diff --git a/scaaml/capture/input_generators/attack_point_iterator.py b/scaaml/capture/input_generators/attack_point_iterator.py index ffb2c8d1..1baf6697 100644 --- a/scaaml/capture/input_generators/attack_point_iterator.py +++ b/scaaml/capture/input_generators/attack_point_iterator.py @@ -57,10 +57,6 @@ def __len__(self) -> int: def __iter__(self): """Start iterating.""" - @abstractmethod - def __next__(self) -> Dict[str, List[int]]: - """Next iterated element.""" - class AttackPointIteratorInternalConstants(AttackPointIteratorInternalBase): """Attack point iterator class that iterates over a constant.""" @@ -69,18 +65,10 @@ def __init__(self, name: str, values: List[List[int]]) -> None: """Initialize the constants to iterate.""" self._name = name self._values = values - self._index = 0 def __len__(self) -> int: return len(self._values) def __iter__(self): - self._index = 0 - return self + return iter({self._name: value} for value in self._values) - def __next__(self) -> Dict[str, List[int]]: - if self._index < self.__len__(): - self._index += 1 - return {self._name: self._values[self._index - 1]} - else: - raise StopIteration diff --git a/tests/capture/input_generators/test_attack_point_iterator.py b/tests/capture/input_generators/test_attack_point_iterator.py index 0a4a8e60..ea42779c 100644 --- a/tests/capture/input_generators/test_attack_point_iterator.py +++ b/tests/capture/input_generators/test_attack_point_iterator.py @@ -8,9 +8,7 @@ def attack_point_iterator_constants(values): input = {"operation": "constants", "name": "key", "values": values} - output = [] - for constant in AttackPointIterator(input): - output.append(constant[input["name"]]) + output = [obj['key'] for obj in list(iter(AttackPointIterator(input)))] assert output == values @@ -51,21 +49,7 @@ def test_attack_point_iterator_constant_lengths(): def repeated_iteration(config): rep_iterator = AttackPointIterator(config) - expected_elements = len(rep_iterator) - - first_iteration_elements = 0 - first_iteration = AttackPointIterator(config) - for r, f in zip(rep_iterator, first_iteration): - assert r == f - first_iteration_elements += 1 - assert first_iteration_elements == expected_elements - - second_iteration_elements = 0 - second_iteration = AttackPointIterator(config) - for r, s in zip(rep_iterator, first_iteration): - assert r == s - second_iteration_elements += 1 - assert second_iteration_elements == expected_elements + assert list(iter(rep_iterator)) == list(iter(rep_iterator)) def test_repeated_iteration_constants():