diff --git a/peak/bitfield.py b/peak/bitfield.py index 7f1ed2a3..c866ebaa 100644 --- a/peak/bitfield.py +++ b/peak/bitfield.py @@ -1,6 +1,22 @@ +import typing as tp + from hwtypes.adt import Enum, Sum, Product from hwtypes import AbstractBitVector +def tag(tags: tp.Mapping[type, int]): + def wrapper(sum: Sum): + if not issubclass(sum, Sum): + raise TypeError('tag can only be applied Sum') + if tags.keys() != sum.fields: + raise ValueError('tag must specificy an Option for each Sum option') + if not all(isinstance(t, int) for t in tags.values()): + raise TypeError('tags must be int') + + setattr(sum, 'tags', tags) + return sum + return wrapper + + def bitfield(i): def wrap(klass): klass.bitfield = i diff --git a/tests/test_tag.py b/tests/test_tag.py new file mode 100644 index 00000000..dc4d1c22 --- /dev/null +++ b/tests/test_tag.py @@ -0,0 +1,31 @@ +import pytest +from hwtypes.adt import Sum + +from peak.bitfield import tag + +def test_tag(): + @tag({int : 0, str : 1}) + class S(Sum[int ,str]): pass + + assert S.tags[int] == 0 + assert S.tags[str] == 1 + + with pytest.raises(TypeError): + @tag() + class S: pass + + with pytest.raises(ValueError): + @tag({int : 0, str : 1, object : 2}) + class S(Sum[int ,str]): pass + + with pytest.raises(ValueError): + @tag({int : 0}) + class S(Sum[int ,str]): pass + + with pytest.raises(TypeError): + @tag({int : 'a', str : 1}) + class S(Sum[int ,str]): pass + + + +