Skip to content

Commit

Permalink
Add type hints to core.py (huggingface#1097)
Browse files Browse the repository at this point in the history
* Add type hinting to core.py functions

* Fixes

* Remove unused functions

* Remove unused import
  • Loading branch information
zachschillaci27 authored and Andrew Lapp committed May 10, 2024
1 parent f736557 commit 1b9b397
Showing 1 changed file with 31 additions and 58 deletions.
89 changes: 31 additions & 58 deletions trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
import random
import warnings
from contextlib import contextmanager
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import top_k_top_p_filtering
Expand All @@ -35,24 +36,24 @@
WANDB_PADDING = -1


def flatten_dict(nested, sep="/"):
def flatten_dict(nested: Dict, sep: str = "/") -> Dict:
"""Flatten dictionary and concatenate nested keys with separator."""

def rec(nest, prefix, into):
def recurse(nest: Dict, prefix: str, into: Dict) -> None:
for k, v in nest.items():
if sep in k:
raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'")
if isinstance(v, Mapping):
rec(v, prefix + k + sep, into)
recurse(v, prefix + k + sep, into)
else:
into[prefix + k] = v

flat = {}
rec(nested, "", flat)
recurse(nested, "", flat)
return flat


def convert_to_scalar(stats):
def convert_to_scalar(stats: Dict) -> Dict:
"""
Converts the stats from a flattened dict to single scalar dicts
"""
Expand All @@ -68,7 +69,7 @@ def convert_to_scalar(stats):
return tensorboard_stats


def stack_dicts(stats_dicts):
def stack_dicts(stats_dicts: List[Dict]) -> Dict:
"""Stack the values of a dict."""
results = dict()
for k in stats_dicts[0]:
Expand All @@ -77,12 +78,12 @@ def stack_dicts(stats_dicts):
return results


def add_suffix(input_dict, suffix):
def add_suffix(input_dict: Dict, suffix: str) -> Dict:
"""Add suffix to dict keys."""
return dict((k + suffix, v) for k, v in input_dict.items())


def pad_to_size(tensor, size, dim=1, padding=50256):
def pad_to_size(tensor: torch.Tensor, size: int, dim: int = 1, padding: int = 50256) -> torch.Tensor:
"""Pad tensor to size."""
t_size = tensor.size()[dim]
if t_size == size:
Expand All @@ -91,7 +92,7 @@ def pad_to_size(tensor, size, dim=1, padding=50256):
return torch.nn.functional.pad(tensor, (0, size - t_size), "constant", padding)


def logprobs_from_logits(logits, labels, gather=True):
def logprobs_from_logits(logits: torch.Tensor, labels: torch.Tensor, gather: bool = True) -> torch.Tensor:
"""
See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
"""
Expand All @@ -103,7 +104,7 @@ def logprobs_from_logits(logits, labels, gather=True):
return logpy


def whiten(values, shift_mean=True):
def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
"""Whiten values."""
mean, var = torch.mean(values), torch.var(values)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
Expand All @@ -112,15 +113,15 @@ def whiten(values, shift_mean=True):
return whitened


def masked_mean(values, mask, axis=None):
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool = None) -> torch.Tensor:
"""Compute mean of tensor with a masked values."""
if axis is not None:
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
else:
return (values * mask).sum() / mask.sum()


def masked_var(values, mask, unbiased=True):
def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor:
"""Compute variance of tensor with masked values."""
mean = masked_mean(values, mask)
centered_values = values - mean
Expand All @@ -139,7 +140,7 @@ def masked_var(values, mask, unbiased=True):
return variance


def masked_whiten(values, mask, shift_mean=True):
def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
"""Whiten values with masked values."""
mean, var = masked_mean(values, mask), masked_var(values, mask)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
Expand All @@ -148,31 +149,31 @@ def masked_whiten(values, mask, shift_mean=True):
return whitened


def clip_by_value(x, tensor_min, tensor_max):
def clip_by_value(x: torch.Tensor, tensor_min: float, tensor_max: float) -> torch.Tensor:
"""
Tensor extenstion to torch.clamp
Tensor extension to torch.clamp
https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
"""
clipped = torch.max(torch.min(x, tensor_max), tensor_min)
return clipped


def entropy_from_logits(logits):
def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
"""Calculate entropy from logits."""
pd = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1)
return entropy


def average_torch_dicts(list_of_dicts):
def average_torch_dicts(list_of_dicts: List[Dict]) -> Dict:
"""Average values of a list of dicts with torch tensors."""
average_dict = dict()
for key in list_of_dicts[0].keys():
average_dict[key] = torch.mean(torch.stack([d[key] for d in list_of_dicts]), axis=0)
return average_dict


def stats_to_np(stats_dict):
def stats_to_np(stats_dict: Dict) -> Dict:
"""Cast all torch.tensors in dict to numpy arrays."""
new_dict = dict()
for k, v in stats_dict.items():
Expand All @@ -188,37 +189,9 @@ def stats_to_np(stats_dict):
return new_dict


def listify_batch(tensor):
"""Turns the first dimension of a tensor into a list."""
return [tensor[i] for i in range(tensor.shape[0])]


def build_bert_batch_from_txt(text_list, tokenizer, device):
"""Create token id and attention mask tensors from text list for BERT classification."""

# tokenize
tensors = [tokenizer.encode(txt, return_tensors="pt").to(device) for txt in text_list]

# find max length to pad to
max_len = max([t.size()[1] for t in tensors])

# get padded tensors and attention masks
# (attention masks make bert ignore padding)
padded_tensors = []
attention_masks = []
for tensor in tensors:
attention_mask = torch.ones(tensor.size(), device=device)
padded_tensors.append(pad_to_size(tensor, max_len, padding=0))
attention_masks.append(pad_to_size(attention_mask, max_len, padding=0))

# stack all tensors
padded_tensors = torch.cat(padded_tensors)
attention_masks = torch.cat(attention_masks)

return padded_tensors, attention_masks


def respond_to_batch(model, queries, txt_len=20, top_k=0, top_p=1.0):
def respond_to_batch(
model: nn.Module, queries: List[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0
) -> torch.LongTensor:
"""Sample text from language model."""
input_ids = queries
for i in range(txt_len):
Expand All @@ -233,7 +206,7 @@ def respond_to_batch(model, queries, txt_len=20, top_k=0, top_p=1.0):
return input_ids[:, -txt_len:]


def set_seed(seed: int):
def set_seed(seed: int) -> None:
"""
Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`.
Expand All @@ -256,10 +229,10 @@ class LengthSampler:
Samples a length
"""

def __init__(self, min_value, max_value):
def __init__(self, min_value: int, max_value: int):
self.values = list(range(min_value, max_value))

def __call__(self):
def __call__(self) -> int:
return np.random.choice(self.values)


Expand Down Expand Up @@ -287,11 +260,11 @@ def empty_device_cache(cls):

def randn_tensor(
shape: Union[Tuple, List],
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
device: Optional["torch.device"] = None,
dtype: Optional["torch.dtype"] = None,
layout: Optional["torch.layout"] = None,
):
generator: Optional[Union[List[torch.Generator], torch.Generator]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
) -> torch.Tensor:
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
is always created on the CPU.
Expand Down

0 comments on commit 1b9b397

Please sign in to comment.