Skip to content

Commit

Permalink
Fix PyType errors in input normalization wrapper.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 389885508
Change-Id: I4b6d5984e933a61a64ce92732c9a81b3087f2058
  • Loading branch information
bshahr authored and tkiela1 committed Aug 10, 2021
1 parent 1bb5a0b commit eee201a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions acme/agents/jax/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def make_actor(
backend='cpu')


# Have to disable pytype invalid-annotation error, as it fails on Kokorox with
# Have to disable pytype invalid-annotation error, as it fails on Kokoro with
# Invalid type annotation 'TrainingState': Appears only once in the signature.
def wrap_learner_core(
learner_core: learner_core_lib.LearnerCore[reverb.ReplaySample,
Expand All @@ -273,7 +273,7 @@ def wrap_learner_core(

State = NormalizationLearnerWrapperState[TrainingState] # pylint: disable=invalid-name

def init(random_key: PRNGKey) -> State:
def init(random_key: PRNGKey): # Returns a State.
return NormalizationLearnerWrapperState(
learner_core.init(random_key),
running_statistics.init_state(environment_spec.observations))
Expand Down

0 comments on commit eee201a

Please sign in to comment.