From d0129a7354d155b7ea395608237f8518c39d9b51 Mon Sep 17 00:00:00 2001 From: danaaubakirova Date: Wed, 9 Oct 2024 16:50:56 +0200 Subject: [PATCH] removing qwen from modeling_vla.py Adding llava_onevision --- .../common/policies/vla/configuration_vla.py | 4 +- lerobot/common/policies/vla/modeling_vla.py | 1032 ++--------------- lerobot/configs/policy/vla.yaml | 10 +- 3 files changed, 76 insertions(+), 970 deletions(-) diff --git a/lerobot/common/policies/vla/configuration_vla.py b/lerobot/common/policies/vla/configuration_vla.py index 68e0c3161..733eeb7c7 100644 --- a/lerobot/common/policies/vla/configuration_vla.py +++ b/lerobot/common/policies/vla/configuration_vla.py @@ -115,11 +115,13 @@ class VLAConfig: } ) + prompt: str = "Please insert the tube into the socket." + # Architecture. # Language + Main transformer vocab_size: int = 152064 - hidden_size: int = 8192 + hidden_size: int = 3584 intermediate_size: int = 29568 num_hidden_layers: int = 80 num_decoder_layers: int = 1 diff --git a/lerobot/common/policies/vla/modeling_vla.py b/lerobot/common/policies/vla/modeling_vla.py index 5bf06e14a..ded833f0c 100644 --- a/lerobot/common/policies/vla/modeling_vla.py +++ b/lerobot/common/policies/vla/modeling_vla.py @@ -20,7 +20,7 @@ Qwen2VLRotaryEmbedding, ) from lerobot.common.policies.vla.modeling_vision import Qwen2VisionTransformerPretrainedModel - +from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration class VLAPolicy( nn.Module, @@ -62,7 +62,12 @@ def __init__( config.output_shapes, config.output_normalization_modes, dataset_stats ) - self.model = VLA(config) + + + self.language_model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype=torch.float16, device_map = 'cuda') + self.device= self.language_model.device + self.model = VLA(config).to(self.device) + self.processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")# Updated Qwen2VL without loss and lm_head self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] @@ -74,384 +79,74 @@ def reset(self): @torch.no_grad() def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: - """Select a single action given environment observations.""" self.eval() - - batch = self.normalize_inputs(batch) + #batch = self.normalize_inputs(batch) + if len(self.expected_image_keys) > 0: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) - - # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by - # querying the policy. - if len(self._action_queue) == 0: - # actions = self.model(batch)[0][:, : self.config.n_action_steps] - predicted_actions = self.model( - batch, input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] - ) - breakpoint() + batch = dict(batch) + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4).to(self.device) + + batch["prompt"] = self.config.prompt.to(self.device) + + # Process inputs (text and images) using the processor + processed_inputs = self.processor( + text=batch["prompt"], videos=list(batch["observation.images"]), return_tensors="pt", padding=True, do_rescale=False + ).to(self.device) + processed_inputs["pixel_values_videos"] = processed_inputs["pixel_values_videos"].to(self.device).to(torch.float16) + # Forward pass through Llava (to get hidden states) + llava_output = self.language_model( # Calling the Llava model inside VLA + **processed_inputs, + return_dict=True, + output_hidden_states=True + ) + + hidden_states = llava_output.hidden_states[-1] # Use last layer's hidden state - # TODO(rcadene): make _forward return output dictionary? - actions = self.unnormalize_outputs({"action": actions})["action"] + # Pass the hidden states to the VLA model for action decoding + predicted_actions = self.model(hidden_states) - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. + if len(self._action_queue) == 0: + actions = self.unnormalize_outputs({"action": predicted_actions})["action"] self._action_queue.extend(actions.transpose(0, 1)) + return self._action_queue.popleft() def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - """ - Forward pass through the model. - - Args: - batch (dict): Dictionary containing the following keys: - - "input_ids": Tensor of tokenized inputs (input to language model). - - "attention_mask": Tensor mask for the input tokens. - - "observation.state": Tensor containing the robot's state. - - "action": Tensor containing the ground-truth actions (optional for training). - - Returns: - dict: A dictionary containing the loss and predicted actions. - """ - # Extract inputs for the model - input_ids = batch.get("input_ids") - attention_mask = batch.get("attention_mask") - - # Forward pass through the VLA model - predicted_actions = self.model(batch, input_ids=input_ids, attention_mask=attention_mask) + #batch = self.normalize_inputs(batch) + + if len(self.expected_image_keys) > 0: + batch = dict(batch) + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4).to(self.device) + + batch["prompt"] = self.config.prompt + + processed_inputs = self.processor( + text=batch["prompt"], videos=list(batch["observation.images"]), return_tensors="pt", padding=True, do_rescale=False + ).to(self.device) + processed_inputs["pixel_values_videos"] = processed_inputs["pixel_values_videos"].to(self.device).to(torch.float16) + + # Pass inputs through Llava and VLA + llava_output = self.language_model( + **processed_inputs, + return_dict=True, + output_hidden_states=True + ) + hidden_states = llava_output.hidden_states[-1] + #hidden_states.to(dtype=torch.float16).to(self.device) + breakpoint() + # Forward pass through VLA + predicted_actions = self.model(hidden_states) loss_dict = {} - - # If ground-truth actions are available, compute L2 loss for training if "action" in batch: - true_actions = batch["action"] # Ground-truth actions - l2_loss = F.mse_loss(predicted_actions, true_actions, reduction="mean") # L2 loss + true_actions = batch["action"] + breakpoint() + l2_loss = F.mse_loss(predicted_actions, true_actions.view(predicted_actions.shape), reduction="mean") loss_dict["l2_loss"] = l2_loss.item() loss_dict["loss"] = l2_loss return loss_dict - -class Qwen2VLPreTrainedModel(PreTrainedModel): - config_class = VLAConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_static_cache = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv3d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -class Qwen2VLCausalLMOutputWithPast(ModelOutput): - """ - Base class for Qwen2VL causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): - The rope index difference between sequence length and multimodal rope. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[List[torch.FloatTensor]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - rope_deltas: Optional[torch.LongTensor] = None - - -class Qwen2VLModel(Qwen2VLPreTrainedModel): - def __init__(self, config: VLAConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen2VLRotaryEmbedding(config=config) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - # the hard coded `3` is for temporal, height and width. - if position_ids is None: - position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) - elif position_ids.dim() == 2: - position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) - - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - - if not return_dict: - return tuple( - v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - # Causal mask logic - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # Causal mask generation logic - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - if using_static_cache: - target_length = past_key_values.get_max_length() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - class ActionDecoderLayer(nn.Module): def __init__(self, config: VLAConfig): super().__init__() @@ -559,620 +254,29 @@ def forward( return x -def make_qwen2_vl_config(config): - expected_keys = set(inspect.signature(Qwen2VLConfig).parameters) - - keys_from_pretrained_config_kwargs = [ - "name_or_path", - "output_hidden_states", - "output_attentions", - "return_dict", - "is_encoder_decoder", - "is_decoder", - "cross_attention_hidden_size", - "add_cross_attention", - "tie_encoder_decoder", - "prune_heads", - "chunk_size_feed_forward", - # > Parameters for fine-tuning tasks - "architectures", - "finetuning_task", - "id2label", - "label2id", - "num_labels", - "task_specific_params", - "problem_type", - # > Parameters linked to the tokenizer - "tokenizer_class", - "prefix", - "bos_token_id", - "pad_token_id", - "eos_token_id", - "decoder_start_token_id", - "sep_token_id", - # > PyTorch specific parameters - "torchscript", - "tie_word_embeddings", - "torch_dtype", - ] - expected_keys = list(expected_keys) + keys_from_pretrained_config_kwargs - - qwen2_vl_config_kwargs = {} - for key in expected_keys: - if key == "kwargs": - continue - if hasattr(config, key): - qwen2_vl_config_kwargs[key] = getattr(config, key) - - return Qwen2VLConfig(**qwen2_vl_config_kwargs) - class VLA(nn.Module): def __init__(self, config: VLAConfig): super().__init__() # Initialize the Qwen2VLForConditionalGeneration and ActionDecoder - qwen2_vl_config = make_qwen2_vl_config(config) - self.model = Qwen2VLForConditionalGeneration( - qwen2_vl_config - ) # Updated Qwen2VL without loss and lm_head + #qwen2_vl_config = make_qwen2_vl_config(config) + self.action_decoder = ActionDecoder(config) # Use the updated ActionDecoder self.action_head = nn.Linear(config.hidden_size, config.output_shapes["action"][0]) + self.half() - def forward(self, batch, input_ids=None, attention_mask=None): + def forward(self, hidden_states): # Get the hidden states from the Qwen2VLForConditionalGeneration model - model_output = self.model( - batch=batch, input_ids=input_ids, attention_mask=attention_mask, return_dict=True - ) - - hidden_states = model_output.hidden_states - # Use the last hidden state for action decoding - action_embedding = hidden_states[:, -1:, :] # Assuming last token is the action token - encoder_out = action_embedding - + #action_embedding = hidden_states[:, -1:, :] # Assuming last token is the action token + # + encoder_out = hidden_states + breakpoint() # Decode the action - action_logits = self.action_decoder(x=action_embedding, encoder_out=encoder_out) + action_logits = self.action_decoder(x=hidden_states, encoder_out=encoder_out) # Final action logits through the action head action_logits = self.action_head(action_logits.squeeze()) return action_logits - - -class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): - # _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.visual = Qwen2VisionTransformerPretrainedModel._from_config( - config.vision_config, - attn_implementation="eager", - ) - self.model = Qwen2VLModel(config) - self.vocab_size = config.vocab_size - # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides - if "observation.state" in config.input_shapes: - self.robot_state_embed = nn.Linear( - config.input_shapes["observation.state"][0], config.hidden_size - ) - self.action_embed = nn.Linear(config.output_shapes["action"][0], config.hidden_size) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def get_rope_index( - self, - input_ids: torch.LongTensor, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Modified to handle robot state and action tokens. - """ - - spatial_merge_size = self.config.vision_config.spatial_merge_size - image_token_id = self.config.image_token_id - video_token_id = self.config.video_token_id - vision_start_token_id = self.config.vision_start_token_id - robot_state_token_id = self.config.robot_state_token_id # Assuming a token ID is set for robot state - action_token_id = self.config.action_token_id # Assuming a token ID is set for action - - mrope_position_deltas = [] - - # Check if vision features are included (images/videos) - if image_grid_thw is not None or video_grid_thw is not None: - total_input_ids = input_ids - position_ids = torch.ones( - 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device - ) - image_index, video_index = 0, 0 - for i, input_ids in enumerate(total_input_ids): - if attention_mask is not None: - input_ids = input_ids[attention_mask[i] == 1] - - image_nums, video_nums = 0, 0 - vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) - vision_tokens = input_ids[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - input_tokens = input_ids.tolist() - llm_pos_ids_list: list = [] - st = 0 - remain_images, remain_videos = image_nums, video_nums - for _ in range(image_nums + video_nums): - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - video_index += 1 - remain_videos -= 1 - ed = ed_video - llm_grid_t, llm_grid_h, llm_grid_w = ( - t.item(), - h.item() // spatial_merge_size, - w.item() // spatial_merge_size, - ) - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - t_index = ( - torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() - ) - h_index = ( - torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() - ) - w_index = ( - torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() - ) - llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - # Handle robot state and action tokens after vision and text tokens - if robot_state_token_id in input_tokens: - robot_state_idx = input_tokens.index(robot_state_token_id, st) - robot_state_len = 1 # Single token length - st_idx = llm_pos_ids_list[-1].max() + 1 - llm_pos_ids_list.append(torch.arange(robot_state_len).view(1, -1).expand(3, -1) + st_idx) - st = robot_state_idx + robot_state_len - - if action_token_id in input_tokens: - action_idx = input_tokens.index(action_token_id, st) - action_len = 1 # Single token length - st_idx = llm_pos_ids_list[-1].max() + 1 - llm_pos_ids_list.append(torch.arange(action_len).view(1, -1).expand(3, -1) + st_idx) - st = action_idx + action_len - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) - mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) - - mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) - return position_ids, mrope_position_deltas - - else: - if attention_mask is not None: - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) - max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] - mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] - else: - position_ids = ( - torch.arange(input_ids.shape[1], device=input_ids.device) - .view(1, 1, -1) - .expand(3, input_ids.shape[0], -1) - ) - mrope_position_deltas = torch.zeros( - [input_ids.shape[0], 1], - device=input_ids.device, - dtype=input_ids.dtype, - ) - - return position_ids, mrope_position_deltas - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - num_new_tokens: int = 1, - ) -> Dict[str, Any]: - model_kwargs = super()._update_model_kwargs_for_generation( - outputs=outputs, - model_kwargs=model_kwargs, - is_encoder_decoder=is_encoder_decoder, - num_new_tokens=num_new_tokens, - ) - - if getattr(outputs, "rope_deltas", None) is not None: - model_kwargs["rope_deltas"] = outputs.rope_deltas - - return model_kwargs - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - robot_state: Optional[torch.Tensor] = None, - action: Optional[torch.Tensor] = None, - ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) - - # Step 1: Process vision inputs (images/videos) if provided - if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.get_dtype()) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - image_mask = ( - (input_ids == self.config.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - - if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - video_mask = ( - (input_ids == self.config.video_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - - # Step 2: Process robot state and action embeddings if provided - if robot_state is not None: - robot_state_embeds = self.robot_state_embed(robot_state).unsqueeze(1) - inputs_embeds = torch.cat([inputs_embeds, robot_state_embeds], dim=1) - - if action is not None: - action_embeds = self.action_embed(action).unsqueeze(1) - inputs_embeds = torch.cat([inputs_embeds, action_embeds], dim=1) - - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - - # Step 3: Forward pass through the model - outputs = self.model( - input_ids=None, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - if not return_dict: - return (hidden_states,) - - return Qwen2VLCausalLMOutputWithPast( - loss=None, - logits=None, # No logits since there's no lm_head - past_key_values=outputs.past_key_values, - hidden_states=hidden_states, - attentions=outputs.attentions, - rope_deltas=None, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - pixel_values=None, - pixel_values_videos=None, - image_grid_thw=None, - video_grid_thw=None, - robot_state=None, - action=None, - **kwargs, - ): - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - if past_key_values is not None: - if inputs_embeds is not None: - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: - input_ids = input_ids[:, cache_position] - - rope_deltas = kwargs.get("rope_deltas", None) - if attention_mask is not None and position_ids is None: - if cache_position is None or (cache_position is not None and cache_position[0] == 0): - position_ids, rope_deltas = self.get_rope_index( - input_ids, image_grid_thw, video_grid_thw, attention_mask - ) - else: - batch_size, seq_length = input_ids.shape - delta = ( - cache_position[0] + rope_deltas - if cache_position is not None and rope_deltas is not None - else 0 - ) - position_ids = torch.arange(seq_length, device=input_ids.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - position_ids = position_ids.add(delta) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - - if cache_position[0] != 0: - pixel_values = None - pixel_values_videos = None - - # Handle embeddings passed only for the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - model_inputs = {"input_ids": input_ids, "inputs_embeds": None} - - # Update attention mask for robot state and action tokens - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = inputs_embeds.shape - device = inputs_embeds.device - else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device - - attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_length(), - dtype=self.lm_head.weight.dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - ) - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "pixel_values_videos": pixel_values_videos, - "image_grid_thw": image_grid_thw, - "video_grid_thw": video_grid_thw, - "rope_deltas": rope_deltas, - "robot_state": robot_state, - "action": action, - } - ) - - return model_inputs - - -""" -class Qwen2VLModel(Qwen2VLPreTrainedModel): - def __init__(self, config: VLAConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - # Initialize flags for using robot state and environment state - self.use_robot_state = "observation.state" in config.input_shapes - self.use_images = any(k.startswith("observation.image") for k in config.input_shapes) - self.use_env_state = "observation.environment_state" in config.input_shapes - - # Embedding layers for robot observation state and action - if self.use_robot_state: - self.robot_state_embed = nn.Linear( - config.input_shapes["observation.state"][0], config.hidden_size - ) - # Embedding layer for robot action - self.action_embed = nn.Linear( - config.output_shapes["action"][0], config.hidden_size - ) - - # Token embedding for text - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - - # Layers for processing - self.layers = nn.ModuleList( - [Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen2VLRotaryEmbedding(config=config) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def forward( - self, - batch: dict, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - - # Step 1: Embed the text input if provided - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - # Step 2: Get robot state and action embeddings if applicable - robot_state_embedding = None - if self.use_robot_state: - robot_state_embedding = self.robot_state_embed(batch["observation.state"]).unsqueeze(1) - - action_embedding = self.action_embed(batch["action"]).unsqueeze(1) - - # Step 3: Combine text, robot state, and action embeddings - if robot_state_embedding is not None: - combined_embedding = torch.cat([inputs_embeds, robot_state_embedding, action_embedding], dim=1) - else: - combined_embedding = torch.cat([inputs_embeds, action_embedding], dim=1) - - # Step 4: Proceed with the rest of the forward pass - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + combined_embedding.shape[1], device=combined_embedding.device - ) - - # The hard coded `3` is for temporal, height, and width. - if position_ids is None: - position_ids = cache_position.view(1, 1, -1).expand(3, combined_embedding.shape[0], -1) - elif position_ids.dim() == 2: - position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - - causal_mask = self._update_causal_mask( - attention_mask, combined_embedding, cache_position, past_key_values, output_attentions - ) - - hidden_states = combined_embedding - - # Create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # Decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # Add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) -""" diff --git a/lerobot/configs/policy/vla.yaml b/lerobot/configs/policy/vla.yaml index 59bbd2779..3c74dc39b 100644 --- a/lerobot/configs/policy/vla.yaml +++ b/lerobot/configs/policy/vla.yaml @@ -2,7 +2,6 @@ seed: 1000 dataset_repo_id: lerobot/aloha_sim_insertion_human -prompt: "Please insert the tube into the socket." override_dataset_stats: observation.images.top: @@ -37,15 +36,15 @@ policy: # Input / output structure. n_obs_steps: 1 - chunk_size: 100 - n_action_steps: 100 - + chunk_size: 4 + n_action_steps: 4 + input_shapes: observation.images.top: [3, 480, 640] # Video inputs (from video frames) observation.state: [128] # State input dimension text.input: [256] # Text input processed by Qwen-VL output_shapes: - action: [4] # Action output dimension (Example: 4D actions) + action: [14] # Action output dimension (Example: 4D actions) # Normalization / Unnormalization input_normalization_modes: @@ -60,6 +59,7 @@ policy: state_encoder: hidden_dim: 256 latent_dim: 64 + prompt: "Please insert the tube into the socket." # Vision-Language Model (Qwen-VL) vlm_backbone: qwen_vl