-
Notifications
You must be signed in to change notification settings - Fork 0
/
cellular_automaton.py
98 lines (77 loc) · 3.11 KB
/
cellular_automaton.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import jax
import jax.numpy as np
from typing import Callable, List
class CellularAutomaton2D:
def __init__(
self, rule: Callable, neighborhood: List[List[bool]], grid,
):
if not self._is_valid_rule(rule):
raise Exception("Invalid rule")
if not self._is_valid_grid(grid):
raise Exception("Invalid grid")
if not self._is_valid_neighborhood(neighborhood):
raise Exception("Invalid neighborhood")
self.rule = rule
self.neighborhood = neighborhood
self.grid = np.array(grid)
# Create a shifted matrix with the same shape as `grid` for each neighbor cell
# that must be included, as defined by the neighborhood.
self._shifts = self._generate_shifts()
self._generate_next_grid_jit = jax.jit(
self._generate_next_grid, static_argnums=[1, 2]
)
@staticmethod
def _generate_next_grid(grid, shifts, rule):
# TODO: handle non-cyclic and pad-cell
shift_matrices = []
for (shift_x, shift_y) in shifts:
shift_matrices.append(np.roll(grid, (shift_x, shift_y), axis=(1, 0)))
# Stack these matrices to create a 3D block.
shift_block = np.array(shift_matrices)
# Reshape so that all neighborhoods are loaded into vectors
loaded = np.reshape(
np.stack(shift_block, axis=2),
(shift_block.shape[1] * shift_block.shape[2], shift_block.shape[0], -1),
)
# map the cellular automata's rule to a vectorized function
# which acts on the loaded neighborhoods
# this also effectively wraps the rule in "jax.jit"
vrule = jax.vmap(rule, in_axes=0, out_axes=0)
# create the next grid by running each loaded neighborhood to the next cell
next_grid = np.reshape(
vrule(loaded), (shift_block.shape[1], shift_block.shape[2], -1)
)
return next_grid
def step(self, n=1):
for i in range(n):
self.grid = self._generate_next_grid_jit(self.grid, self._shifts, self.rule)
def _generate_shifts(self):
height = len(self.neighborhood)
width = len(self.neighborhood[0])
center_y = height // 2
center_x = width // 2
shifts = []
for i in range(len(self.neighborhood)):
for j in range(len(self.neighborhood[0])):
if self.neighborhood[i][j]:
shifts.append((i - center_x, j - center_y))
return shifts
def _is_valid_rule(self, rule):
return True
def _is_valid_neighborhood(self, neighborhood):
if len(neighborhood) < 1:
return False
if len(neighborhood) % 2 != 1:
return False
# TODO: check that it's a np or onp array
# TODO: check rectangleness
return True
def _is_valid_grid(self, grid):
if len(grid.shape) < 3:
# TODO: Reply with a good error message here, if cell values are not vectors
return False
if len(grid) < 1:
return False
if len(grid[0]) < 1:
return False
return True