Skip to content

Commit

Permalink
Change base implementation
Browse files Browse the repository at this point in the history
Issue #196
  • Loading branch information
LucaTomasko committed Jan 31, 2024
1 parent e38a0ac commit 5eca14f
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 59 deletions.
2 changes: 1 addition & 1 deletion scaaml/capture/input_generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
"""Attack point generators and iterator."""

from scaaml.capture.input_generators.input_generators import balanced_generator, single_bunch, unrestricted_generator
from scaaml.capture.input_generators.attack_point_iterator import AttackPointIterator
from scaaml.capture.input_generators.attack_point_iterator import build_attack_points_iterator
97 changes: 55 additions & 42 deletions scaaml/capture/input_generators/attack_point_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,62 +17,68 @@
"""

from abc import ABC, abstractmethod
from typing import List
import collections
import copy
from typing import Dict, List

from scaaml.capture.input_generators.input_generators import balanced_generator, unrestricted_generator


class AttackPointIterator:
"""Attack point iterator class that iterates with different configs."""

def __init__(self, configuration) -> None:
"""Initialize a new iterator."""
self._attack_point_iterator_internal: AttackPointIteratorInternalBase
if configuration["operation"] == "constants":
constant_iter = AttackPointIteratorInternalConstants(
name=configuration["name"],
values=configuration["values"],
)
self._attack_point_iterator_internal = constant_iter
elif configuration["operation"] == "balanced_generator":
balanced_iter = AttackPointIteratorInternalBalancedGenerator(
**configuration)
self._attack_point_iterator_internal = balanced_iter
elif configuration["operation"] == "unrestricted_generator":
unrestricted = AttackPointIteratorInternalUnrestrictedGenerator(
**configuration)
self._attack_point_iterator_internal = unrestricted
else:
raise ValueError(f"{configuration['operation']} is not supported")
class AttackPointIterator(ABC):
"Attack point iterator abstract class."

@abstractmethod
def __len__(self) -> int:
"""Return the number of iterated elements.
"""
return len(self._attack_point_iterator_internal)
"""Return the number of iterated elements."""

@abstractmethod
def __iter__(self):
"""Start iterating."""
return iter(self._attack_point_iterator_internal)


class AttackPointIteratorInternalBase(ABC):
"Attack point iterator abstract class."

@abstractmethod
def __len__(self) -> int:
"""Return the number of iterated elements.
def get_generated_keys(self) -> List[str]:
"""
Returns an exhaustive list of names this iterator
and its children will create.
"""

@abstractmethod
def __iter__(self):
"""Start iterating."""

def build_attack_points_iterator(configuration: Dict) -> AttackPointIterator:
configuration = copy.deepcopy(configuration)
iterator = _build_attack_points_iterator(configuration)

# Check that all names are unique
names_list = collections.Counter(iterator.get_generated_keys())
duplicates = [name for name, count in names_list.items() if count > 1]
if duplicates:
raise ValueError(f"Duplicated attack point names {duplicates}")

return iterator


def _build_attack_points_iterator(configuration: Dict) -> AttackPointIterator:
supported_operations = {
"constants": AttackPointIteratorConstants,
"balanced_generator": AttackPointIteratorBalancedGenerator,
"unrestricted_generator": AttackPointIteratorUnrestrictedGenerator,
# ...
}
operation = configuration["operation"]
iterator_cls = supported_operations.get(operation)

if iterator_cls is None:
raise ValueError(f"Operation {operation} not supported")

class AttackPointIteratorInternalConstants(AttackPointIteratorInternalBase):
return iterator_cls(**configuration)


class AttackPointIteratorConstants(AttackPointIterator):
"""Attack point iterator class that iterates over a constant."""

def __init__(self, name: str, values: List[List[int]]) -> None:
def __init__(self, operation: str, name: str,
values: List[List[int]]) -> None:
"""Initialize the constants to iterate."""
assert "constants" == operation
self._name = name
self._values = values

Expand All @@ -82,9 +88,11 @@ def __len__(self) -> int:
def __iter__(self):
return iter({self._name: value} for value in self._values)

def get_generated_keys(self) -> List[str]:
return [self._name]


class AttackPointIteratorInternalBalancedGenerator(
AttackPointIteratorInternalBase):
class AttackPointIteratorBalancedGenerator(AttackPointIterator):
"""
Attack point iterator class that iterates over the balanced generator.
"""
Expand Down Expand Up @@ -112,9 +120,11 @@ def __iter__(self):
bunches=self._bunches,
elements=self._elements))

def get_generated_keys(self) -> List[str]:
return [self._name]

class AttackPointIteratorInternalUnrestrictedGenerator(
AttackPointIteratorInternalBase):

