Skip to content

Commit

Permalink
[Bugfix] Fix missing post_layernorm in CLIP (vllm-project#8155)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Sep 10, 2024
1 parent 99ae93b commit d5ca2b8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 19 deletions.
29 changes: 25 additions & 4 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,19 @@ def __init__(self,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override)

if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {config.num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
elif len(self.encoder.layers) == config.num_hidden_layers:
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None

def forward(
self,
pixel_values: torch.Tensor,
Expand All @@ -364,7 +377,10 @@ def forward(
hidden_states = self.pre_layrnorm(hidden_states)
hidden_states = self.encoder(inputs_embeds=hidden_states)

return hidden_states
if self.post_layernorm is None:
return hidden_states

return self.post_layernorm(hidden_states)


class CLIPVisionModel(nn.Module):
Expand All @@ -386,9 +402,12 @@ def __init__(self,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override)

def forward(self, pixel_values: Optional[torch.Tensor] = None):
@property
def _require_post_layernorm(self) -> bool:
return self.vision_model.post_layernorm is not None

return self.vision_model(pixel_values=pixel_values)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.vision_model(pixel_values)

@property
def device(self):
Expand All @@ -408,8 +427,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

for name, loaded_weight in weights:
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
if ("vision_model.post_layernorm" in name
and not self._require_post_layernorm):
continue

# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
layer_idx = int(name.split(".")[3])
Expand Down
32 changes: 17 additions & 15 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,27 +443,26 @@ def __init__(
self.config = config
embed_dim = config.hidden_size

if (num_hidden_layers_override is None
or num_hidden_layers_override == config.num_hidden_layers):
self.need_post_layernorm = True
elif num_hidden_layers_override > config.num_hidden_layers:
raise ValueError(
"num_hidden_layers_override cannot be greater than "
"num_hidden_layers")
else:
self.need_post_layernorm = False

self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
)
if self.need_post_layernorm:

if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {config.num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
elif len(self.encoder.layers) == config.num_hidden_layers:
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
else:
self.post_layernorm = nn.Identity()
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None

self.use_head = (True if not hasattr(config, "vision_use_head") else
config.vision_use_head)
if self.use_head:
Expand All @@ -482,6 +481,9 @@ def forward(

encoder_outputs = self.encoder(inputs_embeds=hidden_states)

if self.post_layernorm is None:
return encoder_outputs

last_hidden_state = self.post_layernorm(encoder_outputs)
# TODO: add this back when pooled_output is used in inference
# if self.use_head:
Expand Down Expand Up @@ -512,8 +514,8 @@ def __init__(
)

@property
def need_post_layernorm(self):
return self.vision_model.need_post_layernorm
def _require_post_layernorm(self) -> bool:
return self.vision_model.post_layernorm is not None

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
Expand Down Expand Up @@ -541,7 +543,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
for name, loaded_weight in weights:
# post_layernorm is optional in SiglipVisionModel
if ("vision_model.post_layernorm" in name
and not self.need_post_layernorm):
and not self._require_post_layernorm):
continue

# omit layers when num_hidden_layers_override is set
Expand Down

0 comments on commit d5ca2b8

Please sign in to comment.