diff --git a/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py b/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py index d38de8eb10b9..2c3b30f2fc74 100644 --- a/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py +++ b/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py @@ -424,26 +424,24 @@ def __init__(self, *args, **kwargs): # TODO (yuya): need to handle post_process correctly in order to enable PP self.output_dim = kwargs.pop('output_dim') super().__init__(*args, **kwargs) - if self.post_process: - self.final_layernorm = TENorm( - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - ) - self.head = torch.nn.Linear( - self.config.hidden_size, - self.output_dim, - bias=False, - ) + self.final_layernorm = TENorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + self.head = torch.nn.Linear( + self.config.hidden_size, + self.output_dim, + bias=False, + ) def forward(self, x): x = super().forward( x, ) - if self.post_process: - x = self.final_layernorm(x) - x = x[:, 0] - x = self.head(x) + x = self.final_layernorm(x) + x = x[:, 0] + x = self.head(x) return x diff --git a/nemo/collections/vlm/neva/model/llava.py b/nemo/collections/vlm/neva/model/llava.py index dc27f28373fa..da894f183bbf 100644 --- a/nemo/collections/vlm/neva/model/llava.py +++ b/nemo/collections/vlm/neva/model/llava.py @@ -111,7 +111,7 @@ def convert_state(self, source, target): "language_model.model.layers.*.post_attention_layernorm.weight": "language_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight", "language_model.model.norm.weight": "language_model.decoder.final_layernorm.weight", "language_model.lm_head.weight": "language_model.output_layer.weight", - "vision_tower.vision_model.*": "vision_model.vision_model.*", + "vision_tower.vision_model.**": "vision_model.vision_model.**", } if "vision_projection.encoder.linear_fc1.weight" in target.module.state_dict().keys(): mapping.update( diff --git a/nemo/lightning/io/state.py b/nemo/lightning/io/state.py index 2a4588617241..fc2281b9b063 100644 --- a/nemo/lightning/io/state.py +++ b/nemo/lightning/io/state.py @@ -326,8 +326,28 @@ def call_transform(self, ctx: TransformCTX, *args, **kwargs): def _match_keys(keys: List[str], pattern: str) -> np.ndarray: - regex_pattern = re.compile("^" + pattern.replace("*", r"([^.]+)") + "$") - wildcard_matches = [[] for _ in range(pattern.count("*"))] + escaped_pattern = '' + i = 0 + wildcard_positions = [] + while i < len(pattern): + if pattern[i : i + 2] == '**': + escaped_pattern += r'(.+)' # Match any characters including dots + wildcard_positions.append('**') + i += 2 + elif pattern[i] == '*': + escaped_pattern += r'([^.]+)' # Match any characters except dots + wildcard_positions.append('*') + i += 1 + else: + if pattern[i] == '.': + escaped_pattern += r'\.' # Escape the dot + else: + escaped_pattern += pattern[i] + i += 1 + + regex_pattern = re.compile("^" + escaped_pattern + "$") + num_wildcards = len(wildcard_positions) + wildcard_matches = [[] for _ in range(num_wildcards)] for key in filter(lambda x: x is not None, keys): match = regex_pattern.match(key)