Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 537067965
  • Loading branch information
mattjj authored and jax authors committed Jun 1, 2023
1 parent e13cfe7 commit c8311c6
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,18 @@ def get_axis_primitive_batcher(self, primitive, frame):
frame.size, frame.name, frame.main_trace.trace_type)

def get_frame(self, vals, dims) -> core.AxisEnvFrame:
frame = core.axis_frame(self.axis_name, self.main)
assert frame.main_trace is self.main
if any(d is not not_mapped for d in dims):
sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths)
for x, d in zip(vals, dims) if d is not not_mapped)
data_axis_size, = core.dedup_referents(sizes)
assert data_axis_size == frame.size
axis_size, = core.dedup_referents(sizes)
else:
axis_size = None # can't be inferred from data
if self.axis_name is core.no_axis_name:
assert axis_size is not None # must be inferrable from data
return core.AxisEnvFrame(self.axis_name, axis_size, self.main)
frame = core.axis_frame(self.axis_name, self.main)
assert axis_size is None or axis_size == frame.size, (axis_size, frame.size)
assert frame.main_trace is self.main
return frame

def process_primitive(self, primitive, tracers, params):
Expand Down

0 comments on commit c8311c6

Please sign in to comment.