Skip to content

Commit

Permalink
Add scripts to visualize predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
kbrodt committed Jun 20, 2021
1 parent b25f258 commit 22d3538
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 0 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ python submit.py \
--save ${SAVE}
```


## Visualize predictions

```bash
python visualize.py \
--model ${MODEL_PATH_TO_JIT} \
--data ${DATA_PATH} \
--save ./viz
```

## Useful links

* [kaggle lyft 3rd place solution](https://gdude.de/blog/2021-02-05/Kaggle-Lyft-solution)
92 changes: 92 additions & 0 deletions visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import argparse
import os

import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.pyplot import figure
from torch.utils.data import DataLoader

from train import WaymoLoader, pytorch_neg_multi_log_likelihood_batch


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--data", type=str, required=True)
parser.add_argument("--save", type=str, required=True)
parser.add_argument("--n-samples", type=int, required=False, default=50)

args = parser.parse_args()

return args


def main():
args = parse_args()
if not os.path.exists(args.save):
os.mkdir(args.save)

model = torch.jit.load(args.model).cuda().eval()
loader = DataLoader(
WaymoLoader(args.data, return_vector=True),
batch_size=1,
num_workers=1,
shuffle=False,
)

iii = 0
with torch.no_grad():
for x, y, is_available, vector_data in loader:
x, y, is_available = map(lambda x: x.cuda(), (x, y, is_available))

confidences_logits, logits = model(x)

argmax = confidences_logits.argmax()
confidences_logits = confidences_logits[:, argmax].unsqueeze(1)
logits = logits[:, argmax].unsqueeze(1)

loss = pytorch_neg_multi_log_likelihood_batch(
y, logits, confidences_logits, is_available
)
confidences = torch.softmax(confidences_logits, dim=1)
V = vector_data[0]

X, idx = V[:, :44], V[:, 44].flatten()

figure(figsize=(15, 15), dpi=80)
for i in np.unique(idx):
_X = X[idx == i]
if _X[:, 5:12].sum() > 0:
plt.plot(_X[:, 0], _X[:, 1], linewidth=4, color="purple")
else:
plt.plot(_X[:, 0], _X[:, 1], color="black")
logits = logits.cpu().numpy()[0]
y = y.cpu().numpy()[0]
is_available = is_available.long().cpu().numpy()[0]
plt.plot(
y[is_available > 0][::10, 0],
y[is_available > 0][::10, 1],
"-o",
label="GT",
)
plt.plot(
logits[confidences[0].argmax()][is_available > 0][::10, 0],
logits[confidences[0].argmax()][is_available > 0][::10, 1],
"-o",
label="PRED",
)

plt.title(loss.item())
plt.legend()
plt.savefig(
os.path.join(args.save, f"{iii:0>2}_{loss.item():.3f}.png")
)
plt.close()
iii += 1
if iii == args.n_samples:
break


if __name__ == "__main__":
main()

0 comments on commit 22d3538

Please sign in to comment.