Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add devices params to PPO learner #309

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions acme/agents/jax/ppo/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

"""Learner for the PPO agent."""

from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple
from typing import Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple

from absl import logging
import acme
from acme import types
from acme.agents.jax.ppo import networks
Expand Down Expand Up @@ -103,10 +104,18 @@ def __init__(
metrics_logging_period: int = 100,
pmap_axis_name: str = 'devices',
obs_normalization_fns: Optional[normalization.NormalizationFns] = None,
devices: Optional[Sequence[jax.Device]] = None,
):
self.local_learner_devices = jax.local_devices()
self.num_local_learner_devices = jax.local_device_count()
self.learner_devices = jax.devices()
local_devices = jax.local_devices()
process_id = jax.process_index()
logging.info('Learner process id: %s. Devices passed: %s', process_id,
devices)
logging.info('Learner process id: %s. Local devices from JAX API: %s',
process_id, local_devices)
self.learner_devices = devices or jax.devices()
self.local_learner_devices = [d for d in self.learner_devices if d in local_devices]
self.num_local_learner_devices = len(self.local_learner_devices)

self.num_epochs = num_epochs
self.num_minibatches = num_minibatches
self.metrics_logging_period = metrics_logging_period
Expand Down