Skip to content

Commit

Permalink
Formating + vla.yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene committed Oct 7, 2024
1 parent bd49f7d commit ff4306d
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 147 deletions.
3 changes: 1 addition & 2 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def reset(self):

@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
"""
"""Select a single action given environment observations."""
self.eval()

batch = self.normalize_inputs(batch)
Expand Down
30 changes: 20 additions & 10 deletions lerobot/common/policies/vla/configuration_qwen2_vl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -13,18 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2VL model configuration"""

import os
from typing import Union
from transformers.utils import logging # Using standard Python logging module instead of `transformers.utils.logging`
from transformers.configuration_utils import PretrainedConfig

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import (
logging, # Using standard Python logging module instead of `transformers.utils.logging`
)

logger = logging.get_logger(__name__)


def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: set | None = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
rope_type = rope_scaling.get(
"rope_type", rope_scaling.get("type", None)
) # BC: "rope_type" was originally "type"
required_keys = {"rope_type"}
received_keys = set(rope_scaling.keys())
# _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
Expand Down Expand Up @@ -59,6 +63,7 @@ def rope_config_validation(config: PretrainedConfig, ignore_keys: set | None = N
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
)


class Qwen2VLVisionConfig(PretrainedConfig):
model_type = "qwen2_vl"

Expand Down Expand Up @@ -89,30 +94,35 @@ def __init__(
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size


@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)

config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

if config_dict.get("model_type") == "qwen2_vl":
config_dict = config_dict["vision_config"]

if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)



class Qwen2VLConfig(PretrainedConfig):
r"""
A simplified version of the Qwen2VL model configuration class without the `transformers` dependencies.
"""

model_type = "qwen2_vl"
keys_to_ignore_at_inference = ["past_key_values"]

Expand All @@ -122,7 +132,7 @@ def __init__(
hidden_size=8192,
intermediate_size=29568,
num_hidden_layers=80,
num_decoder_layers = 1,
num_decoder_layers=1,
num_attention_heads=64,
num_key_value_heads=8,
# dim_feedforward = 3200,
Expand Down Expand Up @@ -166,7 +176,7 @@ def __init__(
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers
self.pad_token_id = pad_token_id
self.pad_token_id = pad_token_id
self.pruned_heads = pruned_heads or {}
self.rope_scaling = rope_scaling
self.num_decoder_layers = num_decoder_layers
Expand Down
44 changes: 22 additions & 22 deletions lerobot/common/policies/vla/configuration_vla.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from transformers.configuration_utils import PretrainedConfig


@dataclass
Expand Down Expand Up @@ -119,27 +118,27 @@ class VLAConfig:
# Architecture.

# Language + Main transformer
vocab_size: int =152064
hidden_size: int =8192
intermediate_size: int =29568
num_hidden_layers: int =80
num_decoder_layers : int = 1
vocab_size: int = 152064
hidden_size: int = 8192
intermediate_size: int = 29568
num_hidden_layers: int = 80
num_decoder_layers: int = 1
attn_implementation: str = "eager"
num_attention_heads: int =64
num_key_value_heads: int =8
dim_feedforward : int = 3200
hidden_act: str ="silu"
pad_token_id:int =0
max_position_embeddings:int =32768
initializer_range: float =0.02
rms_norm_eps: float =1e-05
use_cache: bool=True
tie_word_embeddings: bool=False
rope_theta: float =1000000.0
use_sliding_window: bool=False
sliding_window=4096
max_window_layers=80
attention_dropout=0.0
num_attention_heads: int = 64
num_key_value_heads: int = 8
dim_feedforward: int = 3200
hidden_act: str = "silu"
pad_token_id: int = 0
max_position_embeddings: int = 32768
initializer_range: float = 0.02
rms_norm_eps: float = 1e-05
use_cache: bool = True
tie_word_embeddings: bool = False
rope_theta: float = 1000000.0
use_sliding_window: bool = False
sliding_window = 4096
max_window_layers = 80
attention_dropout = 0.0
rope_scaling: dict = field(
default_factory=lambda: {
"type": "mrope",
Expand Down Expand Up @@ -173,7 +172,8 @@ class VLAConfig:
"spatial_merge_size": 2,
"temporal_patch_size": 2,
"attn_implementation": "eager",
})
}
)

def __post_init__(self):
"""Input validation (not exhaustive)."""
Expand Down
67 changes: 46 additions & 21 deletions lerobot/common/policies/vla/modeling_language.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss, LayerNorm
from lerobot.common.policies.vla.configuration_vla import VLAConfig
from transformers.modeling_outputs import ModelOutput, BaseModelOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, StaticCache

from transformers.cache_utils import Cache
from transformers.modeling_outputs import ModelOutput
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.utils import logging

from lerobot.common.policies.vla.configuration_vla import VLAConfig

logger = logging.get_logger(__name__)


@dataclass
class Qwen2VLCausalLMOutputWithPast(ModelOutput):
"""
Expand Down Expand Up @@ -55,6 +54,7 @@ class Qwen2VLCausalLMOutputWithPast(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
rope_deltas: Optional[torch.LongTensor] = None


class Qwen2VLRotaryEmbedding(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -111,10 +111,14 @@ def _dynamic_frequency_update(self, position_ids, device):
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.register_buffer(
"inv_freq", inv_freq, persistent=False
) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
if (
seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len
): # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down Expand Up @@ -142,12 +146,14 @@ def forward(self, x, position_ids):

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
Expand Down Expand Up @@ -192,6 +198,7 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
Expand Down Expand Up @@ -229,7 +236,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
Expand All @@ -245,6 +254,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

return causal_mask


# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
Expand Down Expand Up @@ -293,6 +303,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class Qwen2VLAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
Expand Down Expand Up @@ -346,7 +357,9 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand All @@ -372,31 +385,39 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings

query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

# Fix precision issues in Qwen2-VL float16 inference
# Replace inf values with zeros in attention weights to prevent NaN propagation
if query_states.dtype == torch.float16:
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
attn_weights = torch.where(
torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights
)

# upcast attention to fp32
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
Expand All @@ -417,12 +438,14 @@ def forward(

return attn_output, attn_weights, past_key_value


QWEN2_VL_ATTENTION_CLASSES = {
"eager": Qwen2VLAttention,
}


class Qwen2VLDecoderLayer(nn.Module):
# TODO(rcadene, dana): update config type VLAConfig
# TODO(rcadene, dana): update config type VLAConfig
def __init__(self, config: VLAConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
Expand All @@ -448,7 +471,9 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -504,4 +529,4 @@ def forward(
if use_cache:
outputs += (present_key_value,)

return outputs
return outputs
Loading

0 comments on commit ff4306d

Please sign in to comment.