Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

try to fix Zero3 Memory Leak following @tohtana idea #363

Closed
wants to merge 9 commits into from
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,6 @@ any GPU memory savings. Please refer issue [[FSDP] FSDP with CPU offload consume
`P_TUNING`/`PROMPT_TUNING` appends soft prompt embeddings to `input_embeds` to create
new `input_embeds` to be given to the model. Therefore, `generate` doesn't support this yet.

4. When using ZeRO3 with zero3_init_flag=True, if you find the gpu memory increase with training steps. we might need to set zero3_init_flag=false in accelerate config.yaml. The related issue is [[BUG] memory leak under zero.Init](https://github.com/microsoft/DeepSpeed/issues/2637)

## Backlog:
- [x] Add tests
- [x] Multi Adapter training and inference support
Expand Down
6 changes: 5 additions & 1 deletion src/peft/tuners/adalora.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,11 @@ def forward(self, x: torch.Tensor):
self.unmerge()
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
elif self.r[self.active_adapter] > 0 and not self.merged:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
result = torch.matmul(x, transpose(self.weight, not self.fan_in_fan_out))

if self.bias:
result += self.bias

result += (
(
self.lora_dropout[self.active_adapter](x)
Expand Down
5 changes: 4 additions & 1 deletion src/peft/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,10 @@ def forward(self, x: torch.Tensor):
self.unmerge()
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
elif self.r[self.active_adapter] > 0 and not self.merged:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
result = torch.matmul(x, transpose(self.weight, not self.fan_in_fan_out))

if self.bias:
result += self.bias

x = x.to(self.lora_A[self.active_adapter].weight.dtype)

Expand Down