-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
pjit.py
2713 lines (2370 loc) · 118 KB
/
pjit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Sequence, Iterable
import dataclasses
from functools import partial
import inspect
import logging
import operator as op
import weakref
from typing import NamedTuple, Any, Union, cast
import warnings
import numpy as np
from jax._src import api
from jax._src import ad_util
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import op_shardings
from jax._src import profiler
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import stages
from jax._src import traceback_util
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, shaped_abstractify, check_callable, resolve_argnums,
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info,
hoist_obj_attrs)
from jax._src.interpreters import partial_eval as pe
from jax._src.partition_spec import PartitionSpec
from jax._src.interpreters import xla
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src import sharding
from jax._src.mesh import AbstractMesh
from jax._src.sharding_impls import (
NamedSharding, GSPMDSharding,
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
ParsedPartitionSpec, get_single_pspec, prepare_axis_resources, parse_flatten_op_sharding)
from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout
from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import (
tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, tree_leaves,
treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr,
PyTreeDef, none_leaf_registry as none_lr)
from jax._src.util import (
HashableFunction, safe_map, safe_zip, wraps,
distributed_debug_log, split_list, weakref_lru_cache,
merge_lists, subs_list, fun_name, fun_qual_name)
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
traceback_util.register_exclusion(__file__)
PjitSharding = Union[GSPMDSharding, UnspecifiedValue, AUTO]
PjitShardingMinusUnspecified = Union[GSPMDSharding, AUTO]
MeshSharding = Union[NamedSharding, UnspecifiedValue, AUTO]
MeshShardingMinusUnspecified = Union[NamedSharding, AUTO]
logger = logging.getLogger(__name__)
def _find_arg_mismatch(arg_list, fails, fun_name):
mismatched_args_msg = []
def mismatch(err):
for name, inp_da, aval in arg_list:
if err.m_type == pxla.MismatchType.ARG_SHARDING and err.da == inp_da:
mismatched_args_msg.append(
f"argument {name} of {fun_name} with shape {aval.str_short()} and "
f"{err._dev_ids_plat_str}")
break
first_err, second_err = fails
mismatch(first_err)
mismatch(second_err)
return mismatched_args_msg
def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name,
arg_names):
arg_list = []
if arg_names is None:
arg_names = [''] * len(args_flat)
for a, n in zip(args_flat, arg_names):
da = (a.sharding._device_assignment
if getattr(a, 'sharding', None) is not None else None)
arg_list.append((n, da, shaped_abstractify(a)))
mismatched_args_msg = _find_arg_mismatch(arg_list, fails, fun_name)
if len(mismatched_args_msg) == 2:
first, second = mismatched_args_msg # pytype: disable=bad-unpacking
extra_msg = f" Got {first} and {second}"
elif len(mismatched_args_msg) == 1:
first, second = fails
# Choose the failure left which is not already covered by ARG_SHARDING.
left = second if first.m_type == pxla.MismatchType.ARG_SHARDING else first
extra_msg = f" Got {mismatched_args_msg[0]} and{left._str(api_name)}"
else:
first, second = fails
extra_msg = f" Got{first._str(api_name)} and{second._str(api_name)}"
msg = (f"Received incompatible devices for {api_name}ted computation.{extra_msg}")
return msg
class PjitInfo(NamedTuple):
"""Things that we know about a jit instance before it is called.
In other words, this structure contains arguments to jit()/pjit(),
preprocessed and validated.
"""
fun_sourceinfo: str | None
fun_signature: inspect.Signature | None
# Shardings, as specified by the user. These can either be UNSPECIFIED or they
# can be a tree (prefix) of shardings or None.
user_specified_in_shardings: bool
in_shardings_treedef: PyTreeDef
in_shardings_leaves: tuple[Any, ...]
out_shardings_treedef: PyTreeDef
out_shardings_leaves: tuple[Any, ...]
in_layouts_treedef: PyTreeDef
in_layouts_leaves: tuple[Any, ...]
out_layouts_treedef: PyTreeDef
out_layouts_leaves: tuple[Any, ...]
static_argnums: tuple[int, ...]
static_argnames: tuple[str, ...]
donate_argnums: tuple[int, ...]
donate_argnames: tuple[str, ...]
device: xc.Device | None
backend: str | None
keep_unused: bool
inline: bool
abstracted_axes: Any | None
use_resource_env: bool # False for jit, True for pjit
compiler_options_kvs: tuple[tuple[str, Any], ...]
# Hash and compare PjitInfo by identity when used as a cache key.
def __hash__(self):
return id(self)
def __eq__(self, other):
return self is other
def _python_pjit_helper(fun, jit_info, *args, **kwargs):
p, args_flat = _infer_params(fun, jit_info, args, kwargs)
for arg in args_flat:
dispatch.check_arg(arg)
if p.attrs_tracked:
init_states = _get_states(p.attrs_tracked)
args_flat = [*init_states, *args_flat]
try:
# TODO(yashkatariya): Maybe thread this into pjit params like resource_env
# and set the context manager down the stack?
with p.abstract_mesh:
if (core.trace_state_clean() and
not config.debug_key_reuse.value and
not config.data_dependent_tracing_fallback.value):
args_flat = map(core.full_lower, args_flat)
core.check_eval_args(args_flat)
out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
else:
out_flat = pjit_p.bind(*args_flat, **p.params)
compiled = None
profiler = None
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if p.params['resource_env'] is None else 'pjit'
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
msg = _device_assignment_mismatch_error(
fun_name, fails, args_flat, api_name, p.arg_names)
raise ValueError(msg) from None
except xla.InvalidInputException as e:
arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names
# Run canonicalization again to figure out which arg failed.
if p.params['jaxpr'].consts:
raise TypeError(e.args[0]) from e
else:
for arg, name, aval in zip(args_flat, arg_names, p.in_avals):
try:
xla.canonicalize_dtype(arg)
except xla.InvalidInputException as _:
# Reraise as TypeError with the new message.
raise TypeError(
f"Argument '{name}' of shape {aval.str_short()} of type"
f' {type(arg)} is not a valid JAX type.') from e
raise AssertionError("Unreachable") from e
if p.attrs_tracked:
num_states_out = sum(end_tree.num_leaves for _, end_tree, _ in p.attrs_tracked)
final_states, out_flat = split_list(out_flat, [num_states_out])
_set_states(p.attrs_tracked, final_states)
outs = tree_unflatten(p.out_tree, out_flat)
return (outs, out_flat, p.out_tree, args_flat, p.params['jaxpr'],
p.attrs_tracked, compiled, profiler)
def _set_states(attrs_tracked, vals):
from jax.experimental.attrs import jax_setattr
valss = split_list(vals, [td.num_leaves for _, td, _ in attrs_tracked[:-1]])
for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss):
val = tree_unflatten(treedef, leaves)
jax_setattr(obj, attr, val)
def _get_states(attrs_tracked):
from jax.experimental.attrs import jax_getattr
vals = []
for treedef, _, (obj, attr) in attrs_tracked:
tree = jax_getattr(obj, attr)
leaves, treedef_ = tree_flatten(tree)
assert treedef == treedef_
vals.extend(leaves)
return vals
def _need_to_rebuild_with_fdo(pgle_profiler):
return (pgle_profiler is not None and pgle_profiler.is_enabled()
and not pgle_profiler.is_fdo_consumed())
def _get_fastpath_data(
executable, out_tree, args_flat, out_flat, attrs_tracked, effects,
consts, abstracted_axes, pgle_profiler
) -> pxla.MeshExecutableFastpathData | None:
out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
use_fastpath = (
executable is not None
and isinstance(executable, pxla.MeshExecutable)
and isinstance(executable.unsafe_call, pxla.ExecuteReplicated)
# No effects in computation
and not executable.unsafe_call.ordered_effects
and not executable.unsafe_call.has_unordered_effects
and not executable.unsafe_call.has_host_callbacks
and all(isinstance(x, xc.ArrayImpl) for x in out_reflattened)
and abstracted_axes is None
# no attr state effects
and not attrs_tracked
# no ref state effects
and not any(isinstance(e, RefEffect) for e in effects)
# no prng reuse checking
and not (config.debug_key_reuse.value and any(
hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key)
for arg in (*args_flat, *out_flat, *consts)))
and not _need_to_rebuild_with_fdo(pgle_profiler)
)
if use_fastpath:
out_avals = [o.aval for o in out_reflattened]
out_committed = [o._committed for o in out_reflattened]
kept_var_bitvec = [i in executable._kept_var_idx
for i in range(len(args_flat))]
in_shardings = [
sharding_impls.physical_sharding(a, s)
if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
else s
for s, a in zip(executable._in_shardings, executable.in_avals)
]
fastpath_data = pxla.MeshExecutableFastpathData(
executable.xla_executable, out_tree, in_shardings,
executable._out_shardings, out_avals, out_committed, kept_var_bitvec,
executable._dispatch_in_layouts)
else:
fastpath_data = None
return fastpath_data
def _cpp_pjit_evict_fn(self):
self._clear_cache()
_create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error
_infer_params_cached.cache_clear()
# The entries are doubled here from the default 4096 because _pjit_call_impl
# also has a cpp dispatch path and that would double the number of entries in
# the global shared cache.
# This cache is only used for jit's with only fun. For example: jax.jit(f)
_cpp_pjit_cache_fun_only = xc._xla.PjitFunctionCache(capacity=8192)
# This cache is used for jit where extra arguments are defined other than the
# fun. For example: jax.jit(f, donate_argnums=...) OR
# jax.jit(f, out_shardings=...), etc. We don't use the same cache because the
# capacity might get full very fast because of all the jitted function in JAX
# which might evict train_step for example.
_cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192)
def _get_cpp_global_cache(contains_explicit_attributes: bool):
if contains_explicit_attributes:
return _cpp_pjit_cache_explicit_attributes
else:
return _cpp_pjit_cache_fun_only
def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
@api_boundary
def cache_miss(*args, **kwargs):
if config.no_tracing.value:
raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for "
"`jit`, but 'no_tracing' is set")
(outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, executable,
pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
maybe_fastpath_data = _get_fastpath_data(
executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects,
jaxpr.consts, jit_info.abstracted_axes,
pgle_profiler)
return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
cache_key = pxla.JitGlobalCppCacheKeys(
donate_argnums=jit_info.donate_argnums,
donate_argnames=jit_info.donate_argnames,
device=jit_info.device, backend=jit_info.backend,
in_shardings_treedef=jit_info.in_shardings_treedef,
in_shardings_leaves=jit_info.in_shardings_leaves,
out_shardings_treedef=jit_info.out_shardings_treedef,
out_shardings_leaves=jit_info.out_shardings_leaves,
in_layouts_treedef=jit_info.in_layouts_treedef,
in_layouts_leaves=jit_info.in_layouts_leaves,
out_layouts_treedef=jit_info.out_layouts_treedef,
out_layouts_leaves=jit_info.out_layouts_leaves,
use_resource_env=jit_info.use_resource_env,
compiler_options_kvs=jit_info.compiler_options_kvs)
cpp_pjit_f = xc._xla.pjit(
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
jit_info.static_argnames, cache_key, tree_util.dispatch_registry,
pxla.cc_shard_arg,
_get_cpp_global_cache(cache_key.contains_explicit_attributes))
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
cpp_pjitted_f._fun = fun
type(cpp_pjitted_f).clear_cache = _cpp_pjit_evict_fn
return cpp_pjitted_f
def _split_layout_and_sharding(entries):
entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None)
layouts, shardings = [], []
for e in entries_flat:
if isinstance(e, Layout):
layouts.append(e.device_local_layout)
shardings.append(e.sharding)
elif isinstance(e, (DeviceLocalLayout, AutoLayout)):
raise ValueError(
'`jax.jit` does not accept device-local layouts directly. Create '
'a `Layout` instance wrapping this device-local layout and pass '
f'that to `jit` instead. Got {e}')
else:
layouts.append(None)
shardings.append(e)
assert len(layouts) == len(shardings)
return tree_unflatten(treedef, layouts), tree_unflatten(treedef, shardings)
def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
donate_argnums: int | Sequence[int] | None,
donate_argnames: str | Iterable[str] | None,
static_argnums: int | Sequence[int] | None,
static_argnames: str | Iterable[str] | None,
device: xc.Device | None, backend: str | None,
abstracted_axes: Any | None, keep_unused: bool,
inline: bool, compiler_options: dict[str, Any] | None,
use_resource_env: bool) -> PjitInfo:
"""Parses the arguments to jit/pjit.
Performs any preprocessing and validation of the arguments that we can do
ahead of time before the jit()-ed function is invoked.
"""
if abstracted_axes and not config.dynamic_shapes.value:
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
check_callable(fun)
if backend is not None or device is not None:
warnings.warn(
'backend and device argument on jit is deprecated. You can use'
' `jax.device_put(..., jax.local_devices("cpu")[0])` on the inputs to'
' the jitted function to get the same behavior.', DeprecationWarning)
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
f"got {device=} and {backend=}")
if in_shardings is not None and not isinstance(in_shardings, UnspecifiedValue):
raise ValueError('If backend or device is specified on jit, then '
'in_shardings should not be specified.')
if out_shardings is not None and not isinstance(out_shardings, UnspecifiedValue):
raise ValueError('If backend or device is specified on jit, then '
'out_shardings should not be specified.')
if isinstance(in_shardings, list):
# To be a tree prefix of the positional args tuple, in_axes can never be a
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
# in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here
# rather than raising an error. https://github.com/jax-ml/jax/issues/2367
in_shardings = tuple(in_shardings)
in_layouts, in_shardings = _split_layout_and_sharding(in_shardings)
out_layouts, out_shardings = _split_layout_and_sharding(out_shardings)
in_shardings = prepare_axis_resources(in_shardings, 'in_shardings')
out_shardings = prepare_axis_resources(out_shardings, 'out_shardings')
user_specified_in_shardings = (in_shardings is not None and
not isinstance(in_shardings, UnspecifiedValue))
in_shardings_leaves, in_shardings_treedef = none_lr.flatten(in_shardings)
out_shardings_leaves, out_shardings_treedef = none_lr.flatten(out_shardings)
in_layouts_leaves, in_layouts_treedef = none_lr.flatten(in_layouts)
out_layouts_leaves, out_layouts_treedef = none_lr.flatten(out_layouts)
fun_sourceinfo = api_util.fun_sourceinfo(fun)
fun_signature = api_util.fun_signature(fun)
donate_argnums, donate_argnames, static_argnums, static_argnames = resolve_argnums(
fun, fun_signature, donate_argnums, donate_argnames, static_argnums,
static_argnames)
compiler_options_kvs = (() if compiler_options is None else
tuple(compiler_options.items()))
return PjitInfo(
fun_sourceinfo=fun_sourceinfo,
fun_signature=fun_signature,
user_specified_in_shardings=user_specified_in_shardings,
in_shardings_treedef=in_shardings_treedef,
in_shardings_leaves=tuple(in_shardings_leaves),
out_shardings_treedef=out_shardings_treedef,
out_shardings_leaves=tuple(out_shardings_leaves),
in_layouts_treedef=in_layouts_treedef,
in_layouts_leaves=tuple(in_layouts_leaves),
out_layouts_treedef=out_layouts_treedef,
out_layouts_leaves=tuple(out_layouts_leaves),
static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline,
abstracted_axes=abstracted_axes,
use_resource_env=use_resource_env,
compiler_options_kvs=compiler_options_kvs)
def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
@api_boundary
def lower(*args, **kwargs):
return trace(*args, **kwargs).lower()
@api_boundary
def eval_shape(*args, **kwargs):
p, _ = _infer_params(fun, jit_info, args, kwargs)
out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']]
# TODO(yashkatariya): Add `Layout` to SDS.
out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s,
weak_type=x.weak_type)
for x, s in zip(p.params['jaxpr'].out_avals, out_s)]
return tree_unflatten(p.out_tree, out)
@api_boundary
def trace(*args, **kwargs) -> stages.Traced:
p, args_flat = _infer_params(fun, jit_info, args, kwargs)
donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d)
args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums)
lower_callable = partial(_resolve_and_lower, args_flat, **p.params,
pgle_profiler=None)
return stages.Traced(
p.params['jaxpr'], args_info, p.params["name"], p.out_tree,
lower_callable, p.abstract_mesh, args_flat, p.arg_names, p.num_consts)
wrapped = _cpp_pjit(fun, jit_info)
wrapped.lower = lower
wrapped.eval_shape = eval_shape
wrapped.trace = trace
return wrapped
def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any,
donate_argnums: int | Sequence[int] | None,
donate_argnames: str | Iterable[str] | None,
static_argnums: int | Sequence[int] | None,
static_argnames: str | Iterable[str] | None,
device: xc.Device | None, backend: str | None,
abstracted_axes: Any | None, keep_unused: bool,
inline: bool, compiler_options: dict[str, Any] | None,
use_resource_env: bool) -> Any:
"""jit() and pjit() are thin wrappers around this function."""
jit_info = _parse_jit_arguments(
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames, device, backend, abstracted_axes,
keep_unused, inline, compiler_options, use_resource_env)
return _make_jit_wrapper(fun, jit_info)
class PjitParams(NamedTuple):
consts: list[Any] # Only jaxpr constants, we can't keep other arguments alive
params: dict[str, Any]
in_avals: tuple[core.AbstractValue, ...]
in_tree: PyTreeDef
out_tree: PyTreeDef
donated_invars: tuple[bool, ...]
arg_names: tuple[str, ...] | None
num_consts: int
attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]
abstract_mesh: AbstractMesh
def _infer_params_impl(
fun: Callable,
ji: PjitInfo,
pjit_mesh: mesh_lib.Mesh | None,
resource_env: mesh_lib.ResourceEnv | None,
args: tuple[Any, ...],
kwargs: dict[str, Any],
in_avals: tuple[core.AbstractValue, ...] | None,
) -> tuple[PjitParams, list[Any]]:
have_kwargs = bool(kwargs)
if have_kwargs and ji.user_specified_in_shardings:
raise ValueError(
"pjit does not support kwargs when in_shardings is specified.")
if pjit_mesh is not None:
jit_name = 'pjit'
if (ji.backend or ji.device) and not pjit_mesh.empty:
raise ValueError(
"Mesh context manager should not be used with jit when backend or "
"device is also specified as an argument to jit.")
else:
jit_name = 'jit'
axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)
dbg = debug_info(jit_name, ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
ji.static_argnums, ji.static_argnames)
f = lu.wrap_init(fun)
f, res_paths = result_paths(f)
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
del args
f, dyn_kwargs = argnames_partial_except(f, ji.static_argnames, kwargs)
explicit_args, in_tree = tree_flatten((dyn_args, dyn_kwargs))
flat_fun, out_tree = flatten_fun(f, in_tree)
flat_fun, explicit_args = hoist_obj_attrs(flat_fun, explicit_args)
if (ji.donate_argnums or ji.donate_argnames) and not config.debug_nans.value:
donated_invars = donation_vector(ji.donate_argnums, ji.donate_argnames, in_tree)
else:
donated_invars = (False,) * len(explicit_args)
# If backend or device is set as an arg on jit, then resolve them to
# in_shardings and out_shardings as if user passed in in_shardings
# and out_shardings.
device_or_backend_set = bool(ji.backend or ji.device)
if device_or_backend_set:
sharding = _create_sharding_with_device_backend(ji.device, ji.backend)
leaves, treedef = tree_flatten(sharding)
in_shardings_leaves = out_shardings_leaves = tuple(leaves)
in_shardings_treedef = out_shardings_treedef = treedef
else:
in_shardings_leaves = tuple(
_create_sharding_for_array(pjit_mesh, x, 'in_shardings', jit_name)
for x in ji.in_shardings_leaves)
in_shardings_treedef = ji.in_shardings_treedef
out_shardings_leaves = tuple(
_create_sharding_for_array(pjit_mesh, x, 'out_shardings', jit_name)
for x in ji.out_shardings_leaves)
out_shardings_treedef = ji.out_shardings_treedef
assert None not in in_shardings_leaves
assert None not in out_shardings_leaves
in_type: core.InputType | tuple[core.AbstractValue, ...]
if config.dynamic_shapes.value:
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
in_avals = tuple(a for a, e in in_type if e)
elif in_avals is None:
avals = []
for i, a in enumerate(explicit_args):
try:
avals.append(shaped_abstractify(a))
except OverflowError as e:
arg_path = (f"argument path is {dbg.arg_names[i]}" if dbg
else f"flattened argument number is {i}")
raise OverflowError(
"An overflow was encountered while parsing an argument to a jitted "
f"computation, whose {arg_path}."
) from e
except TypeError as e:
arg_description = (f"path {dbg.arg_names[i]}" if dbg
else f"flattened argument number {i}")
raise TypeError(
f"Error interpreting argument to {fun} as an abstract array."
f" The problematic value is of type {type(a)} and was passed to"
f" the function at {arg_description}.\n"
"This typically means that a jit-wrapped function was called with a non-array"
" argument, and this argument was not marked as static using the"
" static_argnums or static_argnames parameters of jax.jit."
) from e
in_type = in_avals = tuple(avals)
else:
in_type = in_avals
in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
in_shardings_treedef, in_shardings_leaves,
ji.in_layouts_treedef, ji.in_layouts_leaves,
in_avals, in_tree, dbg, device_or_backend_set, have_kwargs)
attr_token = _attr_token(flat_fun, in_type)
abstract_mesh = (
get_abstract_mesh(in_type) if mesh_lib.abstract_mesh_context.mesh is None
else mesh_lib.abstract_mesh_context.mesh)
with abstract_mesh:
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
flat_fun, in_type, attr_token, dbg,
HashableFunction(res_paths, closure=()),
IgnoreKey(ji.inline))
_attr_update(flat_fun, in_type, attr_token, attrs_tracked)
out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef,
ji.out_layouts_leaves, HashableFunction(out_tree, closure=()),
tuple(out_avals), jaxpr.jaxpr.debug_info, device_or_backend_set)
assert len(explicit_args) == len(in_shardings_flat) == len(in_layouts_flat)
if config.dynamic_shapes.value:
implicit_args = _extract_implicit_args(
cast(core.InputType, in_type), explicit_args)
else:
implicit_args = []
args_flat = [*implicit_args, *explicit_args]
num_states_in = sum(init_tree.num_leaves for init_tree, _, _ in attrs_tracked)
num_extra_args = len(implicit_args) + num_states_in + len(consts)
in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat
in_layouts_flat = (None,) * num_extra_args + in_layouts_flat
donated_invars = (False,) * num_extra_args + donated_invars
assert (len(in_shardings_flat) == len(in_layouts_flat) ==
len(donated_invars) == num_states_in + len(consts) + len(args_flat))
params = dict(
jaxpr=jaxpr,
in_shardings=in_shardings_flat,
out_shardings=out_shardings_flat,
in_layouts=in_layouts_flat,
out_layouts=out_layouts_flat,
resource_env=resource_env,
donated_invars=donated_invars,
name=fun_qual_name(flat_fun),
keep_unused=ji.keep_unused,
inline=ji.inline,
compiler_options_kvs=ji.compiler_options_kvs,
)
return PjitParams(consts, params, in_avals, in_tree, out_tree(),
donated_invars, dbg.arg_names if dbg else None, len(consts),
attrs_tracked, abstract_mesh), args_flat
def get_abstract_mesh(in_avals):
if not config.sharding_in_types.value:
return mesh_lib.null_mesh_context()
m = None
for a in in_avals:
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
if a.sharding is None: # type: ignore
continue
if m is not None and m != a.sharding.mesh:
raise ValueError(
f'Mesh for all inputs should be equal. Got one mesh: {m} and'
f' another mesh: {a.sharding.mesh}')
m = a.sharding.mesh # type: ignore
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
if m is None:
return mesh_lib.null_mesh_context()
assert isinstance(m, AbstractMesh)
return m
class InferParamsCacheEntry:
"""Mutable value object for _infer_params_cached."""
__slots__ = ['pjit_params']
pjit_params: PjitParams | None
def __init__(self):
self.pjit_params = None
# We use an outer cache that is keyed on the signature of the arguments, but
# when populating a cache entry using _infer_params_impl, we need to provide
# actual arguments. In principle we could refactor _infer_params_impl to look
# only at an argument signature instead of args/kwargs in those cases that we
# cache, but this was a more minimal change.
@util.weakref_lru_cache
def _infer_params_cached(
fun: Callable,
jit_info: PjitInfo,
signature: jax_jit.ArgumentSignature,
in_avals: tuple[core.AbstractValue, ...],
pjit_mesh: mesh_lib.Mesh | None,
resource_env: mesh_lib.ResourceEnv | None,
) -> InferParamsCacheEntry:
return InferParamsCacheEntry()
def _infer_params(
fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[PjitParams, list[Any]]:
if ji.use_resource_env:
# We need to fetch the mesh from inside the wrapped function, because
# meshes are dynamically scoped (i.e., with a context manager).
resource_env = mesh_lib.thread_resources.env
pjit_mesh = resource_env.physical_mesh
else:
resource_env = None
pjit_mesh = None
skip_cache = config.dynamic_shapes.value
if not skip_cache:
signature, dynargs = jax_jit.parse_arguments(
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
ji.static_argnames, tree_util.default_registry)
try:
avals = tuple(shaped_abstractify(a) for a in dynargs)
except (OverflowError, TypeError):
# If we see something we don't understand, use the slow path.
skip_cache = True
if skip_cache:
p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, args,
kwargs, in_avals=None)
return p, p.consts + args_flat
entry = _infer_params_cached(
fun, ji, signature, avals, pjit_mesh, resource_env)
if entry.pjit_params is None:
p, args_flat = _infer_params_impl(
fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals=avals)
if p.attrs_tracked:
# If there are attrs_tracked, don't use the cache.
return p, p.consts + args_flat
else:
entry.pjit_params = p
return entry.pjit_params, entry.pjit_params.consts + dynargs
def _extract_implicit_args(
in_type: Sequence[tuple[core.AbstractValue, bool]],
explicit_args: Sequence[Any]
) -> Sequence[core.Tracer]:
"""
Given an input type and explicitly-passed arguments (per the user-facing API
calling convention), extract implicit axis size arguments from shapes of
explicit arguments (for the trace-time / jaxpr-level calling convention).
"""
# First, using `in_type` construct a list to represent the full argument list,
# leaving the implicit arguments as None placeholders for now.
explicit_args_ = iter(explicit_args)
args = [next(explicit_args_) if expl else None for _, expl in in_type]
assert next(explicit_args_, None) is None
del explicit_args, explicit_args_
# Next, populate the implicit arguments using the DBIdxs in `in_type`.
for i, (aval, explicit) in enumerate(in_type):
if not explicit or not isinstance(aval, core.DShapedArray):
continue # can't populate an implicit argument
arg = args[i]
assert arg is not None
for d1, d2 in zip(aval.shape, arg.aval.shape):
if isinstance(d1, core.DBIdx):
if args[d1.val] is None:
args[d1.val] = d2
assert core.same_referent(args[d1.val], d2)
assert all(x is not None for x in args)
return [x for x, (_, e) in zip(args, in_type) if not e] # type: ignore
def _flat_axes_specs(abstracted_axes, *args, **kwargs
) -> list[pe.AbstractedAxesSpec] | None:
if abstracted_axes is None: return None
if kwargs: raise NotImplementedError
def ax_leaf(l):
return (isinstance(l, dict) and all_leaves(l.values()) or
isinstance(l, tuple) and all_leaves(l, lambda x: x is None))
return broadcast_prefix(abstracted_axes, args, ax_leaf)
class JitWrapped(stages.Wrapped):
def eval_shape(self, *args, **kwargs):
"""See ``jax.eval_shape``."""
raise NotImplementedError
def trace(self, *args, **kwargs) -> stages.Traced:
raise NotImplementedError
# in_shardings and out_shardings can't be None as the default value
# because `None` means that the input is fully replicated.
def pjit(
fun: Callable,
in_shardings=UNSPECIFIED,
out_shardings=UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,
static_argnames: str | Iterable[str] | None = None,
donate_argnums: int | Sequence[int] | None = None,
donate_argnames: str | Iterable[str] | None = None,
keep_unused: bool = False,
device: xc.Device | None = None,
backend: str | None = None,
inline: bool = False,
abstracted_axes: Any | None = None,
compiler_options: dict[str, Any] | None = None,
) -> JitWrapped:
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
NOTE: This function is now equivalent to jax.jit please use that instead.
The returned function has semantics equivalent to those of ``fun``, but is
compiled to an XLA computation that runs across multiple devices
(e.g. multiple GPUs or multiple TPU cores). This can be useful if the jitted
version of ``fun`` would not fit in a single device's memory, or to speed up
``fun`` by running each operation in parallel across multiple devices.
The partitioning over devices happens automatically based on the
propagation of the input partitioning specified in ``in_shardings`` and
the output partitioning specified in ``out_shardings``. The resources
specified in those two arguments must refer to mesh axes, as defined by
the :py:func:`jax.sharding.Mesh` context manager. Note that the mesh
definition at :func:`~pjit` application time is ignored, and the returned function
will use the mesh definition available at each call site.
Inputs to a :func:`~pjit`'d function will be automatically partitioned across devices
if they're not already correctly partitioned based on ``in_shardings``.
In some scenarios, ensuring that the inputs are already correctly pre-partitioned
can increase performance. For example, if passing the output of one
:func:`~pjit`'d function to another :func:`~pjit`’d function (or the same
:func:`~pjit`’d function in a loop), make sure the relevant
``out_shardings`` match the corresponding ``in_shardings``.
.. note::
**Multi-process platforms:** On multi-process platforms such as TPU pods,
:func:`~pjit` can be used to run computations across all available devices across
processes. To achieve this, :func:`~pjit` is designed to be used in SPMD Python
programs, where every process is running the same Python code such that all
processes run the same :func:`~pjit`'d function in the same order.
When running in this configuration, the mesh should contain devices across
all processes. All inputs arguments must be globally shaped.
``fun`` will still be executed across *all* devices in the mesh,
including those from other processes, and will be given a global view of the
data spread across multiple processes as a single array.
The SPMD model also requires that the same multi-process :func:`~pjit`'d
functions must be run in the same order on all processes, but they can be
interspersed with arbitrary operations running in a single process.
Args:
fun: Function to be compiled. Should be a pure function, as side-effects may
only be executed once. Its arguments and return value should be arrays,
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
Positional arguments indicated by ``static_argnums`` can be anything at
all, provided they are hashable and have an equality operation defined.
Static arguments are included as part of a compilation cache key, which is
why hash and equality operators must be defined.
in_shardings: Pytree of structure matching that of arguments to ``fun``,
with all actual arguments replaced by resource assignment specifications.
It is also valid to specify a pytree prefix (e.g. one value in place of a
whole subtree), in which case the leaves get broadcast to all values in
that subtree.
The ``in_shardings`` argument is optional. JAX will infer the shardings
from the input :py:class:`jax.Array`'s, and defaults to replicating the input
if the sharding cannot be inferred.
The valid resource assignment specifications are:
- :py:class:`Sharding`, which will decide how the value
will be partitioned. With this, using a mesh context manager is not
required.
- :py:obj:`None` is a special case whose semantics are:
- if the mesh context manager is *not* provided, JAX has the freedom to
choose whatever sharding it wants.
For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
- If the mesh context manager is provided, None will imply that the
value will be replicated on all devices of the mesh.
- For backwards compatibility, in_shardings still supports ingesting
:py:class:`PartitionSpec`. This option can *only* be used with the
mesh context manager.
- :py:class:`PartitionSpec`, a tuple of length at most equal to the rank
of the partitioned value. Each element can be a :py:obj:`None`, a mesh
axis or a tuple of mesh axes, and specifies the set of resources assigned
to partition the value's dimension matching its position in the spec.
The size of every dimension has to be a multiple of the total number of
resources assigned to it.
out_shardings: Like ``in_shardings``, but specifies resource
assignment for function outputs.
The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
will use GSPMD's sharding propagation to determine how to shard the outputs.
static_argnums: An optional int or collection of ints that specify which
positional arguments to treat as static (compile-time constant).
Operations that only depend on static arguments will be constant-folded in
Python (during tracing), and so the corresponding argument values can be
any Python object.
Static arguments should be hashable, meaning both ``__hash__`` and
``__eq__`` are implemented, and immutable. Calling the jitted function
with different values for these constants will trigger recompilation.
Arguments that are not arrays or containers thereof must be marked as
static.
If ``static_argnums`` is not provided, no arguments are treated as static.
static_argnames: An optional string or collection of strings specifying
which named arguments to treat as static (compile-time constant). See the
comment on ``static_argnums`` for details. If not
provided but ``static_argnums`` is set, the default is based on calling
``inspect.signature(fun)`` to find corresponding named arguments.
donate_argnums: Specify which positional argument buffers are "donated" to
the computation. It is safe to donate argument buffers if you no longer
need them once the computation has finished. In some cases XLA can make
use of donated buffers to reduce the amount of memory needed to perform a
computation, for example recycling one of your input buffers to store a
result. You should not reuse buffers that you donate to a computation, JAX
will raise an error if you try to. By default, no argument buffers are
donated.
If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
arguments are donated. If ``donate_argnums`` is not provided but
``donate_argnames`` is, or vice versa, JAX uses
:code:`inspect.signature(fun)` to find any positional arguments that
correspond to ``donate_argnames``
(or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
provided, ``inspect.signature`` is not used, and only actual
parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
be donated.
For more details on buffer donation see the
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
donate_argnames: An optional string or collection of strings specifying
which named arguments are donated to the computation. See the
comment on ``donate_argnums`` for details. If not
provided but ``donate_argnums`` is set, the default is based on calling
``inspect.signature(fun)`` to find corresponding named arguments.
keep_unused: If `False` (the default), arguments that JAX determines to be
unused by `fun` *may* be dropped from resulting compiled XLA executables.
Such arguments will not be transferred to the device nor provided to the
underlying executable. If `True`, unused arguments will not be pruned.
device: This argument is deprecated. Please put your arguments on the
device you want before passing them to jit.
Optional, the Device the jitted function will run on. (Available devices
can be retrieved via :py:func:`jax.devices`.) The default is inherited
from XLA's DeviceAssignment logic and is usually to use
``jax.devices()[0]``.
backend: This argument is deprecated. Please put your arguments on the
backend you want before passing them to jit.
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
A wrapped version of ``fun``, set up for just-in-time compilation and