Skip to content

Commit

Permalink
Remove unused functions (#2017)
Browse files Browse the repository at this point in the history
  • Loading branch information
northern-64bit authored Sep 8, 2024
1 parent 7a67de3 commit 8a518ee
Showing 1 changed file with 0 additions and 22 deletions.
22 changes: 0 additions & 22 deletions trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,6 @@ def stack_dicts(stats_dicts: List[Dict]) -> Dict:
return results


def add_suffix(input_dict: Dict, suffix: str) -> Dict:
"""Add suffix to dict keys."""
return {k + suffix: v for k, v in input_dict.items()}


def pad_to_size(tensor: torch.Tensor, size: int, dim: int = 1, padding: int = 50256) -> torch.Tensor:
"""Pad tensor to size."""
t_size = tensor.size()[dim]
if t_size == size:
return tensor
else:
return torch.nn.functional.pad(tensor, (0, size - t_size), "constant", padding)


def logprobs_from_logits(logits: torch.Tensor, labels: torch.Tensor, gather: bool = True) -> torch.Tensor:
"""
See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
Expand Down Expand Up @@ -201,14 +187,6 @@ def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
return entropy


def average_torch_dicts(list_of_dicts: List[Dict]) -> Dict:
"""Average values of a list of dicts with torch tensors."""
average_dict = dict()
for key in list_of_dicts[0].keys():
average_dict[key] = torch.mean(torch.stack([d[key] for d in list_of_dicts]), axis=0)
return average_dict


def stats_to_np(stats_dict: Dict) -> Dict:
"""Cast all torch.tensors in dict to numpy arrays."""
new_dict = dict()
Expand Down

0 comments on commit 8a518ee

Please sign in to comment.