Skip to content

Commit

Permalink
[squash] fix mypy and yapf
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxrdv committed Feb 3, 2025
1 parent 75586fb commit 58de3ab
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 36 deletions.
13 changes: 10 additions & 3 deletions scaaml/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""
from scaaml.stats.ap_checker import APChecker
from scaaml.stats.ap_counter import APCounter
from scaaml.stats.cpa import CPA
from scaaml.stats.example_iterator import ExampleIterator
from scaaml.stats.print_stats import PrintStats
from scaaml.stats.trace_stddev_of_stat import STDDEVofAVGofTraces
Expand All @@ -28,7 +29,13 @@
from scaaml.stats.trace_stddev_of_stat import STDDEVofSTATofTraces

__all__ = [
"APChecker", "APCounter", "ExampleIterator", "PrintStats",
"STDDEVofAVGofTraces", "STDDEVofMAXofTraces", "STDDEVofMINofTraces",
"STDDEVofSTATofTraces"
"APChecker",
"APCounter",
"CPA",
"ExampleIterator",
"PrintStats",
"STDDEVofAVGofTraces",
"STDDEVofMAXofTraces",
"STDDEVofMINofTraces",
"STDDEVofSTATofTraces",
]
2 changes: 0 additions & 2 deletions scaaml/stats/attack_points/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,3 @@
"""Attack points.
"""
from scaaml.stats.attack_points import aes_128

__all__ = []
24 changes: 12 additions & 12 deletions scaaml/stats/attack_points/aes_128/attack_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,19 @@ class Plaintext(AttackPointAES128):
def leakage_knowing_secrets(key: npt.NDArray[np.uint8],
plaintext: npt.NDArray[np.uint8],
byte_index: int) -> int:
return plaintext[byte_index]
return int(plaintext[byte_index])

@staticmethod
def leakage_from_guess(plaintext: npt.NDArray[np.uint8],
ciphertext: npt.NDArray[np.uint8], guess: int,
byte_index: int) -> int:
assert 0 <= guess < 256
return guess
return int(guess)

@staticmethod
def target_secret(key: npt.NDArray[np.uint8],
plaintext: npt.NDArray[np.uint8], byte_index: int) -> int:
return plaintext[byte_index]
return int(plaintext[byte_index])


class SubBytesIn(AttackPointAES128):
Expand All @@ -117,19 +117,19 @@ class SubBytesIn(AttackPointAES128):
def leakage_knowing_secrets(key: npt.NDArray[np.uint8],
plaintext: npt.NDArray[np.uint8],
byte_index: int) -> int:
return key[byte_index] ^ plaintext[byte_index]
return int(key[byte_index] ^ plaintext[byte_index])

@staticmethod
def leakage_from_guess(plaintext: npt.NDArray[np.uint8],
ciphertext: npt.NDArray[np.uint8], guess: int,
byte_index: int) -> int:
assert 0 <= guess < 256
return guess ^ plaintext[byte_index]
return int(guess ^ plaintext[byte_index])

@staticmethod
def target_secret(key: npt.NDArray[np.uint8],
plaintext: npt.NDArray[np.uint8], byte_index: int) -> int:
return key[byte_index]
return int(key[byte_index])


class SubBytesOut(AttackPointAES128):
Expand All @@ -140,19 +140,19 @@ class SubBytesOut(AttackPointAES128):
def leakage_knowing_secrets(key: npt.NDArray[np.uint8],
plaintext: npt.NDArray[np.uint8],
byte_index: int) -> int:
return SBOX[key[byte_index] ^ plaintext[byte_index]]
return int(SBOX[key[byte_index] ^ plaintext[byte_index]])

@staticmethod
def leakage_from_guess(plaintext: npt.NDArray[np.uint8],
ciphertext: npt.NDArray[np.uint8], guess: int,
byte_index: int) -> int:
assert 0 <= guess < 256
return SBOX[guess ^ plaintext[byte_index]]
return int(SBOX[guess ^ plaintext[byte_index]])

@staticmethod
def target_secret(key: npt.NDArray[np.uint8],
plaintext: npt.NDArray[np.uint8], byte_index: int) -> int:
return key[byte_index]
return int(key[byte_index])


