-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
303 lines (233 loc) · 11.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import re
import math
import functools
from collections import defaultdict, deque
from enum import IntEnum, auto
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union
def get_input(path: str = './num17.txt') -> str:
with open(path) as f:
return f.read()
def get_input_lines(path: str = './num17.txt') -> List[str]:
return get_input(path).split('\n')
def ints(text: str, sign_prefixes: bool = True) -> Tuple[int, ...]:
regex = '([\-+]?\d+)' if sign_prefixes else '(\d+)'
return tuple(map(int, re.findall(regex, text)))
def floats(text: str) -> Tuple[float, ...]:
return tuple(map(float, re.findall('([\-+]?\d*(?:\d|\d\.|\.\d)\d*)', text)))
def sum_iter(x: Iterable[int], y: Iterable[int], s: int = 1) -> Tuple[int, ...]:
""" Returns the equation x + s * y for each element in the sequence x and y """
return tuple(a + s * b for a, b in zip(x, y))
def dot_iter(x: Iterable[int], y: Iterable[int]) -> int:
""" Returns the dot product of an sequence """
return sum(a * b for a, b in zip(x, y))
def differences(x: Sequence[int]) -> Tuple[int, ...]:
""" Returns the sequence of differences of consecutive elements of x """
return tuple(x[i + 1] - x[i] for i in range(len(x) - 1))
def prod(iterable: Iterable[int]) -> int:
""" Calculates the product of an iterable. In Python 3.8 this can be replaced with a call to math.prod """
return functools.reduce(lambda x, y: x * y, iterable)
def min_max(x: Iterable[int]) -> Tuple[int, int]:
""" Returns both the min and max of a sequence """
return min(x), max(x)
def sign(a: int) -> int:
""" Returns the sign of a """
return 0 if a == 0 else (-1 if a < 0 else 1)
def lcm(a: int, b: int) -> int:
""" Lowest common multiple. """
return a * b // math.gcd(a, b)
def gcd_iter(sequence: Iterable[int]) -> int:
""" Greatest common divisor of a sequence. """
return functools.reduce(math.gcd, sequence)
def lcm_iter(sequence: Iterable[int]) -> int:
""" Finds the lowest common multiple of a sequence. """
return functools.reduce(lcm, sequence)
def extended_gcd(a: int, b: int) -> Tuple[int, int, int]:
""" Extended Euclidean Algorithm. For any a, b, returns values gcd(a, b), x, and y such that a*x + b*y = gcd(a, b) """
if a == 0:
return b, 0, 1
g, x, y = extended_gcd(b % a, a)
return g, y - (b // a) * x, x
def crt(a1: int, m1: int, a2: int, m2: int) -> Tuple[int, int]:
""" Generalized Chinese Remainder Theorem (CRT).
Find x, where x = a1 mod m1, x = a2 mod m2, if such a solution exists.
It always exists when m1 and m2 are coprime, or if g = gcd(m1, m2), and a1 == a2 mod g
Returns an x, and the modulus (m1 * m2 for the coprime case)
"""
g, x, y = extended_gcd(m1, m2)
if a1 % g == a2 % g:
ar, mr = (a2 * x * m1 + a1 * y * m2) // g, m1 * m2 // g
else:
raise ValueError('The system x = %d mod %d, x = %d mod %d has no solution' % (a1, m1, a2, m2))
return ar % mr, mr
def mod_inv(a: int, m: int) -> int:
""" Finds x such that a*x ~= 1 mod m. Uses the extended euclidean algorithm. In python 3.8+ this can be pow(a, -1, m), but this is explicitly written here in order to be PyPy compliant. """
g, x, y = extended_gcd(a, m)
if g == 1:
return ((x % m) + m) % m
raise ValueError('Modular inverse for a=%d, m=%d does not exist' % (a, m))
def ray_int(start: Iterable[int], end: Iterable[int]) -> list:
""" Returns a list of tuples of the points in a ray cast from start to end, not including either """
deltas = sum_iter(end, start, -1)
delta_gcd = gcd_iter(deltas)
if delta_gcd > 1:
return [tuple(s + d * g // delta_gcd for s, d in zip(start, deltas)) for g in range(1, delta_gcd)]
return []
class Grid:
@staticmethod
def from_text(text: str, default_value: Optional[str] = None):
return Grid([list(line.strip()) for line in text.strip().split('\n')], default_value)
@staticmethod
def from_lines(lines: List[str], default_value: Optional[str] = None):
return Grid([list(line) for line in lines], default_value)
def __init__(self, grid: List[List[str]], default_value: Optional[str] = None):
self.grid = grid
self.height = len(grid)
self.width = len(grid[0])
self.default_value = default_value
def copy(self) -> 'Grid':
return Grid([row.copy() for row in self.grid], self.default_value)
def count(self, value: str) -> int:
return sum(row.count(value) for row in self.grid)
def locations(self) -> Iterator[Tuple[int, int]]:
for x in range(self.width):
for y in range(self.height):
yield x, y
def map_create(self, f: Callable[[int, int], str]) -> 'Grid':
return Grid([[f(x, y) for x in range(self.width)] for y in range(self.height)], self.default_value)
def __getitem__(self, item):
if isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], int) and isinstance(item[1], int):
if 0 <= item[0] < self.width and 0 <= item[1] < self.height:
return self.grid[item[1]][item[0]]
elif self.default_value is not None:
return self.default_value
else:
raise ValueError('Provided location is out of bounds: %s not in [0, %d) x [0, %d)' % (str(item), self.width, self.height))
else:
raise TypeError('Provided index is not an (x, y) tuple: %s' % str(item))
def __setitem__(self, key, value):
if isinstance(key, tuple) and len(key) == 2 and isinstance(key[0], int) and isinstance(key[1], int):
if 0 <= key[0] < self.width and 0 <= key[1] < self.height:
self.grid[key[1]][key[0]] = value
elif self.default_value is not None:
return self.default_value
else:
raise ValueError('Provided index is out of bounds: %s not in [0, %d) x [0, %d)' % (str(key), self.width, self.height))
else:
raise TypeError('Provided index is not an (x, y) tuple: %s' % str(key))
def __contains__(self, item):
if isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], int) and isinstance(item[1], int):
return 0 <= item[0] < self.width and 0 <= item[1] < self.height
if isinstance(item, str):
return any(item in row for row in self.grid)
raise TypeError('Provided item is not a key (x, y) pair, or value (str): %s' % str(item))
def __eq__(self, other):
return other is not None and self.grid == other.grid
def __str__(self):
return '\n'.join(''.join(row) for row in self.grid)
def grid_bfs(grid: Grid, passable: Set[str], start: Tuple[int, int]) -> Dict[Tuple[int, int], int]:
""" A template for a grid based BFS """
queue = deque()
queue.append((*start, 0))
found = {start}
distances = {start: 0}
while queue:
x0, y0, d0 = queue.popleft()
for dx, dy in ((0, 1), (0, -1), (1, 0), (-1, 0)):
p1 = x0 + dx, y0 + dy
if p1 not in found and p1 in grid and grid[p1] in passable:
found.add(p1)
distances[p1] = d0 + 1
queue.append((*p1, d0 + 1))
return distances
class Cycle:
"""
Computes a cycle, identified by a state transition map f: X -> X
Allows for optimized operations using that cycle, such as computing the state at a far future location or slice, or computing the number of states between two points.
"""
def __init__(self, start: Any, generator: Callable[[Any], Any]):
self.generator = generator
self.prefix = []
self.cycle = []
seen = set()
state = start
i = 0
while state not in seen:
seen.add(state)
self.prefix.append(state)
state = generator(state)
i += 1
index = self.prefix.index(state)
self.cycle = self.prefix[index:]
self.prefix = self.prefix[:index]
self.period = len(self.cycle)
self.prefix_len = len(self.prefix)
def values(self, min_inclusive: int, max_exclusive: int) -> List[Tuple[Any, int]]:
""" Get all the states and the number of each from the slice [min_inclusive, max_exclusive) """
values = defaultdict(int)
i = min_inclusive
while i < self.prefix_len and i < max_exclusive: # iterate prefix values
values[self[i]] += 1
i += 1
while (i - self.prefix_len) % self.period != 0 and i < max_exclusive: # iterate until we reach the start of a cycle
values[self[i]] += 1
i += 1
for item in self.cycle:
counts = 1 + (max_exclusive - 1 - i) // self.period
if counts > 0:
values[item] += counts
i += 1
return list(values.items())
def __getitem__(self, item):
""" Access an individual item of the cycle as if it was a list from [0, infinity). Supports bounded slicing. """
if isinstance(item, int):
if item < 0:
raise TypeError('Cannot index a cycle with a negative value: %d' % item)
if item < len(self.prefix):
return self.prefix[item]
else:
return self.cycle[(item - len(self.prefix)) % len(self.cycle)]
if isinstance(item, slice):
if item.stop is None:
raise TypeError('Slice of Cycle must be bounded')
return [self[i] for i in range(0 if item.start is None else item.start, item.stop, 1 if item.step is None else item.step)]
def __str__(self):
return 'Cycle{prefix=%s, cycle=%s}' % (str(self.prefix), str(self.cycle))
class Opcode(IntEnum):
nop = auto()
acc = auto()
jmp = auto()
class Asm:
""" An abstraction for the yet-unnamed assembly code used in Day 8 """
@staticmethod
def parse(lines: List[str]) -> List[Tuple[Union[Opcode, int], ...]]:
code = []
for line in lines:
opcode, *args = line.split(' ')
code.append((Opcode[opcode], *map(int, args)))
return code
def __init__(self, code: List[Tuple[Union['Opcode', int], ...]]):
self.code: List[Tuple[Union['Opcode', int], ...]] = code
self.pointer: int = 0
self.accumulator: int = 0
self.running: bool = False
def run(self) -> 'Asm':
self.running = True
while self.running:
self.tick()
return self
def tick(self):
if self.valid():
opcode = self.code[self.pointer][0]
if opcode == Opcode.nop: # nop
self.pointer += 1
elif opcode == Opcode.acc: # acc [value] -> increment the accumulator by [value]
self.accumulator += self.code[self.pointer][1]
self.pointer += 1
elif opcode == Opcode.jmp: # jmp [offset] -> unconditional branch by [offset]
self.pointer += self.code[self.pointer][1]
else:
self.running = False
def valid(self):
return 0 <= self.pointer < len(self.code)
def __str__(self):
return 'Asm{p=%d, code[p]=%s, acc=%d}' % (self.pointer, str(self.code[self.pointer]) if self.valid() else '???', self.accumulator)