Skip to content

Commit

Permalink
update comments, remove unnecessary default values
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Jan 2, 2025
1 parent 8fd93a7 commit 8e5f693
Showing 1 changed file with 3 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,14 @@ def make_empty_hessian(
def accumulate_hessian(
inp: torch.Tensor,
module: torch.nn.Module,
H: Optional[torch.Tensor] = None,
num_samples: int = 1,
H: Optional[torch.Tensor],
num_samples: int,
) -> Tuple[torch.Tensor, int]:
inp = inp.to(device=H.device)
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)

num_added = inp.shape[0] # note this is the number of dataset samples, not
# multiplied by the sequence length
num_added = inp.shape[0]

if isinstance(module, (torch.nn.Linear, transformers.Conv1D)):
if len(inp.shape) == 3:
Expand Down

0 comments on commit 8e5f693

Please sign in to comment.