class LastRoundStateDiff(AttackPointAES128):
Expand All @@ -179,7 +179,7 @@ def leakage_knowing_secrets(key: npt.NDArray[np.uint8],
st9 = SBOX_INV[ciphertext[byte_index] ^ guess]
byte_value = st9 ^ st10

return byte_value
return int(byte_value)

@staticmethod
def leakage_from_guess(plaintext: npt.NDArray[np.uint8],
Expand All @@ -193,7 +193,7 @@ def leakage_from_guess(plaintext: npt.NDArray[np.uint8],
st9 = SBOX_INV[ciphertext[byte_index] ^ guess]
byte_value = st9 ^ st10

return byte_value
return int(byte_value)

@staticmethod
def target_secret(key: npt.NDArray[np.uint8],
Expand All @@ -202,7 +202,7 @@ def target_secret(key: npt.NDArray[np.uint8],
correct_k = last_key_schedule[-4:].reshape(-1)
guess = correct_k[byte_index]

return guess
return int(guess)


class LeakageModelAES128:
Expand Down
42 changes: 23 additions & 19 deletions scaaml/stats/cpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
"""

import math
from typing import Optional
from typing import Callable, Optional

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from tabulate import tabulate

from scaaml.stats.attack_points.aes import LeakageModelAES128


class R:
"""Holds and updates intermediate values.
Expand All @@ -41,7 +43,7 @@ def __init__(self) -> None:
self.sum_tt: npt.NDArray[np.float64]

def update(self, trace: npt.NDArray[np.float64],
hypothesis: list[int]) -> None:
hypothesis: npt.NDArray[np.int32]) -> None:
"""Update with another trace.
Args:
Expand All @@ -51,7 +53,7 @@ def update(self, trace: npt.NDArray[np.float64],
hypothesis (list[int]): Hypothetical leakage for each possible secret
value.
"""
trace = np.array(trace)
trace = np.array(trace, dtype=np.float64)
hypothesis = np.array(hypothesis)
assert len(trace.shape) == 1
assert len(hypothesis.shape) == 1
Expand All @@ -74,7 +76,7 @@ def update(self, trace: npt.NDArray[np.float64],
self.sum_hh += hypothesis**2
self.sum_tt += trace**2

def guess(self):
def guess(self) -> npt.NDArray[np.float64]:
"""Return how much each possible guess value corresponds to the
observed values.
"""
Expand All @@ -87,7 +89,7 @@ def guess(self):
den_b = (self.sum_t**2) - (self.d * self.sum_tt) # j

r = nom / np.sqrt(np.einsum("i,j->ij", den_a, den_b))
return np.abs(r)
return np.array(np.abs(r), dtype=np.float64)


class CPA:
Expand All @@ -98,10 +100,12 @@ class CPA:
good idea to use one of the well established implementations.
"""

def __init__(self, get_model) -> None:
self.models = [
get_model(byte_index=byte_index) for byte_index in range(16)
]
def __init__( # type: ignore[no-any-unimported]
self, get_model: Callable[[int], LeakageModelAES128]) -> None:
self.models: list[ # type: ignore[no-any-unimported]
LeakageModelAES128] = [
get_model(byte_index) for byte_index in range(16)
]
self.result: dict[int, list[list[float]]] = {
# So that combined traces index from 1
i: [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] for _ in range(256)]
Expand Down Expand Up @@ -135,16 +139,16 @@ def update(self,
assert all(self.real_key == real_key)

for byte in range(16):
hypothesis = [
hypothesis: list[int] = [
self.models[byte].leakage_from_guess(
plaintext=plaintext,
ciphertext=ciphertext,
guess=i,
) for i in range(self.models[byte].different_target_secrets)
]
self.r[byte].update(
trace=trace,
hypothesis=hypothesis,
trace=trace.astype(np.float64),
hypothesis=np.array(hypothesis, dtype=np.int32),
)

res = self.r[byte].guess()
Expand All @@ -168,7 +172,7 @@ def print_predictions(self, real_key: npt.NDArray[np.uint8],
plaintext (npt.NDArray[np.uint8]): The input of AES.
"""
statistics = [["byte"], ["real"], ["guessed"], ["rank"]]
statistics: list[list[int]] = [[], [], [], []]
iteration = len(next(iter(self.result.values()))[0])
for byte in range(16):
target_value = self.models[byte].target_secret(
Expand All @@ -179,7 +183,7 @@ def print_predictions(self, real_key: npt.NDArray[np.uint8],
statistics[1].append(target_value)
res = np.max(self.r[byte].guess(), axis=1)
assert res.shape == (self.models[byte].different_target_secrets,)
statistics[2].append(np.argmax(res))
statistics[2].append(int(np.argmax(res)))
# Compute rank
statistics[3].append(int(np.sum(res >= res[target_value])))

Expand All @@ -191,13 +195,13 @@ def print_predictions(self, real_key: npt.NDArray[np.uint8],
security = math.log2(math.prod(current_ranks))
print(f"Traces: {iteration + 1} mean_rank {np.mean(current_ranks)} "
f"{security = }")
print(tabulate(statistics))
print(tabulate(statistics, headers=["byte", "real", "guessed", "rank"]))

def plot_cpa(self,
real_key: npt.NDArray[np.uint8],
plaintext: npt.NDArray[np.uint8],
experiment_name: str = "cpa.png",
logscale: bool = True) -> None:
real_key: npt.NDArray[np.uint8],
plaintext: npt.NDArray[np.uint8],
experiment_name: str = "cpa.png",
logscale: bool = True) -> None:
"""Plot how does the real secret value change position among
predictions when adding more examples.
Expand Down

0 comments on commit 58de3ab

Please sign in to comment.