Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp16 fixes #222

Merged
merged 2 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions openfold/model/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(self, c_z, no_bins, **kwargs):

self.linear = Linear(self.c_z, self.no_bins, init="final")

def forward(self, z): # [*, N, N, C_z]
def _forward(self, z): # [*, N, N, C_z]
"""
Args:
z:
Expand All @@ -149,8 +149,16 @@ def forward(self, z): # [*, N, N, C_z]
logits = self.linear(z)
logits = logits + logits.transpose(-2, -3)
return logits



def forward(self, z):

float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False):
return self._forward(z.float())
else:
return self._forward(z)

class TMScoreHead(nn.Module):
"""
For use in computation of TM-score, subsection 1.9.7
Expand Down
17 changes: 16 additions & 1 deletion openfold/model/outer_product_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _chunk(self,

return outer

def forward(self,
def _forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
Expand Down Expand Up @@ -143,3 +143,18 @@ def forward(self,
outer = outer / norm

return outer

def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
inplace_safe: bool = False,
) -> torch.Tensor:

float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False):
return self._forward(m.float(), mask, chunk_size, inplace_safe)
else:
return self._forward(m, mask, chunk_size, inplace_safe)

3 changes: 3 additions & 0 deletions openfold/model/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,9 @@ def forward(
q, k, v = self._prep_qkv(q_x, kv_x)

# [*, Q, H, C_hidden]
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled:
use_memory_efficient_kernel = False
if(use_memory_efficient_kernel):
if(len(biases) > 2):
raise ValueError(
Expand Down
16 changes: 12 additions & 4 deletions openfold/model/structure_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,18 @@ def forward(
z[0] = z[0].cpu()

# [*, H, N_res, N_res]
a = torch.matmul(
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
)
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False):
a = torch.matmul(
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
)
else:
a = torch.matmul(
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
)
a *= math.sqrt(1.0 / (3 * self.c_hidden))
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))

Expand Down
7 changes: 6 additions & 1 deletion openfold/model/triangular_multiplicative_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,12 @@ def forward(self,
b = mask
b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z)
x = self._combine_projections(a, b)
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float())
else:
x = self._combine_projections(a, b)
del a, b
x = self.layer_norm_out(x)
x = self.linear_z(x)
Expand Down