diff --git a/lerobot/common/policies/vla/modeling_vla.py b/lerobot/common/policies/vla/modeling_vla.py index ded833f0c..3a964e766 100644 --- a/lerobot/common/policies/vla/modeling_vla.py +++ b/lerobot/common/policies/vla/modeling_vla.py @@ -12,14 +12,7 @@ from transformers.modeling_utils import PreTrainedModel from lerobot.common.policies.normalize import Normalize, Unnormalize -from lerobot.common.policies.vla.configuration_qwen2_vl import Qwen2VLConfig from lerobot.common.policies.vla.configuration_vla import VLAConfig -from lerobot.common.policies.vla.modeling_language import ( - Qwen2RMSNorm, - Qwen2VLDecoderLayer, - Qwen2VLRotaryEmbedding, -) -from lerobot.common.policies.vla.modeling_vision import Qwen2VisionTransformerPretrainedModel from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration class VLAPolicy( @@ -61,13 +54,12 @@ def __init__( self.unnormalize_outputs = Unnormalize( config.output_shapes, config.output_normalization_modes, dataset_stats ) - - self.language_model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype=torch.float16, device_map = 'cuda') self.device= self.language_model.device self.model = VLA(config).to(self.device) self.processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")# Updated Qwen2VL without loss and lm_head + self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] @@ -92,7 +84,9 @@ def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: processed_inputs = self.processor( text=batch["prompt"], videos=list(batch["observation.images"]), return_tensors="pt", padding=True, do_rescale=False ).to(self.device) + processed_inputs["pixel_values_videos"] = processed_inputs["pixel_values_videos"].to(self.device).to(torch.float16) + breakpoint() # Forward pass through Llava (to get hidden states) llava_output = self.language_model( # Calling the Llava model inside VLA **processed_inputs, @@ -101,6 +95,7 @@ def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: ) hidden_states = llava_output.hidden_states[-1] # Use last layer's hidden state + hidden_states = hidden_states[:, -4:, :] #make 4 a config parameter # Pass the hidden states to the VLA model for action decoding predicted_actions = self.model(hidden_states) @@ -112,6 +107,7 @@ def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: return self._action_queue.popleft() def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + #batch = self.normalize_inputs(batch) if len(self.expected_image_keys) > 0: @@ -132,6 +128,7 @@ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: output_hidden_states=True ) hidden_states = llava_output.hidden_states[-1] + hidden_states = hidden_states[:, -4:, :] #hidden_states.to(dtype=torch.float16).to(self.device) breakpoint() # Forward pass through VLA @@ -141,7 +138,7 @@ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: if "action" in batch: true_actions = batch["action"] breakpoint() - l2_loss = F.mse_loss(predicted_actions, true_actions.view(predicted_actions.shape), reduction="mean") + l2_loss = F.mse_loss(predicted_actions, true_actions, reduction="mean") loss_dict["l2_loss"] = l2_loss.item() loss_dict["loss"] = l2_loss @@ -195,6 +192,7 @@ def forward( if self.pre_norm: x = self.norm1(x) q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) + breakpoint() # Self-attention x = self.self_attn(q, k, value=x)[0] # select just the output, not attention weights @@ -261,22 +259,50 @@ def __init__(self, config: VLAConfig): # Initialize the Qwen2VLForConditionalGeneration and ActionDecoder #qwen2_vl_config = make_qwen2_vl_config(config) - + self.chunk_size = config.chunk_size + self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.hidden_size) self.action_decoder = ActionDecoder(config) # Use the updated ActionDecoder self.action_head = nn.Linear(config.hidden_size, config.output_shapes["action"][0]) + + self.half() def forward(self, hidden_states): - # Get the hidden states from the Qwen2VLForConditionalGeneration model - # Use the last hidden state for action decoding - #action_embedding = hidden_states[:, -1:, :] # Assuming last token is the action token - # - encoder_out = hidden_states - breakpoint() - # Decode the action - action_logits = self.action_decoder(x=hidden_states, encoder_out=encoder_out) + """ + Forward pass to compute action logits using hidden states from Qwen2VL (Llava). + + Args: + hidden_states: Tensor of shape [batch_size, seq_len, hidden_size] from Llava model. + + Returns: + action_logits: Tensor of predicted actions. + """ + batch_size = hidden_states.size(0) # Ensure batch size is extracted + seq_len = hidden_states.size(1) # Sequence length of hidden states + hidden_size = hidden_states.size(2) # Hidden size + + # Ensure encoder_out has the correct shape [chunk_size, batch_size, seq_len, hidden_size] + # Repeat the encoder output for chunk size across the batch dimension + #encoder_out = hidden_states.unsqueeze(0).repeat(self.chunk_size, 1, 1, 1) # [chunk_size, batch_size, seq_len, hidden_size] + #encoder_out = encoder_out.view(self.chunk_size * seq_len, batch_size, hidden_size) + + # Repeat the decoder input (hidden states) as well, maintaining batch and hidden size + repeated_hidden_states = hidden_states.unsqueeze(0).repeat(self.chunk_size//seq_len, 1, 1, 1) # [chunk_size, batch_size, seq_len, hidden_size] + breakpoint() + repeated_hidden_states = repeated_hidden_states.view(self.chunk_size, batch_size, hidden_size) + + # Generate positional embeddings for the decoder + decoder_pos_embeddings = self.decoder_pos_embed.weight.unsqueeze(1).repeat(1, batch_size, 1) + + # Decode the action with positional embeddings and encoder output + action_logits = self.action_decoder( + x=repeated_hidden_states, + encoder_out=repeated_hidden_states , + decoder_pos_embed=decoder_pos_embeddings + ) - # Final action logits through the action head - action_logits = self.action_head(action_logits.squeeze()) + # Final action logits through the action head + action_logits = self.action_head(action_logits) - return action_logits + action_logits = action_logits.transpose(0, 1) + return action_logits diff --git a/lerobot/configs/policy/vla.yaml b/lerobot/configs/policy/vla.yaml index 3c74dc39b..4bc4183dd 100644 --- a/lerobot/configs/policy/vla.yaml +++ b/lerobot/configs/policy/vla.yaml @@ -36,8 +36,8 @@ policy: # Input / output structure. n_obs_steps: 1 - chunk_size: 4 - n_action_steps: 4 + chunk_size: 100 + n_action_steps: 100 input_shapes: observation.images.top: [3, 480, 640] # Video inputs (from video frames)