-
Notifications
You must be signed in to change notification settings - Fork 135
/
utils.py
121 lines (107 loc) · 4.15 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Adapted from https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/utils.py
Action format derived from VPT https://github.com/openai/Video-Pre-Training
"""
import math
import torch
from torch import nn
from torchvision.io import read_image, read_video
from torchvision.transforms.functional import resize
from einops import rearrange
from typing import Mapping, Sequence
def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5):
"""
sigmoid schedule
proposed in https://arxiv.org/abs/2212.11972 - Figure 8
better for images > 64x64, when used during training
"""
steps = timesteps + 1
t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
v_start = torch.tensor(start / tau).sigmoid()
v_end = torch.tensor(end / tau).sigmoid()
alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
ACTION_KEYS = [
"inventory",
"ESC",
"hotbar.1",
"hotbar.2",
"hotbar.3",
"hotbar.4",
"hotbar.5",
"hotbar.6",
"hotbar.7",
"hotbar.8",
"hotbar.9",
"forward",
"back",
"left",
"right",
"cameraX",
"cameraY",
"jump",
"sneak",
"sprint",
"swapHands",
"attack",
"use",
"pickItem",
"drop",
]
def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor:
actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS))
for i, current_actions in enumerate(actions):
for j, action_key in enumerate(ACTION_KEYS):
if action_key.startswith("camera"):
if action_key == "cameraX":
value = current_actions["camera"][0]
elif action_key == "cameraY":
value = current_actions["camera"][1]
else:
raise ValueError(f"Unknown camera action key: {action_key}")
max_val = 20
bin_size = 0.5
num_buckets = int(max_val / bin_size)
value = (value - num_buckets) / num_buckets
assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}"
else:
value = current_actions[action_key]
assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}"
actions_one_hot[i, j] = value
return actions_one_hot
IMAGE_EXTENSIONS = {"png", "jpg", "jpeg"}
VIDEO_EXTENSIONS = {"mp4"}
def load_prompt(path, video_offset=None, n_prompt_frames=1):
if path.lower().split(".")[-1] in IMAGE_EXTENSIONS:
print("prompt is image; ignoring video_offset and n_prompt_frames")
prompt = read_image(path)
# add frame dimension
prompt = rearrange(prompt, "c h w -> 1 c h w")
elif path.lower().split(".")[-1] in VIDEO_EXTENSIONS:
prompt = read_video(path, pts_unit="sec")[0]
if video_offset is not None:
prompt = prompt[video_offset:]
prompt = prompt[:n_prompt_frames]
else:
raise ValueError(f"unrecognized prompt file extension; expected one in {IMAGE_EXTENSIONS} or {VIDEO_EXTENSIONS}")
assert prompt.shape[0] == n_prompt_frames, f"input prompt {path} had less than n_prompt_frames={n_prompt_frames} frames"
prompt = resize(prompt, (360, 640))
# add batch dimension
prompt = rearrange(prompt, "t c h w -> 1 t c h w")
prompt = prompt.float() / 255.0
return prompt
def load_actions(path, action_offset=None):
if path.endswith(".actions.pt"):
actions = one_hot_actions(torch.load(path))
elif path.endswith(".one_hot_actions.pt"):
actions = torch.load(path, weights_only=True)
else:
raise ValueError("unrecognized action file extension; expected '*.actions.pt' or '*.one_hot_actions.pt'")
if action_offset is not None:
actions = actions[action_offset:]
actions = torch.cat([torch.zeros_like(actions[:1]), actions], dim=0)
# add batch dimension
actions = rearrange(actions, "t d -> 1 t d")
return actions