diff --git a/README.md b/README.md index 084ae16..96fe042 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,20 @@ custom `machine_id`, `start_time` etc. - `machine_id` should be an integer value upto 16-bits, callable or `None` (will be used random machine id). +If you need to generate ids at rate more than 256ids/10msec, you can use the `RoundRobin` wrapper over multiple `SonyFlake` instances: + +``` python +from timeit import timeit +from sonyflake import RoundRobin, SonyFlake, random_machine_ids +sf = RoundRobin([SonyFlake(machine_id=_id) for _id in random_machine_ids(10)]) +t = timeit(sf.next_id, number=100000) +print(f"generated 100000 ids in {t:.2f} seconds") +``` + +> :warning: This increases the chance of collisions, so be careful when using random machine IDs. + +For convenience, both `SonyFlake` and `RoundRobin` implement iterator protocol (`next(sf)`). + ## License The MIT License (MIT). diff --git a/sonyflake/__init__.py b/sonyflake/__init__.py index e1a6c02..9d9a99c 100644 --- a/sonyflake/__init__.py +++ b/sonyflake/__init__.py @@ -1,4 +1,17 @@ from .about import NAME, VERSION, __version__ -from .sonyflake import SonyFlake +from .round_robin import RoundRobin +from .sonyflake import ( + SONYFLAKE_EPOCH, + SonyFlake, + lower_16bit_private_ip, + random_machine_id, + random_machine_ids, +) -__all__ = ["SonyFlake"] +__all__ = [ + "RoundRobin", + "SonyFlake", + "random_machine_id", + "random_machine_ids", + "lower_16bit_private_ip", +] diff --git a/sonyflake/round_robin.py b/sonyflake/round_robin.py new file mode 100644 index 0000000..f198fc1 --- /dev/null +++ b/sonyflake/round_robin.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from itertools import cycle +from typing import Iterable, Iterator + + +class RoundRobin(Iterator[int]): + """Round-robin iterator for cycling through multiple ID generators. + + Used for generating ids at rate more than 256ids/10msec. + + Example: + >>> from sonyflake import RoundRobin, SonyFlake, random_machine_ids + >>> sf = RoundRobin([SonyFlake(machine_id=_id) for _id in random_machine_ids(10)]) + >>> %timeit next(sf) + """ + + _id_generators: cycle[Iterator[int]] + __slots__ = ("_id_generators",) + + def __init__(self, id_generators: Iterable[Iterator[int]]) -> None: + self._id_generators = cycle(id_generators) + + def __next__(self) -> int: + return next(next(self._id_generators)) + + next_id = __next__ diff --git a/sonyflake/sonyflake.py b/sonyflake/sonyflake.py index c950009..c2f4ea1 100644 --- a/sonyflake/sonyflake.py +++ b/sonyflake/sonyflake.py @@ -1,11 +1,11 @@ import datetime import ipaddress from functools import partial -from random import randrange +from random import randrange, sample from socket import gethostbyname, gethostname from threading import Lock from time import sleep -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Union from warnings import warn BIT_LEN_TIME = 39 @@ -19,6 +19,21 @@ utc_now = partial(datetime.datetime.now, tz=UTC) +def random_machine_ids(n: int) -> List[int]: + """ + Returns a list of `n` random machine IDs. + + `n` must be in range (0, 0xFFFF]. + + Returned list is sorted in ascending order, without duplicates. + """ + + if not (0 < n <= MAX_MACHINE_ID): + raise ValueError(f"n must be in range (0, {MAX_MACHINE_ID}]") + + return sorted(sample(range(0, MAX_MACHINE_ID + 1), n)) + + def lower_16bit_private_ip() -> int: """ Returns the lower 16 bits of the private IP address. @@ -28,13 +43,17 @@ def lower_16bit_private_ip() -> int: return (ip_bytes[2] << 8) + ip_bytes[3] -class SonyFlake: +class SonyFlake(Iterator[int]): """ The distributed unique ID generator. """ + _now: Callable[[], datetime.datetime] + mutex: Lock _start_time: int _machine_id: int + elapsed_time: int + sequence: int __slots__ = ( "_now", @@ -148,6 +167,8 @@ def next_id(self) -> int: sleep(self.sleep_time(overtime, self._now())) return self.to_id() + __next__ = next_id + def to_id(self) -> int: if self.elapsed_time >= (1 << BIT_LEN_TIME): raise TimeoutError("Over the time limit!") diff --git a/tests/test_round_robin.py b/tests/test_round_robin.py new file mode 100644 index 0000000..25812ec --- /dev/null +++ b/tests/test_round_robin.py @@ -0,0 +1,18 @@ +from sonyflake.round_robin import RoundRobin +from sonyflake.sonyflake import BIT_LEN_MACHINE_ID, SonyFlake + + +def test_round_robin() -> None: + rr = RoundRobin( + [ + SonyFlake(machine_id=0x0000), + SonyFlake(machine_id=0x7F7F), + SonyFlake(machine_id=0xFFFF), + ] + ) + + assert [next(rr) & ((1 << BIT_LEN_MACHINE_ID) - 1) for _ in range(6)] == [ + 0x0000, + 0x7F7F, + 0xFFFF, + ] * 2 diff --git a/tests/test_sonyflake.py b/tests/test_sonyflake.py index b7dcf68..5e1ef65 100644 --- a/tests/test_sonyflake.py +++ b/tests/test_sonyflake.py @@ -4,7 +4,7 @@ from time import sleep from unittest import TestCase -from pytest import raises +from pytest import mark, raises from sonyflake.sonyflake import ( BIT_LEN_SEQUENCE, @@ -12,6 +12,7 @@ SonyFlake, lower_16bit_private_ip, random_machine_id, + random_machine_ids, ) @@ -110,5 +111,18 @@ def test_random_machine_id() -> None: assert random_machine_id() +@mark.parametrize("n", [1, 1024, 65535]) +def test_random_machine_ids(n: int) -> None: + machine_ids = random_machine_ids(n) + assert len(set(machine_ids)) == n + assert sorted(machine_ids) == machine_ids + + +@mark.parametrize("n", [0, 65536]) +def test_random_machine_ids_edges(n: int) -> None: + with raises(ValueError, match=r"n must be in range \(0, 65535\]"): + random_machine_ids(n) + + def test_lower_16bit_private_ip() -> None: assert lower_16bit_private_ip()