Skip to content

Commit

Permalink
adding positional embeddings to the action decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
danaaubakirova committed Oct 10, 2024
1 parent d0129a7 commit 903bd23
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 24 deletions.
70 changes: 48 additions & 22 deletions lerobot/common/policies/vla/modeling_vla.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,7 @@
from transformers.modeling_utils import PreTrainedModel

from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.vla.configuration_qwen2_vl import Qwen2VLConfig
from lerobot.common.policies.vla.configuration_vla import VLAConfig
from lerobot.common.policies.vla.modeling_language import (
Qwen2RMSNorm,
Qwen2VLDecoderLayer,
Qwen2VLRotaryEmbedding,
)
from lerobot.common.policies.vla.modeling_vision import Qwen2VisionTransformerPretrainedModel
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration

class VLAPolicy(
Expand Down Expand Up @@ -61,13 +54,12 @@ def __init__(
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)



self.language_model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype=torch.float16, device_map = 'cuda')
self.device= self.language_model.device
self.model = VLA(config).to(self.device)
self.processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")# Updated Qwen2VL without loss and lm_head


self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]

Expand All @@ -92,7 +84,9 @@ def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
processed_inputs = self.processor(
text=batch["prompt"], videos=list(batch["observation.images"]), return_tensors="pt", padding=True, do_rescale=False
).to(self.device)

processed_inputs["pixel_values_videos"] = processed_inputs["pixel_values_videos"].to(self.device).to(torch.float16)
breakpoint()
# Forward pass through Llava (to get hidden states)
llava_output = self.language_model( # Calling the Llava model inside VLA
**processed_inputs,
Expand All @@ -101,6 +95,7 @@ def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
)

hidden_states = llava_output.hidden_states[-1] # Use last layer's hidden state
hidden_states = hidden_states[:, -4:, :] #make 4 a config parameter

# Pass the hidden states to the VLA model for action decoding
predicted_actions = self.model(hidden_states)
Expand All @@ -112,6 +107,7 @@ def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
return self._action_queue.popleft()

def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:

#batch = self.normalize_inputs(batch)

if len(self.expected_image_keys) > 0:
Expand All @@ -132,6 +128,7 @@ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
output_hidden_states=True
)
hidden_states = llava_output.hidden_states[-1]
hidden_states = hidden_states[:, -4:, :]
#hidden_states.to(dtype=torch.float16).to(self.device)
breakpoint()
# Forward pass through VLA
Expand All @@ -141,7 +138,7 @@ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
if "action" in batch:
true_actions = batch["action"]
breakpoint()
l2_loss = F.mse_loss(predicted_actions, true_actions.view(predicted_actions.shape), reduction="mean")
l2_loss = F.mse_loss(predicted_actions, true_actions, reduction="mean")
loss_dict["l2_loss"] = l2_loss.item()
loss_dict["loss"] = l2_loss

Expand Down Expand Up @@ -195,6 +192,7 @@ def forward(
if self.pre_norm:
x = self.norm1(x)
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
breakpoint()

# Self-attention
x = self.self_attn(q, k, value=x)[0] # select just the output, not attention weights
Expand Down Expand Up @@ -261,22 +259,50 @@ def __init__(self, config: VLAConfig):

# Initialize the Qwen2VLForConditionalGeneration and ActionDecoder
#qwen2_vl_config = make_qwen2_vl_config(config)

self.chunk_size = config.chunk_size
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.hidden_size)
self.action_decoder = ActionDecoder(config) # Use the updated ActionDecoder
self.action_head = nn.Linear(config.hidden_size, config.output_shapes["action"][0])


self.half()

def forward(self, hidden_states):
# Get the hidden states from the Qwen2VLForConditionalGeneration model
# Use the last hidden state for action decoding
#action_embedding = hidden_states[:, -1:, :] # Assuming last token is the action token
#
encoder_out = hidden_states
breakpoint()
# Decode the action
action_logits = self.action_decoder(x=hidden_states, encoder_out=encoder_out)
"""
Forward pass to compute action logits using hidden states from Qwen2VL (Llava).
Args:
hidden_states: Tensor of shape [batch_size, seq_len, hidden_size] from Llava model.
Returns:
action_logits: Tensor of predicted actions.
"""
batch_size = hidden_states.size(0) # Ensure batch size is extracted
seq_len = hidden_states.size(1) # Sequence length of hidden states
hidden_size = hidden_states.size(2) # Hidden size

# Ensure encoder_out has the correct shape [chunk_size, batch_size, seq_len, hidden_size]
# Repeat the encoder output for chunk size across the batch dimension
#encoder_out = hidden_states.unsqueeze(0).repeat(self.chunk_size, 1, 1, 1) # [chunk_size, batch_size, seq_len, hidden_size]
#encoder_out = encoder_out.view(self.chunk_size * seq_len, batch_size, hidden_size)

# Repeat the decoder input (hidden states) as well, maintaining batch and hidden size
repeated_hidden_states = hidden_states.unsqueeze(0).repeat(self.chunk_size//seq_len, 1, 1, 1) # [chunk_size, batch_size, seq_len, hidden_size]
breakpoint()
repeated_hidden_states = repeated_hidden_states.view(self.chunk_size, batch_size, hidden_size)

# Generate positional embeddings for the decoder
decoder_pos_embeddings = self.decoder_pos_embed.weight.unsqueeze(1).repeat(1, batch_size, 1)

# Decode the action with positional embeddings and encoder output
action_logits = self.action_decoder(
x=repeated_hidden_states,
encoder_out=repeated_hidden_states ,
decoder_pos_embed=decoder_pos_embeddings
)

# Final action logits through the action head
action_logits = self.action_head(action_logits.squeeze())
# Final action logits through the action head
action_logits = self.action_head(action_logits)

return action_logits
action_logits = action_logits.transpose(0, 1)
return action_logits
4 changes: 2 additions & 2 deletions lerobot/configs/policy/vla.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ policy:

# Input / output structure.
n_obs_steps: 1
chunk_size: 4
n_action_steps: 4
chunk_size: 100
n_action_steps: 100

input_shapes:
observation.images.top: [3, 480, 640] # Video inputs (from video frames)
Expand Down

0 comments on commit 903bd23

Please sign in to comment.