Skip to content

Commit

Permalink
plotting: some stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
marcojob committed Dec 14, 2024
1 parent f306433 commit ce809e5
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 6 deletions.
2 changes: 1 addition & 1 deletion blearn/blearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from src.blender import Blender

def main():
b = Blender(config_file="config/config_rhone.yml")
b = Blender(config_file="config/config_rural_area_demo.yml")
b.start()


Expand Down
2 changes: 1 addition & 1 deletion blearn/config/config_mountain_area_demo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ output:
format: JPEG
render: CYCLES
paths:
number_of_samples: 100
number_of_samples: 200
x_min: 0.0
x_max: 1400.0
y_min: 0.0
Expand Down
2 changes: 1 addition & 1 deletion blearn/config/config_rhone_demo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ output:
render: CYCLES
use_second_view: False
paths:
number_of_samples: 100
number_of_samples: 200
x_min: 0.0
x_max: 200.0
y_min: 0.0
Expand Down
2 changes: 1 addition & 1 deletion blearn/config/config_road_corridor_demo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ output:
format: JPEG
render: CYCLES
paths:
number_of_samples: 100
number_of_samples: 200
x_min: 0.0
x_max: 800.0
y_min: 0.0
Expand Down
2 changes: 1 addition & 1 deletion blearn/config/config_rural_area_demo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ output:
format: JPEG
render: CYCLES
paths:
number_of_samples: 100
number_of_samples: 200
x_min: 0.0
x_max: 500.0
y_min: 0.0
Expand Down
2 changes: 1 addition & 1 deletion radarmeetsvision/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def validate_epoch(self, epoch, val_loader, iteration_callback=None):

# TODO: Expand on this interface
if iteration_callback is not None:
iteration_callback(int(sample['index']), depth_prediction)
iteration_callback(sample, depth_prediction)

if mask is not None:
current_results = eval_depth(depth_prediction[mask], depth_target[mask])
Expand Down
57 changes: 57 additions & 0 deletions scripts/plotting/plot_rgbr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import argparse
import matplotlib.pyplot as plt
import numpy as np
import cv2
import radarmeetsvision as rmv

from pathlib import Path

rgbr_dir = None

def prediction_callback(sample, prediction):
image = sample['image'].cpu().numpy().squeeze()
image = np.moveaxis(image, 0, -1)
depth_prior = sample['depth_prior'].cpu().numpy().squeeze()

image = cv2.resize(image, (640, 480))
depth_prior = cv2.resize(depth_prior, (640, 480))

nonzero_mask = (depth_prior > 0.0)
ones = np.ones(nonzero_mask.sum())
zeros = np.zeros(nonzero_mask.sum())
red = np.vstack((ones, zeros, zeros)).T
image[nonzero_mask] = red
image = np.clip(image, 0.0, 1.0)
plt.imsave(str(rgbr_dir / f"{int(sample['index']):05d}_rgbr.jpg"), image)

def main(args):
rmv.setup_global_logger()
interface = rmv.Interface()
interface.set_encoder('vitb')
depth_min = 0.19983673095703125
depth_max = 120.49285888671875
interface.set_depth_range((depth_min, depth_max))
interface.set_output_channels(2)

interface.set_size(480, 640)
interface.set_batch_size(1)
interface.set_criterion()
interface.set_use_depth_prior(True)

interface.load_model(pretrained_from=args.network)
_, loader = interface.get_single_dataset(args.dataset, min_index=0, max_index=-1)

global rgbr_dir
rgbr_dir = Path(args.dataset) / 'rgbr'
if not rgbr_dir.is_dir():
rgbr_dir.mkdir(exist_ok=True)

interface.validate_epoch(0, loader, iteration_callback=prediction_callback)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Purely evalute a network')
parser.add_argument('--dataset', type=str, required=True, help='Path to the dataset directory')
parser.add_argument('--network', type=str, help='Path to the network file')
args = parser.parse_args()
main(args)

0 comments on commit ce809e5

Please sign in to comment.