Skip to content

Commit

Permalink
Merge pull request #676 from more-itertools/transpose-matmul-factor
Browse files Browse the repository at this point in the history
Sync with itertool recipes by adding transpose(), matmul(), and facto…
  • Loading branch information
bbayles authored Feb 21, 2023
2 parents e82e7dc + aae6d92 commit 74c920a
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 2 deletions.
3 changes: 3 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ These tools yield groups of items from a source iterable.
.. autofunction:: batched
.. autofunction:: grouper
.. autofunction:: partition
.. autofunction:: transpose


Lookahead and lookback
Expand Down Expand Up @@ -293,3 +294,5 @@ Others
.. autofunction:: repeatfunc
.. autofunction:: polynomial_from_roots
.. autofunction:: sieve
.. autofunction:: factor
.. autofunction:: matmul
4 changes: 2 additions & 2 deletions more_itertools/more.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ def distinct_permutations(iterable, r=None):
[(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
"""

# Algorithm: https://w.wiki/Qai
def _full(A):
while True:
Expand Down Expand Up @@ -2917,6 +2918,7 @@ def make_decorator(wrapping_func, result_index=0):
'7'
"""

# See https://sites.google.com/site/bbayles/index/decorator_factory for
# notes on how this works.
def decorator(*wrapping_args, **wrapping_kwargs):
Expand Down Expand Up @@ -3467,7 +3469,6 @@ def _sample_unweighted(iterable, k):
next_index = k + floor(log(random()) / log(1 - W))

for index, element in enumerate(iterable, k):

if index == next_index:
reservoir[randrange(k)] = element
# The new W is the largest in a sample of k U(0, `old_W`) numbers
Expand Down Expand Up @@ -4287,7 +4288,6 @@ def minmax(iterable_or_value, *others, key=None, default=_marker):
lo_key = hi_key = key(lo)

for x, y in zip_longest(it, it, fillvalue=lo):

x_key, y_key = key(x), key(y)

if y_key < x_key:
Expand Down
47 changes: 47 additions & 0 deletions more_itertools/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
cycle,
groupby,
islice,
product,
repeat,
starmap,
tee,
Expand All @@ -38,10 +39,12 @@
'convolve',
'dotproduct',
'first_true',
'factor',
'flatten',
'grouper',
'iter_except',
'iter_index',
'matmul',
'ncycles',
'nth',
'nth_combination',
Expand All @@ -65,6 +68,7 @@
'tabulate',
'tail',
'take',
'transpose',
'triplewise',
'unique_everseen',
'unique_justseen',
Expand Down Expand Up @@ -881,3 +885,46 @@ def batched(iterable, n):
if not batch:
break
yield batch


def transpose(it):
"""Swap the rows and columns of the input.
>>> list(transpose([(1, 2, 3), (11, 22, 33)]))
[(1, 11), (2, 22), (3, 33)]
The caller should ensure that the dimensions of the input are compatible.
"""
# TODO: when 3.9 goes end-of-life, add stric=True to this.
return zip(*it)


def matmul(m1, m2):
"""Multiply two matrices.
>>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
[[49, 80], [41, 60]]
The caller should ensure that the dimensions of the input matrices are
compatible with each other.
"""
n = len(m2[0])
return batched(starmap(dotproduct, product(m1, transpose(m2))), n)


def factor(n):
"""Yield the prime factors of n.
>>> list(factor(360))
[2, 2, 2, 3, 3, 5]
"""
isqrt = getattr(math, 'isqrt', lambda x: int(math.sqrt(x)))
for prime in sieve(isqrt(n) + 1):
while True:
quotient, remainder = divmod(n, prime)
if remainder:
break
yield prime
n = quotient
if n == 1:
return
if n >= 2:
yield n
5 changes: 5 additions & 0 deletions more_itertools/recipes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,8 @@ def batched(
iterable: Iterable[_T],
n: int,
) -> Iterator[list[_T]]: ...
def transpose(
it: Iterable[Iterable[_T]],
) -> tuple[Iterator[_T], ...]: ...
def matmul(m1: Sequence[_T], m2: Sequence[_T]) -> Iterator[list[_T]]: ...
def factor(n: int) -> Iterator[int]: ...
70 changes: 70 additions & 0 deletions tests/test_recipes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from doctest import DocTestSuite
from functools import reduce
from itertools import combinations, count, permutations
from operator import mul
from math import factorial
from unittest import TestCase

Expand Down Expand Up @@ -971,3 +972,72 @@ def test_basic(self):
with self.subTest(n=n):
actual = list(mi.batched(iterable, n))
self.assertEqual(actual, expected)


class TransposeTests(TestCase):
def test_empty(self):
it = []
actual = list(mi.transpose(it))
expected = []
self.assertEqual(actual, expected)

def test_basic(self):
it = [(10, 11, 12), (20, 21, 22), (30, 31, 32)]
actual = list(mi.transpose(it))
expected = [(10, 20, 30), (11, 21, 31), (12, 22, 32)]
self.assertEqual(actual, expected)

def test_incompatible(self):
it = [(10, 11, 12, 13), (20, 21, 22), (30, 31, 32)]
actual = list(mi.transpose(it))
expected = [(10, 20, 30), (11, 21, 31), (12, 22, 32)]
self.assertEqual(actual, expected)


class MatMulTests(TestCase):
def test_n_by_n(self):
actual = list(mi.matmul([(7, 5), (3, 5)], [[2, 5], [7, 9]]))
expected = [[49, 80], [41, 60]]
self.assertEqual(actual, expected)

def test_m_by_n(self):
m1 = [[2, 5], [7, 9], [3, 4]]
m2 = [[7, 11, 5, 4, 9], [3, 5, 2, 6, 3]]
actual = list(mi.matmul(m1, m2))
expected = [
[29, 47, 20, 38, 33],
[76, 122, 53, 82, 90],
[33, 53, 23, 36, 39],
]
self.assertEqual(actual, expected)


class FactorTests(TestCase):
def test_basic(self):
for n, expected in (
(0, []),
(1, []),
(2, [2]),
(3, [3]),
(4, [2, 2]),
(6, [2, 3]),
(360, [2, 2, 2, 3, 3, 5]),
(128_884_753_939, [128_884_753_939]),
(999953 * 999983, [999953, 999983]),
(909_909_090_909, [3, 3, 7, 13, 13, 751, 113797]),
):
with self.subTest(n=n):
actual = list(mi.factor(n))
self.assertEqual(actual, expected)

def test_cross_check(self):
prod = lambda x: reduce(mul, x, 1)
self.assertTrue(all(prod(mi.factor(n)) == n for n in range(1, 2000)))
self.assertTrue(
all(set(mi.factor(n)) <= set(mi.sieve(n + 1)) for n in range(2000))
)
self.assertTrue(
all(
list(mi.factor(n)) == sorted(mi.factor(n)) for n in range(2000)
)
)

0 comments on commit 74c920a

Please sign in to comment.