diff --git a/alpa/create_state_parallel.py b/alpa/create_state_parallel.py index 6902a639f..f30be84b4 100644 --- a/alpa/create_state_parallel.py +++ b/alpa/create_state_parallel.py @@ -6,6 +6,7 @@ from jax.core import ClosedJaxpr, Var from jax.interpreters import partial_eval as pe, pxla from jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef +import numpy as np from alpa.device_mesh import ReplicatedDistributedArray, PhysicalDeviceMeshGroup from alpa.mesh_executable import (NormalMeshDriverExecutable, @@ -31,11 +32,13 @@ def __init__(self, mesh_group: PhysicalDeviceMeshGroup, pipeshard_config: PipeshardConfig, target_placement_specs: Sequence[PlacementSpec], + in_tree: PyTreeDef, out_tree: Optional[PyTreeDef] = None, static_argnums: Optional[Sequence[int]] = None): super().__init__(mesh_group=mesh_group, pipeshard_config=pipeshard_config, num_batch=1, + in_tree=in_tree, out_tree=out_tree, static_argnums=static_argnums) self.target_placement_specs = target_placement_specs @@ -58,7 +61,7 @@ def launch_on_driver(self, *args): indices = pxla.spec_to_indices(array.shape, sharding_spec) dis_array = self.mesh_group[mesh_id].shard_args_to_arrays( (array.aval,), (indices,), (sharding_spec,), - (array,))[0] + (np.asarray(array),))[0] distributed_arrays.append(dis_array) outputs[idx] = ReplicatedDistributedArray( meshes, distributed_arrays) @@ -104,7 +107,8 @@ def compile_create_state_executable(fun, in_tree, out_tree_thunk, return NormalMeshDriverExecutable(physical_mesh, hlo_module, stage_plan, avals, out_avals, - [False] * len(avals)) + [False] * len(avals), static_argnums, + in_tree, out_tree) else: # Construct a new pipelined jaxpr outvars = jaxpr.outvars @@ -126,14 +130,15 @@ def compile_create_state_executable(fun, in_tree, out_tree_thunk, # Compile a pipeshard executable with predefined output shardings pipeshard_config = compile_pipeshard_executable_internal( - new_jaxpr, None, 1, in_tree, [False] * len(avals), - [False] * len(avals), executable.mesh_group.parent, 1, "inference", + new_jaxpr, None, 1, [False] * len(avals), [False] * len(avals), + executable.mesh_group.parent, 1, "inference", AutoShardingOption(enable_auto_sharding=False), UniformStageOption(), name, output_shardings) return CreateStateExecutable(mesh_group=executable.mesh_group, pipeshard_config=pipeshard_config, target_placement_specs=placement_specs, + in_tree=in_tree, out_tree=out_tree_thunk(), static_argnums=static_argnums) diff --git a/alpa/data_loader.py b/alpa/data_loader.py index e455dd5e5..1dcb35afa 100644 --- a/alpa/data_loader.py +++ b/alpa/data_loader.py @@ -3,11 +3,11 @@ import itertools import jax -from jax.interpreters import pxla, xla +from jax.interpreters import pxla import numpy as np import ray -from alpa.device_mesh import (LocalPhysicalDeviceMesh, DistributedArray, +from alpa.device_mesh import (DistributedArray, get_global_physical_mesh, create_remote_array_refs) @@ -15,38 +15,28 @@ class DataLoader: """A driver-only dataloader that loads data on the driver process and sends the data to all workers.""" - def __init__(self, - input_iter, - sharding_specs, - physical_mesh=None, - prefetch_size=1): + def __init__(self, input_iter, placement_specs, prefetch_size=1): self.input_iter = input_iter - self.sharding_specs = sharding_specs self.prefetch_size = prefetch_size - if physical_mesh is None: - self.physical_mesh = LocalPhysicalDeviceMesh() - else: - self.physical_mesh = physical_mesh + self.physical_mesh = get_global_physical_mesh() + self.avals = [] + self.indices = [] + self.sharding_specs = [] + for ps in jax.tree_leaves(placement_specs): + assert len(ps.mesh_ids) == 1 + assert ps.mesh_ids[0] == self.physical_mesh.mesh_id + + self.avals.append(ps.aval) + self.sharding_specs.append(ps.sharding_specs[0]) + self.indices.append( + tuple(ps.sharding_specs[0].indices(ps.aval.shape))) self.queue = collections.deque() - self.first_iter = True - self.avals = None - self.indices = None def enqueue(self, num_batches): for batch in itertools.islice(self.input_iter, num_batches): flatten_args, tree = jax.tree_flatten(batch) - - # Cache meta info - if self.first_iter: - self.first_iter = False - self.avals = [xla.abstractify(a) for a in flatten_args] - self.indices = [ - tuple(spec.indices(aval.shape)) - for spec, aval in zip(self.sharding_specs, self.avals) - ] - new_args = self.physical_mesh.shard_args_to_arrays( self.avals, self.indices, self.sharding_specs, flatten_args) self.queue.append(jax.tree_unflatten(tree, new_args)) @@ -112,14 +102,19 @@ def __init__(self, batch_size, num_samples, input_iter_func, - avals, - sharding_specs, - physical_mesh, + placement_specs, prefetch_size=1): - indices = [ - tuple(np.ravel(spec.indices(aval.shape))) - for spec, aval in zip(sharding_specs, avals) - ] + physical_mesh = get_global_physical_mesh() + avals = [] + sharding_specs = [] + indices = [] + for ps in jax.tree_leaves(placement_specs): + avals.append(ps.aval) + assert len(ps.mesh_ids) == 1 + assert ps.mesh_ids[0] == physical_mesh.mesh_id + sharding_specs.append(ps.sharding_specs[0]) + indices.append(np.ravel(ps.sharding_specs[0].indices( + ps.aval.shape))) self.uuid = next_mesh_data_loader_uuid() self.physical_mesh = physical_mesh diff --git a/alpa/device_mesh.py b/alpa/device_mesh.py index afe71b095..88d5294da 100644 --- a/alpa/device_mesh.py +++ b/alpa/device_mesh.py @@ -809,7 +809,7 @@ def __init__(self, devices: Sequence["Device"] = None): self.devices = devices if devices is not None else xb.local_devices() self.num_hosts = 1 self.num_devices_per_host = len(self.devices) - self.mesh_id = 0 + self.mesh_id = -1 self.device_strs = [] self.operation_executables = {} diff --git a/alpa/global_env.py b/alpa/global_env.py index 70a80c13a..ff44317d2 100644 --- a/alpa/global_env.py +++ b/alpa/global_env.py @@ -78,6 +78,7 @@ def __init__(self): ########## Options of logging ########## self.print_compilation_time = False + self.print_auto_layer_stats = False global_config = GlobalConfig() diff --git a/alpa/mesh_executable.py b/alpa/mesh_executable.py index 668d9bfa8..ace169d20 100644 --- a/alpa/mesh_executable.py +++ b/alpa/mesh_executable.py @@ -61,7 +61,19 @@ def launch_on_driver(self, *args, **kwargs): raise NotImplementedError() def get_input_placement_specs(self): - """Return the preferred placement specs for input arguments.""" + """ + Return the preferred placement specs for input arguments. + The return value is a pytree of PlacementSpec + with the same structure as the input pytree. + """ + raise NotImplementedError() + + def get_output_placement_specs(self): + """ + Return the preferred placement specs for outputs. + The return value is a pytree of PlacementSpec + with the same structure as the output pytree. + """ raise NotImplementedError() def preshard_dynamic_args(self, *args): @@ -160,6 +172,15 @@ def sync_func_worker(): return sync_func_worker +def wrap_to_placement_spec_tree(physical_mesh, avals, sharding_specs, pytree): + """Wrap avals and sharding specs to a pytree of placement specs.""" + placement_specs = [ + PlacementSpec(aval, (physical_mesh.mesh_id,), (sharding_spec,)) + for aval, sharding_spec in zip(avals, sharding_specs) + ] + return tree_unflatten(pytree, placement_specs) + + class NormalMeshDriverExecutable(MeshDriverExecutable): """The driver part of a normal mesh executable.""" @@ -284,13 +305,24 @@ def launch_on_driver(self, *args, **kwargs): return self.outs_handler(output_bufs) def get_input_placement_specs(self): - """Return the preferred placement specs for input arguments.""" - placement_specs = [ - PlacementSpec(aval, (self.physical_mesh.mesh_id,), - (sharding_spec,)) for aval, sharding_spec in zip( - self.avals, self.input_sharding_specs) - ] - return tree_unflatten(self.in_tree, placement_specs) + """ + Return the preferred placement specs for input arguments. + The return value is a pytree of PlacementSpec + with the same structure as the input pytree. + """ + return wrap_to_placement_spec_tree(self.physical_mesh, self.avals, + self.input_sharding_specs, + self.in_tree) + + def get_output_placement_specs(self): + """ + Return the preferred placement specs for outputs. + The return value is a pytree of PlacementSpec + with the same structure as the output pytree. + """ + return wrap_to_placement_spec_tree(self.physical_mesh, self.out_avals, + self.output_sharding_specs, + self.out_tree) def preshard_dynamic_args(self, *args): """Pre-shard the input arguments.""" @@ -505,6 +537,7 @@ def __init__(self, out_avals, physical_mesh.num_devices, logical_mesh_shape)) + self.output_sharding_specs = output_sharding_specs num_grads = len(grad_avals) assert accumulate_grad_input_sharding_specs[ -num_grads:] == grad_sharding_specs @@ -701,13 +734,24 @@ def launch_on_driver(self, *args): return self.outs_handler(output_bufs) def get_input_placement_specs(self): - """Return the preferred placement specs for input arguments.""" - placement_specs = [ - PlacementSpec(aval, (self.physical_mesh.mesh_id,), - (sharding_spec,)) for aval, sharding_spec in zip( - self.avals, self.global_arg_sharding_specs) - ] - return tree_unflatten(self.in_tree, placement_specs) + """ + Return the preferred placement specs for input arguments. + The return value is a pytree of PlacementSpec + with the same structure as the input pytree. + """ + return wrap_to_placement_spec_tree(self.physical_mesh, self.avals, + self.global_arg_sharding_specs, + self.in_tree) + + def get_output_placement_specs(self): + """ + Return the preferred placement specs for outputs. + The return value is a pytree of PlacementSpec + with the same structure as the output pytree. + """ + return wrap_to_placement_spec_tree(self.physical_mesh, self.out_avals, + self.output_sharding_specs, + self.out_tree) def get_total_allocation_size(self): """Get the total allocated memory size of this executable.""" diff --git a/alpa/pipeline_parallel/compile_executable.py b/alpa/pipeline_parallel/compile_executable.py index cc98c2a4a..2ac39561f 100644 --- a/alpa/pipeline_parallel/compile_executable.py +++ b/alpa/pipeline_parallel/compile_executable.py @@ -74,14 +74,15 @@ def compile_pipeshard_executable( debug_compilation_time("trace") pipeshard_config = compile_pipeshard_executable_internal( - closed_jaxpr, full_batch_closed_jaxpr, micro_batch_size, in_tree, - donated_invars, batch_invars, virtual_mesh, num_microbatch, - pipeline_schedule, default_as_option, stage_option, name_base, None) + closed_jaxpr, full_batch_closed_jaxpr, micro_batch_size, donated_invars, + batch_invars, virtual_mesh, num_microbatch, pipeline_schedule, + default_as_option, stage_option, name_base, None) executable = PipeshardDriverExecutable( mesh_group=virtual_mesh.launched_physical_mesh_group, pipeshard_config=pipeshard_config, num_batch=num_microbatch, + in_tree=in_tree, out_tree=out_tree_thunk(), static_argnums=static_argnums) debug_compilation_time("driver executable") @@ -91,11 +92,10 @@ def compile_pipeshard_executable( def compile_pipeshard_executable_internal( closed_jaxpr: ClosedJaxpr, full_batch_closed_jaxpr: Optional[ClosedJaxpr], micro_batch_size: int, - in_tree: PyTreeDef, donated_invars: Sequence[bool], - batch_invars: Sequence[bool], virtual_mesh: VirtualPhysicalMesh, - num_microbatch: int, pipeline_schedule: str, - default_as_option: AutoShardingOption, stage_option: StageOption, - name_base: str, + donated_invars: Sequence[bool], batch_invars: Sequence[bool], + virtual_mesh: VirtualPhysicalMesh, num_microbatch: int, + pipeline_schedule: str, default_as_option: AutoShardingOption, + stage_option: StageOption, name_base: str, output_shardings: Optional[Sequence[pxla.ShardingSpec]]): global_invars = closed_jaxpr.jaxpr.invars global_outvars = closed_jaxpr.jaxpr.outvars @@ -224,8 +224,7 @@ def compile_pipeshard_executable_internal( schedule=schedule, is_batch=batch_invars, num_batch=num_microbatch, - flop_count=total_flops, - in_tree=in_tree).compile() + flop_count=total_flops).compile() debug_compilation_time("runtime emitter") return pipeshard_config diff --git a/alpa/pipeline_parallel/layer_construction.py b/alpa/pipeline_parallel/layer_construction.py index 1c0030078..214e1ccce 100644 --- a/alpa/pipeline_parallel/layer_construction.py +++ b/alpa/pipeline_parallel/layer_construction.py @@ -13,9 +13,11 @@ new_jaxpr_eqn, gensym, raise_to_shaped, get_aval) from jax.interpreters.partial_eval import remat_call_p +from alpa.global_env import global_config from alpa.pipeline_parallel.layer_stats import (global_invar_size, is_nontrivial, eqn_flops, - heavy_count) + heavy_count, + log_layer_slicing_stats) from alpa.pipeline_parallel.primitive_def import (pipeline_p, mark_pipeline_jaxpreqn) from alpa.util import (clone_jaxpr, slices_to_jaxpr, OrderedSet, @@ -448,6 +450,8 @@ def wrapped(*args): eps, costs, cost_criteria=cost_criteria) + if global_config.print_auto_layer_stats: + log_layer_slicing_stats(jaxpr, sliced_eqns) else: sliced_eqns = slice_eqns_by_layer_boundary(jaxpr) diff --git a/alpa/pipeline_parallel/pipeshard_executable.py b/alpa/pipeline_parallel/pipeshard_executable.py index d1e93cd9d..46d43f3d2 100644 --- a/alpa/pipeline_parallel/pipeshard_executable.py +++ b/alpa/pipeline_parallel/pipeshard_executable.py @@ -44,6 +44,7 @@ def __init__(self, mesh_group: PhysicalDeviceMeshGroup, pipeshard_config: PipeshardConfig, num_batch: int, + in_tree: PyTreeDef, out_tree: Optional[PyTreeDef] = None, static_argnums: Optional[Sequence[int]] = None): ##### Input arguments ##### @@ -51,6 +52,7 @@ def __init__(self, self.num_mesh = len(mesh_group) self.num_batch = num_batch self.static_argnums = static_argnums + self.in_tree = in_tree self.out_tree = out_tree ##### For debugging and serialization ##### @@ -58,6 +60,7 @@ def __init__(self, self.schedule = pipeshard_config.schedule self.flop_count = pipeshard_config.flop_count self.input_placement_specs = pipeshard_config.input_placement_specs + self.output_placement_specs = pipeshard_config.output_placement_specs # List[stage_idx -> str] self.fully_optimized_hlo_texts = [] self.sharding_annotated_hlo_texts = ( @@ -192,8 +195,20 @@ def launch_on_driver(self, *args): return self.outs_handler(self.mesh_group, output_bufs) def get_input_placement_specs(self): - """Return the preferred placement specs for input arguments.""" - return self.input_placement_specs + """ + Return the preferred placement specs for input arguments. + The return value is a pytree of PlacementSpec + with the same structure as the input pytree. + """ + return tree_unflatten(self.in_tree, self.input_placement_specs) + + def get_output_placement_specs(self): + """ + Return the preferred placement specs for outputs. + The return value is a pytree of PlacementSpec + with the same structure as the output pytree. + """ + return tree_unflatten(self.out_tree, self.output_placement_specs) def __call__(self, *args): """Fast call without signature matching.""" diff --git a/alpa/pipeline_parallel/runtime_emitter.py b/alpa/pipeline_parallel/runtime_emitter.py index be21055b9..a03601d79 100644 --- a/alpa/pipeline_parallel/runtime_emitter.py +++ b/alpa/pipeline_parallel/runtime_emitter.py @@ -5,7 +5,6 @@ import logging from typing import Any, Callable, Dict, Optional, Sequence, Union, Set -from jax._src.tree_util import PyTreeDef, tree_unflatten from jax.core import Var from jax.interpreters import pxla from jax.lib import xla_bridge as xb @@ -250,8 +249,9 @@ class PipeshardConfig: # Output configs output_local_uuid_list: Sequence[Sequence[int]] outs_handler: Callable - # Others + # Others (debug info) input_placement_specs: Sequence[PlacementSpec] + output_placement_specs: Sequence[PlacementSpec] sharding_annotated_hlo_texts: Sequence[str] flop_count: int @@ -266,7 +266,7 @@ def __init__(self, *, stages: Sequence[XlaShardedPipelineComputation], Var], mesh_group: PhysicalDeviceMeshGroup, schedule: PipelineSchedule, is_batch: Sequence[bool], - num_batch: int, in_tree: PyTreeDef, flop_count: int): + num_batch: int, flop_count: int): ##### Input arguments ##### self.stages = stages self.global_invars = global_invars @@ -278,7 +278,6 @@ def __init__(self, *, stages: Sequence[XlaShardedPipelineComputation], self.schedule = schedule self.is_batch = is_batch self.num_batch = num_batch - self.in_tree = in_tree self.flop_count = flop_count self.sharding_annotated_hlo_texts = [x.get_hlo_text() for x in stages] @@ -416,8 +415,9 @@ def compile(self): # Compile information for outputs output_local_uuid_list, mesh_output_indices, output_spec_list = ( self._compile_collect_outputs()) - outs_handler = self._get_outs_handler(mesh_output_indices, - output_spec_list) + outs_handler, output_placement_specs = self._get_outs_handler( + mesh_output_indices, output_spec_list) + # Add gradient accumulation buffer reduced_var_uuid_lists = [] for mesh_idx in range(num_mesh): @@ -459,6 +459,7 @@ def compile(self): outs_handler, # Others input_placement_specs, + output_placement_specs, self.sharding_annotated_hlo_texts, self.flop_count) @@ -919,6 +920,7 @@ def _get_outs_handler(self, mesh_output_indices, output_spec_list): outvar_index_on_mesh_list = [] spec_list = [] indices_list = [] + output_placement_specs = [] # Generate cached info for i, aval in enumerate(avals): @@ -931,6 +933,9 @@ def _get_outs_handler(self, mesh_output_indices, output_spec_list): outvar_index_on_mesh_list.append(outvar_index_on_mesh) spec_list.append(spec) indices_list.append(pxla.spec_to_indices(aval.shape, spec)) + + output_placement_specs.append( + PlacementSpec(aval, (mesh_idx_list[-1],), (spec_list[-1],))) else: # for RepliatedDistributedArray mesh_idx_list.append([]) @@ -947,6 +952,9 @@ def _get_outs_handler(self, mesh_output_indices, output_spec_list): spec_list[-1].append(spec) indices_list[-1].append( pxla.spec_to_indices(aval.shape, spec)) + output_placement_specs.append( + PlacementSpec(aval, tuple(mesh_idx_list[-1]), + tuple(spec_list[-1]))) def outs_handler(mesh_group, refs): ret = [] @@ -980,12 +988,10 @@ def outs_handler(mesh_group, refs): ret.append(arr) return ret - return outs_handler + return outs_handler, output_placement_specs def _compile_input_placement_spec(self, mesh_arg_indices, input_shard_specs): - assert self.in_tree is not None - # build spec_arr: List[flatten global index -> PlacementSpec] spec_arr = [None] * len(self.is_batch) for mesh_idx, physical_mesh in enumerate(self.mesh_group): @@ -1002,7 +1008,7 @@ def _compile_input_placement_spec(self, mesh_arg_indices, old_val.mesh_ids + (physical_mesh.mesh_id,), old_val.sharding_specs + (shard_spec,)) - return tree_unflatten(self.in_tree, spec_arr) + return spec_arr # TODO(yonghao): set empty buffer is not compatiable with local allgather @staticmethod diff --git a/alpa/torch/optim/adam.py b/alpa/torch/optim/adam.py index 996276e35..7357fd401 100644 --- a/alpa/torch/optim/adam.py +++ b/alpa/torch/optim/adam.py @@ -27,7 +27,7 @@ def optim_gen(params): def optim_func(params, optim_state, params_grad): for k in params: params[k] = params[k] + params_grad[k] * lr - optim_state[k] = optim_state[k] + 1 + optim_state[k] = optim_state[k] + params_grad[k] return params, optim_state optim_state = copy.deepcopy(params) diff --git a/examples/imagenet/train.py b/examples/imagenet/train.py index a2d8cf335..77e313470 100644 --- a/examples/imagenet/train.py +++ b/examples/imagenet/train.py @@ -165,12 +165,12 @@ def eval_step(state, batch): def create_input_iter(dataset_builder, batch_size, image_size, dtype, - sharding_specs, physical_mesh, train, cache): + placement_specs, train, cache): ds = input_pipeline.create_split( dataset_builder, batch_size, image_size=image_size, dtype=dtype, train=train, cache=cache) it = map(lambda xs: jax.tree_map(lambda x: x._numpy(), xs), ds) - it = alpa.DataLoader(it, sharding_specs, physical_mesh=physical_mesh, prefetch_size=4) + it = alpa.DataLoader(it, placement_specs, prefetch_size=4) return it @@ -306,17 +306,16 @@ def train_and_evaluate(config: ml_collections.ConfigDict, "label": jax.core.ShapedArray((config.batch_size,), jnp.int32), } executable = p_train_step.get_executable(state, batch) - physical_mesh = executable.physical_mesh logging.info('Initial compilation completed.') - sharding_specs = executable.input_sharding_specs[-2:] + batch_placement_specs = executable.get_input_placement_specs()[1] train_iter = create_input_iter( dataset_builder, local_batch_size, image_size, input_dtype, - sharding_specs, physical_mesh, train=True, cache=config.cache) + batch_placement_specs, train=True, cache=config.cache) eval_iter = create_input_iter( dataset_builder, local_batch_size, image_size, input_dtype, - sharding_specs, physical_mesh, train=False, cache=config.cache) + batch_placement_specs, train=False, cache=config.cache) train_metrics = [] hooks = [] diff --git a/tests/runtime/test_create_state.py b/tests/runtime/test_create_state.py index f13d7f245..4874b0d8a 100644 --- a/tests/runtime/test_create_state.py +++ b/tests/runtime/test_create_state.py @@ -69,18 +69,16 @@ def create_state(): state = train_step(state, batch) if isinstance(method, ShardParallel): - actual = create_state.get_last_executable().output_sharding_specs - if method.num_micro_batches == None: # NormalMeshDriverExecutable - expected = train_step.get_last_executable( - ).input_sharding_specs[:len(actual)] - else: # GradAccMeshDriverExecutable - expected = train_step.get_last_executable( - ).global_arg_sharding_specs[:len(actual)] - for x, y in zip(actual, expected): - assert x == y, f"{x} vs. {y}" + actual = jax.tree_flatten(create_state.get_last_executable(). + get_output_placement_specs())[0] + expected = jax.tree_flatten( + train_step.get_last_executable().get_input_placement_specs() + [0])[0] + assert actual == expected elif isinstance(method, PipeshardParallel): # The assertion is already in CreateStateExecutable::launch_on_driver - pass + # Here, we just call the function to test whether it is runnable. + train_step.get_last_executable().get_output_placement_specs() def test_shard_parallel(self): method = ShardParallel(num_micro_batches=None) diff --git a/tests/runtime/test_data_loader.py b/tests/runtime/test_data_loader.py index 8aaeb726f..21a841159 100644 --- a/tests/runtime/test_data_loader.py +++ b/tests/runtime/test_data_loader.py @@ -9,6 +9,7 @@ from jax.interpreters import pxla from alpa import init, MeshDriverDataLoader +from alpa.parallel_plan import PlacementSpec from alpa.device_mesh import get_global_physical_mesh from alpa.testing import (assert_allclose, data_loader_test_input_iter_func as input_iter_func) @@ -21,18 +22,20 @@ def setUp(self): self.physical_mesh = get_global_physical_mesh(create_if_not_exist=True) def run_test(self, sharding_specs): - batch_size = 64 num_samples = 256 avals = [ jax.core.ShapedArray((batch_size, 32), jnp.float32), jax.core.ShapedArray((batch_size,), jnp.int32) ] + placement_specs = [ + PlacementSpec(aval, (self.physical_mesh.mesh_id,), (sharding_spec,)) + for aval, sharding_spec in zip(avals, sharding_specs) + ] prefetch_size = 2 data_loader = MeshDriverDataLoader(batch_size, num_samples, - input_iter_func, avals, - sharding_specs, self.physical_mesh, + input_iter_func, placement_specs, prefetch_size) expected_data_loader = input_iter_func(0, num_samples, batch_size)