Skip to content

Commit

Permalink
feat(vhdl): implement clipped fixed point representation
Browse files Browse the repository at this point in the history
  • Loading branch information
julianhoever committed Sep 29, 2022
1 parent 393d96b commit 8e53506
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 61 deletions.
58 changes: 58 additions & 0 deletions elasticai/creator/tests/vhdl/test_number_representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest import TestCase

from elasticai.creator.vhdl.number_representations import (
ClippedFixedPoint,
FixedPoint,
ToLogicEncoder,
float_values_to_fixed_point,
Expand Down Expand Up @@ -239,6 +240,63 @@ def test_to_signed_int_negative_value(self):
self.assertEqual(-80, fp.to_signed_int())


class ClippedFixedPointTest(TestCase):
def test_conversion_value_in_bounds(self) -> None:
fp = ClippedFixedPoint(5.251, total_bits=8, frac_bits=4)
self.assertEqual(84, int(fp))
self.assertEqual(5.25, float(fp))

def test_conversion_value_out_of_lower_bound(self) -> None:
fp = ClippedFixedPoint(-9.25, total_bits=8, frac_bits=4)
self.assertEqual(128, int(fp))
self.assertEqual(-8, float(fp))

def test_conversion_value_out_of_upper_bound(self) -> None:
fp = ClippedFixedPoint(8, total_bits=8, frac_bits=4)
self.assertEqual(127, int(fp))
self.assertEqual(7.9375, float(fp))

def test_repr_value_out_of_bounds(self) -> None:
fp = ClippedFixedPoint(10, total_bits=8, frac_bits=4)
self.assertEqual(
"ClippedFixedPoint(value=7.9375, total_bits=8, frac_bits=4)", repr(fp)
)

def test_repr_value_in_bounds(self) -> None:
fp = ClippedFixedPoint(5.21, total_bits=8, frac_bits=4)
self.assertEqual(
"ClippedFixedPoint(value=5.21, total_bits=8, frac_bits=4)", repr(fp)
)

def test_from_unsigned_int_value_in_bounds(self) -> None:
fp = ClippedFixedPoint.from_unsigned_int(62, total_bits=8, frac_bits=4)
self.assertEqual(62, int(fp))
self.assertEqual(3.875, float(fp))

def test_from_unsigned_int_value_out_of_bounds(self) -> None:
fp = ClippedFixedPoint.from_unsigned_int(830, total_bits=8, frac_bits=4)
self.assertEqual(62, int(fp))
self.assertEqual(3.875, float(fp))

def test_from_signed_int_value_in_bounds(self) -> None:
fp = ClippedFixedPoint.from_signed_int(-100, total_bits=8, frac_bits=4)
self.assertEqual(156, int(fp))
self.assertEqual(-6.25, float(fp))

def test_from_signed_int_value_out_of_bounds(self) -> None:
fp = ClippedFixedPoint.from_signed_int(-255, total_bits=8, frac_bits=4)
self.assertEqual(128, int(fp))
self.assertEqual(-8, float(fp))

def test_get_factory(self) -> None:
factory = ClippedFixedPoint.get_factory(total_bits=8, frac_bits=4)
fp = factory(1)
self.assertEqual(ClippedFixedPoint, type(fp))
self.assertEqual(8, fp.total_bits)
self.assertEqual(4, fp.frac_bits)
self.assertEqual(1, float(fp))


class InferTotalAndFracBits(TestCase):
def test_infer_empty_list(self):
with self.assertRaises(ValueError):
Expand Down
172 changes: 111 additions & 61 deletions elasticai/creator/vhdl/number_representations.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,50 @@
import math
from collections.abc import Sequence
from collections.abc import Iterable, Iterator, Sequence
from functools import partial
from itertools import chain
from typing import Any, Callable, Iterable, Iterator
from typing import Any, Callable


def _assert_range(value: float, total_bits: int, frac_bits: int) -> None:
max_value = 2 ** (total_bits - frac_bits - 1)
min_value = max_value * (-1)
if not min_value <= value < max_value:
raise ValueError(
(
f"Value {value} cannot represented as a fixed point value with {total_bits} total bits "
f"and {frac_bits} fraction bits (value range: [{min_value}, {max_value}))."
)
)


def _assert_is_compatible(fp1: "FixedPoint", fp2: "FixedPoint") -> None:
if not (fp1.total_bits == fp2.total_bits and fp1.frac_bits == fp2.frac_bits):
raise ValueError(
(
f"FixedPoint objects not compatible (total_bits: {fp1.total_bits} != {fp2.total_bits}); "
f"frac_bits: {fp1.frac_bits} != {fp2.frac_bits})."
)
)


def _invert_int(value: int, num_bits: int) -> int:
return value ^ int("1" * num_bits, 2)


def _discard_leading_bits(value: int, num_bits: int) -> int:
return value & int("1" * num_bits, 2)


def _calculate_two_complement(value: int, num_bits: int) -> int:
return _invert_int(abs(value), num_bits) + 1


class FixedPoint:
"""
A data type that converts a given number to the corresponding fixed-point representation.
A fixed-point value is an unsigned integer in two's complement.
Parameters:
value (float | int): Value to be represented as fixed-point value.
value (float): Value to be represented as fixed-point value.
total_bits (int): Total number of bits of the fixed-point representation (including number of fractional bits).
frac_bits (int): Number of bits to represent the fractional part of the number.
Examples:
Expand All @@ -29,21 +63,16 @@ class FixedPoint:

__slots__ = ["_value", "_frac_bits", "_total_bits"]

def __init__(
self,
value: float | int,
total_bits: int,
frac_bits: int,
) -> None:
def __init__(self, value: float, total_bits: int, frac_bits: int) -> None:
self._value = float(value)
self._total_bits = total_bits
self._frac_bits = frac_bits
FixedPoint._assert_range(self._value, self._total_bits, self._frac_bits)
_assert_range(self._value, self._total_bits, self._frac_bits)

def __int__(self) -> int:
fp_int = int(self._value * (1 << self._frac_bits))
if fp_int < 0:
fp_int = FixedPoint._calculate_two_complement(fp_int, self._total_bits)
fp_int = _calculate_two_complement(fp_int, self._total_bits)
return fp_int

def __float__(self) -> float:
Expand All @@ -70,31 +99,29 @@ def __ge__(self, other: Any) -> bool:
return float(self) >= float(other)

def __add__(self, other: "FixedPoint") -> "FixedPoint":
FixedPoint._assert_is_compatible(self, other)
_assert_is_compatible(self, other)
return self._identical_fixed_point_from_int(
FixedPoint._discard_leading_bits(
int(self) + int(other), num_bits=self._total_bits
)
_discard_leading_bits(int(self) + int(other), num_bits=self._total_bits)
)

def __sub__(self, other: "FixedPoint") -> "FixedPoint":
return self + (-other)

def __and__(self, other: "FixedPoint") -> "FixedPoint":
FixedPoint._assert_is_compatible(self, other)
_assert_is_compatible(self, other)
return self._identical_fixed_point_from_int(int(self) & int(other))

def __or__(self, other: "FixedPoint") -> "FixedPoint":
FixedPoint._assert_is_compatible(self, other)
_assert_is_compatible(self, other)
return self._identical_fixed_point_from_int(int(self) | int(other))

def __xor__(self, other: "FixedPoint") -> "FixedPoint":
FixedPoint._assert_is_compatible(self, other)
_assert_is_compatible(self, other)
return self._identical_fixed_point_from_int(int(self) ^ int(other))

def __invert__(self) -> "FixedPoint":
return self._identical_fixed_point_from_int(
FixedPoint._invert_int(int(self), num_bits=self._total_bits)
_invert_int(int(self), num_bits=self._total_bits)
)

def __neg__(self) -> "FixedPoint":
Expand All @@ -109,7 +136,7 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"FixedPoint(value={self._value}, total_bits={self._total_bits}, frac_bits={self._frac_bits})"

def _identical_fixed_point(self, value: float | int) -> "FixedPoint":
def _identical_fixed_point(self, value: float) -> "FixedPoint":
return FixedPoint(
value=value, total_bits=self._total_bits, frac_bits=self._frac_bits
)
Expand All @@ -119,41 +146,6 @@ def _identical_fixed_point_from_int(self, value: int) -> "FixedPoint":
value=value, total_bits=self._total_bits, frac_bits=self._frac_bits
)

@staticmethod
def _assert_range(value: float | int, total_bits: int, frac_bits: int) -> None:
max_value = 2 ** (total_bits - frac_bits - 1)
min_value = max_value * (-1)

if not min_value <= value < max_value:
raise ValueError(
(
f"Value {value} cannot represented as a fixed point value with {total_bits} total bits "
f"and {frac_bits} fraction bits (value range: [{min_value}, {max_value}))."
)
)

@staticmethod
def _assert_is_compatible(fp1: "FixedPoint", fp2: "FixedPoint") -> None:
if not (fp1.total_bits == fp2.total_bits and fp1.frac_bits == fp2.frac_bits):
raise ValueError(
(
f"FixedPoint objects not compatible (total_bits: {fp1.total_bits} != {fp2.total_bits}); "
f"frac_bits: {fp1.frac_bits} != {fp2.frac_bits})."
)
)

@staticmethod
def _invert_int(value: int, num_bits: int) -> int:
return value ^ int("1" * num_bits, 2)

@staticmethod
def _discard_leading_bits(value: int, num_bits: int) -> int:
return value & int("1" * num_bits, 2)

@staticmethod
def _calculate_two_complement(value: int, num_bits: int) -> int:
return FixedPoint._invert_int(abs(value), num_bits) + 1

@staticmethod
def from_unsigned_int(value: int, total_bits: int, frac_bits: int) -> "FixedPoint":
if value > 2**total_bits - 1:
Expand All @@ -162,21 +154,18 @@ def from_unsigned_int(value: int, total_bits: int, frac_bits: int) -> "FixedPoin
)
is_negative = value & (1 << total_bits - 1) > 0
if is_negative:
value = FixedPoint._calculate_two_complement(value, total_bits)
value = _calculate_two_complement(value, total_bits)
value *= -1
float_value = value / (1 << frac_bits)
return FixedPoint(float_value, total_bits=total_bits, frac_bits=frac_bits)

@staticmethod
def from_signed_int(value: int, total_bits: int, frac_bits: int) -> "FixedPoint":
float_value = value / (1 << frac_bits)
FixedPoint._assert_range(float_value, total_bits, frac_bits)
return FixedPoint(float_value, total_bits=total_bits, frac_bits=frac_bits)

@staticmethod
def get_factory(
total_bits: int, frac_bits: int
) -> Callable[[float | int], "FixedPoint"]:
def get_factory(total_bits: int, frac_bits: int) -> Callable[[float], "FixedPoint"]:
return partial(FixedPoint, total_bits=total_bits, frac_bits=frac_bits)

@property
Expand All @@ -200,6 +189,67 @@ def to_hex(self) -> str:
return f"{int(self):0{math.ceil(self._total_bits / 4)}x}"


class ClippedFixedPoint(FixedPoint):
def __init__(self, value: float, total_bits: int, frac_bits: int) -> None:
max_value = (2 ** (total_bits - 1) - 1) / (1 << frac_bits)
min_value = 2 ** (total_bits - frac_bits - 1) * (-1)
if min_value <= value <= max_value:
super().__init__(value=value, total_bits=total_bits, frac_bits=frac_bits)
else:
super().__init__(
value=max_value if value > max_value else min_value,
total_bits=total_bits,
frac_bits=frac_bits,
)

def __float__(self) -> float:
return ClippedFixedPoint.from_unsigned_int(
int(self), self._total_bits, self._frac_bits
)._value

def __repr__(self) -> str:
return f"ClippedFixedPoint(value={self._value}, total_bits={self._total_bits}, frac_bits={self._frac_bits})"

def _identical_fixed_point(self, value: float) -> "ClippedFixedPoint":
return ClippedFixedPoint(
value=value, total_bits=self._total_bits, frac_bits=self._frac_bits
)

def _identical_fixed_point_from_int(self, value: int) -> "ClippedFixedPoint":
return ClippedFixedPoint.from_unsigned_int(
value=value, total_bits=self._total_bits, frac_bits=self._frac_bits
)

@staticmethod
def from_unsigned_int(
value: int, total_bits: int, frac_bits: int
) -> "ClippedFixedPoint":
value = _discard_leading_bits(value, num_bits=total_bits)
is_negative = value & (1 << total_bits - 1) > 0
if is_negative:
value = _calculate_two_complement(value, total_bits)
value *= -1
float_value = value / (1 << frac_bits)
return ClippedFixedPoint(
float_value, total_bits=total_bits, frac_bits=frac_bits
)

@staticmethod
def from_signed_int(
value: int, total_bits: int, frac_bits: int
) -> "ClippedFixedPoint":
float_value = value / (1 << frac_bits)
return ClippedFixedPoint(
float_value, total_bits=total_bits, frac_bits=frac_bits
)

@staticmethod
def get_factory(
total_bits: int, frac_bits: int
) -> Callable[[float], "ClippedFixedPoint"]:
return partial(ClippedFixedPoint, total_bits=total_bits, frac_bits=frac_bits)


def infer_total_and_frac_bits(*values: Sequence[FixedPoint]) -> tuple[int, int]:
if sum(len(value_list) == 0 for value_list in values) > 0:
raise ValueError("Cannot infer total bits and frac bits from an empty list.")
Expand All @@ -213,7 +263,7 @@ def infer_total_and_frac_bits(*values: Sequence[FixedPoint]) -> tuple[int, int]:


def float_values_to_fixed_point(
values: list[float | int], total_bits: int, frac_bits: int
values: list[float], total_bits: int, frac_bits: int
) -> list[FixedPoint]:
return list(map(lambda x: FixedPoint(x, total_bits, frac_bits), values))

Expand Down

0 comments on commit 8e53506

Please sign in to comment.