Skip to content

Commit

Permalink
modify "copied from" comment
Browse files Browse the repository at this point in the history
  • Loading branch information
HeegyuKim committed Oct 6, 2023
1 parent 194d4f0 commit a670443
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/transformers/models/gpt_neox/modeling_flax_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,6 @@ def __call__(
return (hidden_states,) + attn_outputs[1:]


# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo -> GPTNeoX
class FlaxGPTNeoXPreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
Expand All @@ -459,6 +458,7 @@ def __init__(
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel.init_weights
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
Expand All @@ -479,6 +479,7 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
else:
return random_params

# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel.init_cache
def init_cache(self, batch_size, max_length):
r"""
Args:
Expand Down Expand Up @@ -569,7 +570,6 @@ def __call__(
return outputs


# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoBlockCollection with GPTNeo -> GPTNeoX
class FlaxGPTNeoXBlockCollection(nn.Module):
config: GPTNeoXConfig
dtype: jnp.dtype = jnp.float32
Expand Down Expand Up @@ -737,7 +737,6 @@ def __call__(
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)


# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoForCausalLM with GPTNeo -> GPTNeoX
@add_start_docstrings(
"""
The GPTNeoX Model transformer with a language modeling head on top.
Expand Down Expand Up @@ -768,6 +767,7 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O
"position_ids": position_ids,
}

# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoForCausalLM.update_inputs_for_generation
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
Expand Down

0 comments on commit a670443

Please sign in to comment.