Skip to content

Commit

Permalink
batchv2
Browse files Browse the repository at this point in the history
  • Loading branch information
siqim committed Apr 23, 2024
1 parent 350a986 commit 2e40838
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 146 deletions.
Binary file modified example/ckpt/tracking-60k-model.pt
Binary file not shown.
98 changes: 61 additions & 37 deletions example/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"import os\n",
"import sys\n",
"sys.path.append('../src')\n",
Expand Down Expand Up @@ -44,7 +43,7 @@
},
"outputs": [],
"source": [
"device = 'cuda:7'\n",
"device = 'cuda:1'\n",
"dataset_name = 'tracking-60k'\n",
"batch_size = 1\n",
"model_configs = {'block_size': 100, 'n_hashes': 3, 'num_regions': 150, 'num_heads': 8, 'h_dim': 24, 'n_layers': 4, 'num_w_per_dist': 10}\n",
Expand All @@ -64,13 +63,22 @@
"outputs": [],
"source": [
"dataset_dir = Path('../data/') / dataset_name.split(\"-\")[0]\n",
"dataset = get_dataset(dataset_name, dataset_dir)\n",
"loaders = get_data_loader(dataset, dataset.idx_split, batch_size=batch_size)"
"dataset = get_dataset(dataset_name, dataset_dir)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c81e94bc",
"metadata": {},
"outputs": [],
"source": [
"loaders = get_data_loader(dataset, dataset.idx_split, batch_size=batch_size)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6b3b391649705898",
"metadata": {
"ExecuteTime": {
Expand All @@ -85,38 +93,37 @@
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8d0d75234ae914dd",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-10T15:13:32.361412Z",
"start_time": "2024-04-10T15:13:32.349108Z"
}
},
"execution_count": 6,
"id": "0e6034a9",
"metadata": {},
"outputs": [],
"source": [
"checkpoint = torch.load(\"./ckpt/tracking-60k-model.pt\", map_location=\"cpu\")\n",
"model.load_state_dict(checkpoint, strict=True)\n",
"model = model.to(device)\n",
"\n",
"criterion = get_loss('infonce', {'dist_metric': 'l2_rbf', 'tau': 0.05})\n",
"metrics = init_metrics(dataset_name)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "1c365bee",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Epoch 0] test , loss: 0.5879, acc: 0.9191, prec: 0.3807, recall: 0.9750: 100%|██████████| 5/5 [00:11<00:00, 2.25s/it]"
"[Epoch 0] test , loss: 0.5884, acc: 0.9189, prec: 0.3805, recall: 0.9744: 100%|██████████| 5/5 [00:08<00:00, 1.64s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy@0.9: 0.9191\n"
"Test accuracy@0.9: 0.9189\n"
]
},
{
Expand All @@ -128,53 +135,70 @@
}
],
"source": [
"checkpoint = torch.load(\"./ckpt/tracking-60k-model.pt\", map_location=\"cpu\")\n",
"model.load_state_dict(checkpoint, strict=True)\n",
"model = model.to(device)\n",
"\n",
"with torch.no_grad():\n",
" model.eval()\n",
" test_res = run_one_epoch(model, None, criterion, loaders[\"test\"], \"test\", 0, device, metrics, None)\n",
"\n",
"print(f\"Test accuracy@0.9: {test_res['accuracy@0.9']:.4f}\")"
]
},
{
"cell_type": "markdown",
"id": "d490a331",
"metadata": {},
"source": [
"# Benchmark Inference Speed"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "6a1d502a",
"metadata": {},
"outputs": [],
"source": [
"# model = torch.compile(model)"
"model = torch.compile(model)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "65848375",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<torch.utils.benchmark.utils.common.Measurement object at 0x7ef33f5cf430>\n",
"model(data.x, data.coords, data.batch)\n",
"setup: from __main__ import model, data\n",
" Median: 29.96 ms\n",
" IQR: 0.07 ms (29.92 to 29.99)\n",
" 167 measurements, 1 runs per measurement, 1 thread\n"
]
}
],
"source": [
"# torch.set_float32_matmul_precision('high')\n",
"# for data in loaders[\"test\"]:\n",
"# if data.x.shape[0] > 60000:\n",
"# data = data.to(device)\n",
"# break\n",
"torch.set_float32_matmul_precision('high')\n",
"for data in loaders[\"test\"]:\n",
" if data.x.shape[0] > 60000:\n",
" data = data.to(device)\n",
" break\n",
"\n",
"# model.eval()\n",
"# with torch.no_grad():\n",
"# t1 = benchmark.Timer(\n",
"# stmt=f\"model(data.x, data.coords, data.batch)\", setup=f\"from __main__ import model, data\"\n",
"# )\n",
"# m = t1.blocked_autorange(min_run_time=5)\n",
"# print(m)"
"model.eval()\n",
"with torch.no_grad():\n",
" t1 = benchmark.Timer(\n",
" stmt=f\"model(data.x, data.coords, data.batch)\", setup=f\"from __main__ import model, data\"\n",
" )\n",
" m = t1.blocked_autorange(min_run_time=5)\n",
"print(m)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33709f92",
"id": "93f727b6",
"metadata": {},
"outputs": [],
"source": []
Expand All @@ -196,7 +220,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down
24 changes: 0 additions & 24 deletions example/hept.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,9 @@
import torch
import torch.nn as nn
from einops import rearrange
from typing import List
from hept_utils import E2LSH, invert_permutation, lsh_mapping, sort_to_buckets, unsort_from_buckets


@torch.no_grad()
def get_geo_shift(regions_h: List[List[int]], hash_shift, region_indices, num_or_hashes):
region_indices_eta, region_indices_phi = region_indices

q_hash_shift_eta = region_indices_eta * hash_shift
k_hash_shift_eta = region_indices_eta * hash_shift

q_hash_shift_phi = region_indices_phi * hash_shift * (torch.ceil(regions_h[0][:, None]) + 1)
k_hash_shift_phi = region_indices_phi * hash_shift * (torch.ceil(regions_h[0][:, None]) + 1)
res = torch.stack([q_hash_shift_phi + q_hash_shift_eta, k_hash_shift_phi + k_hash_shift_eta], dim=0)
return rearrange(res, "a (c h) n -> a c h n", c=num_or_hashes)


def qkv_res(s_query, s_key, s_value):
q_sq_05 = -0.5 * (s_query**2).sum(dim=-1, keepdim=True)
k_sq_05 = -0.5 * (s_key**2).sum(dim=-1, keepdim=True)
Expand Down Expand Up @@ -55,7 +41,6 @@ def __init__(self, hash_dim, **kwargs):
self.e2lsh = E2LSH(n_hashes=self.n_hashes, n_heads=self.num_heads, dim=hash_dim)

def forward(self, query, key, value, **kwargs):
# TODO: support batched inputs
query = query.view(-1, self.num_heads, self.dim_per_head)
key = key.view(-1, self.num_heads, self.dim_per_head)
value = value.view(-1, self.num_heads, self.dim_per_head)
Expand All @@ -72,17 +57,8 @@ def forward(self, query, key, value, **kwargs):
q_hat = rearrange(q_hat, "n h d -> h n d")
k_hat = rearrange(k_hat, "n h d -> h n d")
value = rearrange(value, "n h d -> h n d")
q_hat[:, kwargs["raw_size"] :] = 0.0
k_hat[:, kwargs["raw_size"] :] = 0.0
value[:, kwargs["raw_size"] :] = 0.0

q_hashed, k_hashed, hash_shift = lsh_mapping(self.e2lsh, q_hat, k_hat)
q_hashed[..., kwargs["raw_size"] :] = float("inf")
k_hashed[..., kwargs["raw_size"] :] = float("inf")

# q_shifts, k_shifts = get_geo_shift(kwargs["regions_h"], hash_shift, kwargs["region_indices"], self.n_hashes)
# q_hashed = q_hashed + q_shifts
# k_hashed = k_hashed + k_shifts

combined_shifts = kwargs["combined_shifts"] * hash_shift
q_hashed = q_hashed + combined_shifts
Expand Down
27 changes: 0 additions & 27 deletions example/hept_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange


Expand Down Expand Up @@ -41,9 +40,7 @@ def __init__(self, n_hashes, n_heads, dim, r=1):
super(E2LSH, self).__init__()

self.alpha = nn.Parameter(torch.normal(0, 1, (n_heads, dim, n_hashes)))
self.beta = nn.Parameter(uniform(0, r, shape=(1, n_hashes)))
self.alpha.requires_grad = False
self.beta.requires_grad = False

def forward(self, vecs):
projection = torch.bmm(vecs, self.alpha)
Expand Down Expand Up @@ -96,27 +93,3 @@ def sort_to_buckets(x, perm, bucketsz):
def unsort_from_buckets(s_x, perm_inverse):
b_x = rearrange(s_x, "h b nbuckets bucketsz d -> h b (nbuckets bucketsz) d")
return batched_index_select(b_x, perm_inverse)


def pad_to_multiple(tensor, multiple, dims=-1, value=0):
# try:
# dims = list(dims) # If dims is an iterable (e.g., List, Tuple)
# except:
# dims = [dims]
assert isinstance(dims, int)
dims = [dims]
# convert dims from negative to positive
dims = [d if d >= 0 else tensor.ndim + d for d in dims]
padding = [0] * (2 * tensor.ndim)
for d in dims:
size = tensor.size(d)
# Pytorch's JIT doesn't like divmod
# m, remainder = divmod(size, multiple)
m = size // multiple
remainder = size - m * multiple
if remainder != 0:
padding[2 * (tensor.ndim - d - 1) + 1] = multiple - remainder
if all(p == 0 for p in padding):
return tensor
else:
return F.pad(tensor, tuple(padding), value=value)
2 changes: 1 addition & 1 deletion example/trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from torch_geometric.utils import unbatch, to_undirected
from torch_geometric.utils import unbatch
from torchmetrics import MeanMetric
import numpy as np
from numba import jit
Expand Down
Loading

0 comments on commit 2e40838

Please sign in to comment.