Skip to content

Commit

Permalink
batchv1
Browse files Browse the repository at this point in the history
  • Loading branch information
siqim committed Apr 22, 2024
1 parent 1d0682f commit 350a986
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 50 deletions.
68 changes: 22 additions & 46 deletions example/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
"device = 'cuda:7'\n",
"dataset_name = 'tracking-60k'\n",
"batch_size = 1\n",
"model_configs = {'block_size': 100, 'n_hashes': 3, 'num_regions': 150, 'num_heads': 8, 'h_dim': 24, 'n_layers': 4, 'num_w_per_dist': 10}"
"model_configs = {'block_size': 100, 'n_hashes': 3, 'num_regions': 150, 'num_heads': 8, 'h_dim': 24, 'n_layers': 4, 'num_w_per_dist': 10}\n",
"torch.cuda.set_device(device)"
]
},
{
Expand Down Expand Up @@ -101,33 +102,21 @@
{
"cell_type": "code",
"execution_count": 6,
"id": "239d20f0a8b6d94d",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-10T15:26:12.414265Z",
"start_time": "2024-04-10T15:26:01.158690Z"
}
},
"id": "1c365bee",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/5 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Epoch 0] test , loss: 0.5851, acc: 0.9193, prec: 0.3807, recall: 0.9749: 100%|██████████| 5/5 [00:03<00:00, 1.28it/s]"
"[Epoch 0] test , loss: 0.5879, acc: 0.9191, prec: 0.3807, recall: 0.9750: 100%|██████████| 5/5 [00:11<00:00, 2.25s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy@0.9: 0.9193\n"
"Test accuracy@0.9: 0.9191\n"
]
},
{
Expand Down Expand Up @@ -157,48 +146,35 @@
"metadata": {},
"outputs": [],
"source": [
"model = torch.compile(model)"
"# model = torch.compile(model)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "65848375",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<torch.utils.benchmark.utils.common.Measurement object at 0x7f0247245900>\n",
"model(data.x, data.coords, data.batch)\n",
"setup: from __main__ import model, data\n",
" Median: 29.41 ms\n",
" IQR: 0.08 ms (29.37 to 29.45)\n",
" 170 measurements, 1 runs per measurement, 1 thread\n"
]
}
],
"outputs": [],
"source": [
"torch.set_float32_matmul_precision('high')\n",
"for data in loaders[\"test\"]:\n",
" if data.x.shape[0] > 60000:\n",
" data = data.to(device)\n",
" break\n",
"# torch.set_float32_matmul_precision('high')\n",
"# for data in loaders[\"test\"]:\n",
"# if data.x.shape[0] > 60000:\n",
"# data = data.to(device)\n",
"# break\n",
"\n",
"model.eval()\n",
"with torch.no_grad():\n",
" t1 = benchmark.Timer(\n",
" stmt=f\"model(data.x, data.coords, data.batch)\", setup=f\"from __main__ import model, data\"\n",
" )\n",
" m = t1.blocked_autorange(min_run_time=5)\n",
"print(m)"
"# model.eval()\n",
"# with torch.no_grad():\n",
"# t1 = benchmark.Timer(\n",
"# stmt=f\"model(data.x, data.coords, data.batch)\", setup=f\"from __main__ import model, data\"\n",
"# )\n",
"# m = t1.blocked_autorange(min_run_time=5)\n",
"# print(m)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3bb2773a",
"id": "33709f92",
"metadata": {},
"outputs": [],
"source": []
Expand All @@ -220,7 +196,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.10.0"
}
},
"nbformat": 4,
Expand Down
9 changes: 6 additions & 3 deletions example/hept.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,13 @@ def forward(self, query, key, value, **kwargs):
q_hashed[..., kwargs["raw_size"] :] = float("inf")
k_hashed[..., kwargs["raw_size"] :] = float("inf")

q_shifts, k_shifts = get_geo_shift(kwargs["regions_h"], hash_shift, kwargs["region_indices"], self.n_hashes)
# q_shifts, k_shifts = get_geo_shift(kwargs["regions_h"], hash_shift, kwargs["region_indices"], self.n_hashes)
# q_hashed = q_hashed + q_shifts
# k_hashed = k_hashed + k_shifts

q_hashed = q_hashed + q_shifts
k_hashed = k_hashed + k_shifts
combined_shifts = kwargs["combined_shifts"] * hash_shift
q_hashed = q_hashed + combined_shifts
k_hashed = k_hashed + combined_shifts

q_positions = q_hashed.argsort(dim=-1)
k_positions = k_hashed.argsort(dim=-1)
Expand Down
15 changes: 14 additions & 1 deletion example/transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import torch
from torch import nn
from torch_geometric.nn import MLP
Expand All @@ -6,6 +7,12 @@
from hept_utils import quantile_partition, get_regions, pad_to_multiple


def bit_shift(base, shift_idx):
max_base = base.max()
num_bits = math.ceil(math.log2(max_base + 1))
return (shift_idx << num_bits) | base


def prepare_input(x, coords, batch, helper_params):
kwargs = {}
key_padding_mask = None
Expand All @@ -14,7 +21,7 @@ def prepare_input(x, coords, batch, helper_params):
kwargs["coords"] = coords

with torch.no_grad():
block_size = helper_params["block_size"]
block_size, num_heads = helper_params["block_size"], helper_params["num_heads"]
kwargs["raw_size"] = x.shape[0]
x = pad_to_multiple(x, block_size, dims=0)
kwargs["coords"] = pad_to_multiple(kwargs["coords"], block_size, dims=0, value=float("inf"))
Expand All @@ -27,6 +34,11 @@ def prepare_input(x, coords, batch, helper_params):
kwargs["region_indices"] = [region_indices_eta, region_indices_phi]
kwargs["regions_h"] = regions_h
kwargs["coords"][kwargs["raw_size"] :] = 0.0

combined_shifts = bit_shift(region_indices_eta.long(), region_indices_phi.long())
paded_batch = pad_to_multiple(batch, block_size, dims=0, value=batch.max() + 1)
combined_shifts = bit_shift(combined_shifts, paded_batch[None])
kwargs["combined_shifts"] = rearrange(combined_shifts, "(c h) n -> c h n", h=num_heads)
return x, mask, kwargs


Expand Down Expand Up @@ -65,6 +77,7 @@ def __init__(self, in_dim, coords_dim, num_classes, dropout=0.1, **kwargs):
self.helper_params["block_size"] = kwargs["block_size"]
self.regions = nn.Parameter(get_regions(kwargs["num_regions"], kwargs["n_hashes"], kwargs["num_heads"]), requires_grad=False)
self.helper_params["regions"] = self.regions
self.helper_params["num_heads"] = kwargs["num_heads"]

if self.num_classes:
self.out_proj = nn.Linear(int(self.h_dim // 2), num_classes)
Expand Down

0 comments on commit 350a986

Please sign in to comment.