Skip to content

Commit

Permalink
Fixed issue with adapters not providing gradients with new grad activ…
Browse files Browse the repository at this point in the history
…ator
  • Loading branch information
jaretburkett committed Oct 29, 2024
1 parent 22cd40d commit 4747716
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 16 deletions.
4 changes: 2 additions & 2 deletions jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,7 @@ def run(self):
# self.step_num = self.embedding.step
# self.start_step = self.step_num
params.append({
'params': self.embedding.get_trainable_params(),
'params': list(self.embedding.get_trainable_params()),
'lr': self.train_config.embedding_lr
})

Expand All @@ -1491,7 +1491,7 @@ def run(self):
else:
# set trainable params
params.append({
'params': self.adapter.parameters(),
'params': list(self.adapter.parameters()),
'lr': self.train_config.adapter_lr
})

Expand Down
4 changes: 2 additions & 2 deletions toolkit/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,13 +1161,13 @@ def get_parameter_groups(self, adapter_lr):
# when training just scaler, we do not train anything else
if not self.config.train_scaler:
param_groups.append({
"params": self.get_non_scaler_parameters(),
"params": list(self.get_non_scaler_parameters()),
"lr": adapter_lr,
})
if self.config.train_scaler or self.config.merge_scaler:
scaler_lr = adapter_lr if self.config.scaler_lr is None else self.config.scaler_lr
param_groups.append({
"params": self.get_scaler_parameters(),
"params": list(self.get_scaler_parameters()),
"lr": scaler_lr,
})
return param_groups
Expand Down
33 changes: 21 additions & 12 deletions toolkit/models/vd_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,20 +711,30 @@ def __init__(
self.block_scaler.requires_grad = True
else:
self.block_scaler = None

self.pool = None

if self.config.num_tokens is not None:
image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict()
# image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict()
# max_seq_len = CLIP tokens + CLS token
max_seq_len = 257
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
# clip
max_seq_len = int(
image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
self.resampler = MLPR(
in_dim=self.token_size,
in_channels=max_seq_len,
out_dim=self.mid_size,
out_channels=self.config.num_tokens,
# max_seq_len = 257
# if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
# # clip
# max_seq_len = int(
# image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
# self.resampler = MLPR(
# in_dim=self.token_size,
# in_channels=max_seq_len,
# out_dim=self.mid_size,
# out_channels=self.config.num_tokens,
# )
vision_config = self.adapter_ref().vision_encoder.config
# sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2 + 1)
# siglip doesnt add 1
sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2)
self.pool = nn.Sequential(
nn.Conv1d(sequence_length, self.config.num_tokens, 1, bias=False),
Norm(),
)

elif self.config.image_encoder_arch == "pixtral":
Expand All @@ -733,7 +743,6 @@ def __init__(
out_dim=self.mid_size,
)

self.pool = None
self.sparse_autoencoder = None
if self.config.conv_pooling:
vision_config = self.adapter_ref().vision_encoder.config
Expand Down

0 comments on commit 4747716

Please sign in to comment.