diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index ed7a6f86a35e..5039269b0a6f 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -333,13 +333,18 @@ def get_axis_primitive_batcher(self, primitive, frame): frame.size, frame.name, frame.main_trace.trace_type) def get_frame(self, vals, dims) -> core.AxisEnvFrame: - frame = core.axis_frame(self.axis_name, self.main) - assert frame.main_trace is self.main if any(d is not not_mapped for d in dims): sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths) for x, d in zip(vals, dims) if d is not not_mapped) - data_axis_size, = core.dedup_referents(sizes) - assert data_axis_size == frame.size + axis_size, = core.dedup_referents(sizes) + else: + axis_size = None # can't be inferred from data + if self.axis_name is core.no_axis_name: + assert axis_size is not None # must be inferrable from data + return core.AxisEnvFrame(self.axis_name, axis_size, self.main) + frame = core.axis_frame(self.axis_name, self.main) + assert axis_size is None or axis_size == frame.size, (axis_size, frame.size) + assert frame.main_trace is self.main return frame def process_primitive(self, primitive, tracers, params):