class AttackPointIteratorUnrestrictedGenerator(AttackPointIterator):
"""
Attack point iterator class that iterates over the unrestricted generator.
"""
Expand All @@ -140,3 +150,6 @@ def __iter__(self):
return iter({self._name: value} for value in unrestricted_generator(
length=self._length, bunches=self._bunches,
elements=self._elements))

def get_generated_keys(self) -> List[str]:
return [self._name]
34 changes: 18 additions & 16 deletions tests/capture/input_generators/test_attack_point_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import numpy as np
import pytest

from scaaml.capture.input_generators import AttackPointIterator
from scaaml.capture.input_generators import build_attack_points_iterator


def attack_point_iterator_constants(values):
input = {"operation": "constants", "name": "key", "values": values}
output = [obj['key'] for obj in list(iter(AttackPointIterator(input)))]
output = [
obj['key'] for obj in list(iter(build_attack_points_iterator(input)))
]
assert output == values


Expand All @@ -17,7 +19,7 @@ def test_attack_point_iterator_no_legal_operation():
[15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]]
input = {"operation": "NONE", "name": "key", "values": values}
with pytest.raises(ValueError):
AttackPointIterator(input)
build_attack_points_iterator(input)


def test_attack_point_iterator_constants():
Expand All @@ -30,15 +32,15 @@ def test_single_key_in_iterator_constants():
values = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]]
input = {"operation": "constants", "name": "key", "values": values}
for constant in AttackPointIterator(input):
for constant in build_attack_points_iterator(input):
assert list(constant.keys()) == ["key"]


def test_attack_point_iterator_constants_no_values():
input = {"operation": "constants", "name": "key"}
output = []
with pytest.raises(KeyError):
AttackPointIterator(input)
with pytest.raises(TypeError):
build_attack_points_iterator(input)


def test_attack_point_iterator_constant_lengths():
Expand All @@ -48,7 +50,7 @@ def test_attack_point_iterator_constant_lengths():


def repeated_iteration(config):
rep_iterator = AttackPointIterator(config)
rep_iterator = build_attack_points_iterator(config)
assert list(iter(rep_iterator)) == list(iter(rep_iterator))


Expand All @@ -61,7 +63,7 @@ def test_repeated_iteration_constants():

def test_attack_point_iterator_balanced_generator():
config = {"operation": "balanced_generator", "name": "key", "length": 16}
output = list(iter(AttackPointIterator(config)))
output = list(iter(build_attack_points_iterator(config)))
assert len(output) == 256


Expand All @@ -73,7 +75,7 @@ def test_attack_point_iterator_balanced_generator_all_kwargs():
"bunches": 2,
"elements": 3
}
output = list(iter(AttackPointIterator(config)))
output = list(iter(build_attack_points_iterator(config)))
assert len(output) == config["bunches"] * config["elements"]


Expand All @@ -83,7 +85,7 @@ def test_attack_point_iterator_unrestricted_generator():
"name": "key",
"length": 16
}
output = list(iter(AttackPointIterator(config)))
output = list(iter(build_attack_points_iterator(config)))
assert len(output) == 256


Expand All @@ -95,13 +97,13 @@ def test_attack_point_iterator_balanced_generator_all_args():
"bunches": 2,
"elements": 3
}
output = list(iter(AttackPointIterator(config)))
output = list(iter(build_attack_points_iterator(config)))
assert len(output) == config["bunches"] * config["elements"]


def test_attack_point_iterator_balanced_generator_len():
config = {"operation": "balanced_generator", "name": "key", "length": 16}
output = AttackPointIterator(config)
output = build_attack_points_iterator(config)
assert len(output) == 256


Expand All @@ -113,9 +115,9 @@ def test_attack_point_iterator_balanced_generator_all_args_len():
"bunches": 2,
"elements": 3
}
output = list(iter(AttackPointIterator(config)))
output = list(iter(build_attack_points_iterator(config)))
assert len(output) == config["bunches"] * config["elements"]
assert len(output) == len(AttackPointIterator(config))
assert len(output) == len(build_attack_points_iterator(config))


def test_attack_point_iterator_unrestricted_generator_all_args_len():
Expand All @@ -126,6 +128,6 @@ def test_attack_point_iterator_unrestricted_generator_all_args_len():
"bunches": 2,
"elements": 3
}
output = list(iter(AttackPointIterator(config)))
output = list(iter(build_attack_points_iterator(config)))
assert len(output) == config["bunches"] * config["elements"]
assert len(output) == len(AttackPointIterator(config))
assert len(output) == len(build_attack_points_iterator(config))

0 comments on commit 5eca14f

Please sign in to comment.