From 6ee2dcbd5e2afce35276dee7f808f132faa0553c Mon Sep 17 00:00:00 2001 From: LucaTomasko <89468158+LucaTomasko@users.noreply.github.com> Date: Thu, 25 Jan 2024 14:44:39 +0100 Subject: [PATCH] Add attack point iterator constants (#189) * Add attack point iterator constant --- scaaml/capture/input_generators/__init__.py | 3 +- .../input_generators/attack_point_iterator.py | 73 +++++++++++++++++++ .../test_attack_point_iterator.py | 59 +++++++++++++++ 3 files changed, 134 insertions(+), 1 deletion(-) create mode 100644 scaaml/capture/input_generators/attack_point_iterator.py create mode 100644 tests/capture/input_generators/test_attack_point_iterator.py diff --git a/scaaml/capture/input_generators/__init__.py b/scaaml/capture/input_generators/__init__.py index a788731a..9e307ef7 100644 --- a/scaaml/capture/input_generators/__init__.py +++ b/scaaml/capture/input_generators/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Attack point generators""" +"""Attack point generators and iterator.""" from scaaml.capture.input_generators.input_generators import balanced_generator, single_bunch, unrestricted_generator +from scaaml.capture.input_generators.attack_point_iterator import AttackPointIterator diff --git a/scaaml/capture/input_generators/attack_point_iterator.py b/scaaml/capture/input_generators/attack_point_iterator.py new file mode 100644 index 00000000..1c4e1e97 --- /dev/null +++ b/scaaml/capture/input_generators/attack_point_iterator.py @@ -0,0 +1,73 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +An Iterator that iterates through attack points +and can be used with config files. +""" + +from abc import ABC, abstractmethod +from typing import List + + +class AttackPointIterator: + """Attack point iterator class that iterates with different configs.""" + + def __init__(self, configuration) -> None: + """Initialize a new iterator.""" + self._attack_point_iterator_internal: AttackPointIteratorInternalBase + if configuration["operation"] == "constants": + constant_iterator = AttackPointIteratorInternalConstants( + name=configuration["name"], + values=configuration["values"], + ) + self._attack_point_iterator_internal = constant_iterator + else: + raise ValueError(f"{configuration['operation']} is not supported") + + def __len__(self) -> int: + """Return the number of iterated elements. + """ + return len(self._attack_point_iterator_internal) + + def __iter__(self): + """Start iterating.""" + return iter(self._attack_point_iterator_internal) + + +class AttackPointIteratorInternalBase(ABC): + "Attack point iterator abstract class." + + @abstractmethod + def __len__(self) -> int: + """Return the number of iterated elements. + """ + + @abstractmethod + def __iter__(self): + """Start iterating.""" + + +class AttackPointIteratorInternalConstants(AttackPointIteratorInternalBase): + """Attack point iterator class that iterates over a constant.""" + + def __init__(self, name: str, values: List[List[int]]) -> None: + """Initialize the constants to iterate.""" + self._name = name + self._values = values + + def __len__(self) -> int: + return len(self._values) + + def __iter__(self): + return iter({self._name: value} for value in self._values) diff --git a/tests/capture/input_generators/test_attack_point_iterator.py b/tests/capture/input_generators/test_attack_point_iterator.py new file mode 100644 index 00000000..ea42779c --- /dev/null +++ b/tests/capture/input_generators/test_attack_point_iterator.py @@ -0,0 +1,59 @@ +"""Test attack point iterator.""" + +import numpy as np +import pytest + +from scaaml.capture.input_generators import AttackPointIterator + + +def attack_point_iterator_constants(values): + input = {"operation": "constants", "name": "key", "values": values} + output = [obj['key'] for obj in list(iter(AttackPointIterator(input)))] + assert output == values + + +def test_attack_point_itarattor_no_legal_operation(): + values = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]] + input = {"operation": "NONE", "name": "key", "values": values} + with pytest.raises(ValueError): + AttackPointIterator(input) + + +def test_attack_point_iterator_constants(): + values = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]] + attack_point_iterator_constants(values=values) + + +def test_single_key_in_iterator_constants(): + values = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]] + input = {"operation": "constants", "name": "key", "values": values} + for constant in AttackPointIterator(input): + assert list(constant.keys()) == ["key"] + + +def test_attack_point_iterator_constants_no_values(): + input = {"operation": "constants", "name": "key"} + output = [] + with pytest.raises(KeyError): + AttackPointIterator(input) + + +def test_attack_point_iterator_constant_lengths(): + for l in range(4): + values = np.random.randint(256, size=(l, 17)) + attack_point_iterator_constants(values=values.tolist()) + + +def repeated_iteration(config): + rep_iterator = AttackPointIterator(config) + assert list(iter(rep_iterator)) == list(iter(rep_iterator)) + + +def test_repeated_iteration_constants(): + values = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]] + config = {"operation": "constants", "name": "key", "values": values} + repeated_iteration(config)