-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathxpu.py
263 lines (203 loc) · 8.27 KB
/
xpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
"""General utilities to unify CPU/IPU programming."""
from dataclasses import dataclass
from types import TracebackType
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Type, Union
import numpy as np
import tensorflow as tf
from tensorflow import keras
try:
from tensorflow.python import ipu
IPU = True
except ImportError: # pragma: no cover
IPU = False
Function = Callable[..., Any]
Operation = Callable[..., Dict[str, tf.Tensor]]
Batch = Dict[str, np.ndarray]
FunctionCache = Callable[[Any], Callable[[Function], Function]]
def _make_cache(**function_args: Any) -> FunctionCache:
"""Make a decorator that calls tf.function, with a user-keyed cache.
E.g.
cache = make_cache(experimental_compile=True)
body = ...
@cache(key=("model", body))
def model(x: tf.Tensor) -> tf.Tensor:
return 2 * body(x)
"""
_cache: Dict[Any, Function] = {}
def wrap(key: Any) -> Callable[[Function], Function]:
def wrapper(fn: Operation) -> Operation:
if key not in _cache:
_cache[key] = tf.function(**function_args)(fn)
return _cache[key]
return wrapper
return wrap
@dataclass
class CpuSettings:
"""CPU-specific settings."""
compile: bool = False
type: str = "cpu"
@dataclass
class IpuSettings:
"""IPU-specific settings."""
iterations_per_loop: int
available_memory_proportion: Optional[float] = None
stochastic_rounding: bool = False
type: str = "ipu"
Settings = Union[CpuSettings, IpuSettings]
class Context:
"""Manages target setup and a cache for compiled functions."""
_CURRENT: Optional["Context"] = None
def __init__(self, strategy: tf.distribute.Strategy, compile: bool):
self.strategy = strategy
self._scope = self.strategy.scope()
self._cache = (
_make_cache(experimental_compile=True)
if compile
else (lambda key: lambda fn: fn)
)
def __enter__(self) -> "Context":
assert Context._CURRENT is None, "xpu.context scopes cannot be nested"
Context._CURRENT = self
self._scope.__enter__()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self._scope.__exit__(exc_type, exc_val, exc_tb)
assert Context._CURRENT is self, "exiting a scope with the wrong context"
Context._CURRENT = None
def loop(self, operation: Operation, inputs: Iterable[Batch]) -> Iterable[Batch]:
"""Stream inputs into an operation and return all outputs.
operation -- callable as `result = operation(**input)`,
where `result` is a `dict`
"""
return loop_cpu(operation, inputs, strategy=self.strategy, cache=self._cache)
@staticmethod
def outline(layer: keras.layers.Layer) -> None:
"""Mark a layer for outlining on IPU, do nothing on CPU."""
def context(settings: Settings) -> Context:
"""Create an execution context with the given settings.
Should generally be used in an immediate `with` scope, e.g.
with xpu.context(xpu.CpuSettings(compile=False)) as context:
...
# also accessible as xpu.current_context()
"""
if isinstance(settings, CpuSettings):
return Context(tf.distribute.OneDeviceStrategy(""), compile=settings.compile)
if isinstance(settings, IpuSettings):
if not IPU: # pragma: no cover
raise ValueError(
"Cannot create IPU context - tensorflow.python.ipu could not be imported"
)
return _create_ipu_context(settings)
assert False, f"Unexpected Context settings type {settings}"
def current_context() -> Context:
"""Get the currently in-scope Context."""
# pylint:disable=protected-access
assert Context._CURRENT is not None, "there is no context in scope"
return Context._CURRENT
def loop_cpu(
operation: Operation,
inputs: Iterable[Batch],
strategy: tf.distribute.Strategy,
cache: FunctionCache,
) -> Iterable[Batch]:
"""Stream inputs into an operation and return all outputs.
operation -- callable as `result = operation(**input)`,
where `result` is a `dict`
"""
fn = cache(key=operation)(operation) # type:ignore[call-arg]
for input_ in inputs:
yield {k: np.array(v) for k, v in strategy.run(fn, kwargs=input_).items()}
if IPU:
class _IpuContext(Context):
def __init__(self, settings: IpuSettings):
super().__init__(ipu.ipu_strategy.IPUStrategy(), compile=True)
self.settings = settings
def loop(
self, operation: Operation, inputs: Iterable[Batch]
) -> Iterable[Batch]:
return loop_ipu(
operation,
inputs,
strategy=self.strategy,
cache=self._cache,
iterations_per_loop=self.settings.iterations_per_loop,
)
@staticmethod
def outline(layer: keras.layers.Layer) -> None:
inner_call = layer.call
def outlined_call(*args: Any, **kwargs: Any) -> Any:
@ipu.outlined_function # type:ignore[misc]
def call() -> Any:
return inner_call(*args, **kwargs)
return call()
layer.call = outlined_call
def _create_ipu_context(settings: IpuSettings) -> Context:
config = ipu.config.IPUConfig()
config.auto_select_ipus = 1
config.floating_point_behaviour.esr = (
ipu.config.StochasticRoundingBehaviour.from_bool(
settings.stochastic_rounding
)
)
config.device_connection.type = ipu.config.DeviceConnectionType.ON_DEMAND
if settings.available_memory_proportion is not None:
config.matmuls.poplar_options["availableMemoryProportion"] = str(
settings.available_memory_proportion
)
ipu.utils.configure_ipu_system(config)
return _IpuContext(settings)
def _padded_dataset(inputs: Iterable[Batch]) -> tf.data.Dataset:
iterator = iter(inputs)
head = next(iterator)
def generator() -> Iterable[Dict[str, np.ndarray]]: # pragma: no cover
yield dict(**head, _pad=np.array(False))
for item in iterator:
yield dict(**item, _pad=np.array(False))
while True: # padding
yield dict(**head, _pad=np.array(True))
signature = {
k: tf.TensorSpec(shape=v.shape, dtype=v.dtype) for k, v in head.items()
}
signature["_pad"] = tf.TensorSpec(shape=(), dtype=np.bool)
return tf.data.Dataset.from_generator(generator, output_signature=signature)
def loop_ipu(
operation: Operation,
inputs: Iterable[Batch],
strategy: tf.distribute.Strategy,
cache: FunctionCache,
iterations_per_loop: int,
) -> Iterable[Dict[str, np.ndarray]]:
"""Stream inputs into an operation and return all outputs.
operation -- callable as `result = operation(**input)`,
where `result` is a `dict`
"""
@cache( # type:ignore[call-arg]
key=("loop_ipu", operation, iterations_per_loop)
)
def _loop(
iterator: Iterator[Dict[str, tf.Tensor]],
outfeed: ipu.ipu_outfeed_queue.IPUOutfeedQueue,
) -> None: # pragma: no cover
for _ in tf.range(iterations_per_loop):
batch = next(iterator)
pad = batch.pop("_pad")
results = operation(**batch)
results["_pad"] = pad
outfeed.enqueue(results)
iterator = iter(_padded_dataset(inputs))
outfeed = ipu.ipu_outfeed_queue.IPUOutfeedQueue()
while True:
strategy.run(_loop, (iterator, outfeed))
for item in outfeed:
if item.pop("_pad"):
# Prevent: Error occurred when finalizing GeneratorDataset iterator
del iterator
del outfeed
return
yield {k: np.array(v) for k, v in item.items()}