From 722feacfb2d540902cc9439a9b5f891119f72d05 Mon Sep 17 00:00:00 2001 From: Misko Date: Sun, 24 Mar 2024 16:05:59 -0700 Subject: [PATCH] making space for v4 data format --- spf/data_collector.py | 274 +++++++++++++------------------ spf/dataset/v4_data.py | 51 ++++++ spf/drone_v3_datacollector.py | 122 ++++++++++++++ spf/mavlink_radio_collection.py | 4 +- spf/notebooks/zarr_testing.ipynb | 143 ++++++++++++++++ spf/sdrpluto/README.md | 4 + spf/sdrpluto/sdr_controller.py | 10 +- spf/utils.py | 7 + tests/test_dataset.py | 22 ++- tests/test_in_simulator.py | 12 +- 10 files changed, 476 insertions(+), 173 deletions(-) create mode 100644 spf/dataset/v4_data.py create mode 100644 spf/drone_v3_datacollector.py create mode 100644 spf/notebooks/zarr_testing.ipynb diff --git a/spf/data_collector.py b/spf/data_collector.py index f7fd043d..733a34eb 100644 --- a/spf/data_collector.py +++ b/spf/data_collector.py @@ -1,23 +1,17 @@ -import argparse -import json import logging -import os import struct import sys import threading import time -from datetime import datetime -from pathlib import Path from typing import Optional import numpy as np -import yaml from attr import dataclass from tqdm import tqdm from spf.dataset.rover_idxs import v3rx_column_names +from spf.dataset.v4_data import v4rx_new_dataset from spf.dataset.wall_array_v2_idxs import v2_column_names -from spf.gps.boundaries import franklin_safe from spf.rf import beamformer_given_steering, precompute_steering_vectors from spf.sdrpluto.sdr_controller import ( EmitterConfig, @@ -209,67 +203,53 @@ def start_read_thread(self): self.run = True self.t.start() - def get_data(self): + def get_rx(self, max_retries=15): tries = 0 - try: - signal_matrix = self.pplus.sdr.rx() - rssis = self.pplus.rssis() - gains = self.pplus.gains() - # rssi_and_gain = self.pplus.get_rssi_and_gain() - except Exception as e: - logging.error( - f"Failed to receive RX data! removing file : retry {tries} {e}", - ) - time.sleep(0.1) - tries += 1 - if tries > 15: - logging.error("GIVE UP") - sys.exit(1) + while tries < max_retries: + try: + signal_matrix = self.pplus.sdr.rx() + rssis = self.pplus.rssis() + gains = self.pplus.gains() + # rssi_and_gain = self.pplus.get_rssi_and_gain() + return {"signal_matrix": signal_matrix, "rssis": rssis, "gains": gains} + except Exception as e: + logging.error( + f"Failed to receive RX data! removing file : retry {tries} {e}", + ) + time.sleep(0.1) + tries += 1 + if tries > max_retries: + logging.error("GIVE UP") + sys.exit(1) + return None + + def get_data(self): + sdr_rx = self.get_rx() # process the data - signal_matrix = np.vstack(signal_matrix) + signal_matrix = np.vstack(sdr_rx["signal_matrix"]) current_time = time.time() - self.time_offset # timestamp self.data = data_to_snapshot( current_time=current_time, signal_matrix=signal_matrix, steering_vectors=self.steering_vectors, - rssis=rssis, - gains=gains, + rssis=sdr_rx(["rssis"]), + gains=sdr_rx(["gains"]), rx_config=self.pplus.rx_config, ) - # self.data = { - # "current_time": current_time, - # "signal_matrix": signal_matrix, - # "steering_vectors": steering_vectors, - # } - class DataCollector: - def __init__( - self, yaml_config, filename_npy, position_controller, column_names, tag="" - ): - self.column_names = column_names + def __init__(self, yaml_config, data_filename, position_controller, tag=""): self.yaml_config = yaml_config - self.filename_npy = filename_npy - Path(self.filename_npy).touch() + self.data_filename = data_filename + # self.record_matrix = None self.position_controller = position_controller self.finished_collecting = False - # record matrix - if not self.yaml_config["dry-run"]: - self.record_matrix = np.memmap( - self.filename_npy, - dtype="float32", - mode="w+", - shape=( - 2, # TODO should be nreceivers - self.yaml_config["n-records-per-receiver"], - len(self.column_names), - ), # t,tx,ty,rx,ry,rtheta,rspacing / avg1,avg2 / sds - ) + self.setup_record_matrix() def radios_to_online(self): # lets open all the radios @@ -386,6 +366,9 @@ def done(self): def is_collecting(self): return not self.finished_collecting + def setup_record_matrix(self): + raise NotImplementedError + def write_to_record_matrix(self, thread_idx, record_idx, read_thread: ThreadedRX): raise NotImplementedError @@ -414,14 +397,63 @@ def run_collector_thread(self): read_thread.join() +class DroneDataCollectorChunked(DataCollector): + def __init__(self, *args, **kwargs): + super(DroneDataCollectorChunked, self).__init__( + *args, + **kwargs, + ) + + def setup_record_matrix(self): + # make sure all receivers are sharing a common buffer size + buffer_size = None + for receiver in self.yaml_config["receivers"]: + assert "buffer_size" in receiver + if buffer_size is None: + buffer_size = self.yaml_config["buffer_size"] + else: + assert buffer_size == self.yaml_config["buffer_size"] + # record matrix + self.zarr = v4rx_new_dataset( + self.data_filename, + self.yaml_config["timesteps"], + buffer_size, + len(self.yaml_config["receivers"]), + chunk_size=4096, + compressor=None, + ) + + # def write_to_record_matrix(self, thread_idx, record_idx, data): + # current_pos_heading_and_time = ( + # self.position_controller.get_position_bearing_and_time() + # ) + + # self.record_matrix[thread_idx, record_idx] = prepare_record_entry_v3( + # ds=data, + # current_pos_heading_and_time=current_pos_heading_and_time, + # ) + + class DroneDataCollector(DataCollector): def __init__(self, *args, **kwargs): super(DroneDataCollector, self).__init__( *args, - column_names=v3rx_column_names(nthetas=kwargs["yaml_config"]["n-thetas"]), **kwargs, ) + def setup_record_matrix(self): + # record matrix + self.record_matrix = np.memmap( + self.data_filename, + dtype="float32", + mode="w+", + shape=( + 2, # TODO should be nreceivers + self.yaml_config["n-records-per-receiver"], + v3rx_column_names(nthetas=self.yaml_config["n-thetas"]), + ), # t,tx,ty,rx,ry,rtheta,rspacing / avg1,avg2 / sds + ) + def write_to_record_matrix(self, thread_idx, record_idx, data): current_pos_heading_and_time = ( self.position_controller.get_position_bearing_and_time() @@ -437,10 +469,24 @@ class FakeDroneDataCollector(DataCollector): def __init__(self, *args, **kwargs): super(FakeDroneDataCollector, self).__init__( *args, - column_names=v3rx_column_names(nthetas=kwargs["yaml_config"]["n-thetas"]), **kwargs, ) + def setup_record_matrix(self): + # record matrix + self.record_matrix = np.memmap( + self.data_filename, + dtype="float32", + mode="w+", + shape=( + 2, # TODO should be nreceivers + self.yaml_config["n-records-per-receiver"], + len( + v3rx_column_names(nthetas=self.yaml_config["n-thetas"]), + ), # t,tx,ty,rx,ry,rtheta,rspacing / avg1,avg2 / sds + ), + ) + def write_to_record_matrix(self, thread_idx, record_idx, data): self.record_matrix[thread_idx, record_idx] = prepare_record_entry_v3( ds=data, @@ -497,10 +543,22 @@ class GrblDataCollector(DataCollector): def __init__(self, *args, **kwargs): super(GrblDataCollector, self).__init__( *args, - column_names=v2_column_names(nthetas=kwargs["yaml_config"]["n-thetas"]), **kwargs, ) + def setup_record_matrix(self): + # record matrix + self.record_matrix = np.memmap( + self.data_filename, + dtype="float32", + mode="w+", + shape=( + 2, # TODO should be nreceivers + self.yaml_config["n-records-per-receiver"], + v2_column_names(nthetas=self.yaml_config["n-thetas"]), + ), # t,tx,ty,rx,ry,rtheta,rspacing / avg1,avg2 / sds + ) + def write_to_record_matrix(self, thread_idx, record_idx, data): tx_pos = self.position_controller.controller.position["xy"][ self.yaml_config["emitter"]["motor_channel"] @@ -512,115 +570,3 @@ def write_to_record_matrix(self, thread_idx, record_idx, data): self.record_matrix[thread_idx, record_idx] = prepare_record_entry_v2( ds=data, rx_pos=rx_pos, tx_pos=tx_pos ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "-c", - "--yaml-config", - type=str, - help="YAML config file", - required=True, - ) - parser.add_argument( - "-t", "--tag", type=str, help="tag files", required=False, default="" - ) - parser.add_argument( - "--tx-gain", type=int, help="tag files", required=False, default=None - ) - parser.add_argument( - "-l", - "--logging-level", - type=str, - help="Logging level", - default="INFO", - required=False, - ) - parser.add_argument( - "-m", - "--device-mapping", - type=str, - help="Device mapping file", - default=None, - required=True, - ) - args = parser.parse_args() - - run_started_at = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") - # read YAML - with open(args.yaml_config, "r") as stream: - yaml_config = yaml.safe_load(stream) - - # open device mapping and figure out URIs - with open(args.device_mapping, "r") as device_mapping: - port_to_uri = { - int(mapping[0]): f"usb:1.{mapping[1]}.5" - for mapping in [line.strip().split() for line in device_mapping] - } - - for receiver in yaml_config["receivers"] + [yaml_config["emitter"]]: - if "receiver-port" in receiver: - receiver["receiver-uri"] = port_to_uri[receiver["receiver-port"]] - if "emitter-port" in yaml_config["emitter"]: - yaml_config["emitter"]["emitter-uri"] = port_to_uri[ - yaml_config["emitter"]["emitter-port"] - ] - - if args.tx_gain is not None: - assert yaml_config["emitter"]["type"] == "sdr" - yaml_config["emitter"]["tx-gain"] = args.tx_gain - - output_files_prefix = f"rover_{run_started_at}_nRX{len(yaml_config['receivers'])}_{yaml_config['routine']}" - if args.tag != "": - output_files_prefix += f"_tag_{args.tag}" - - # setup filename - # tmpdir = tempfile.TemporaryDirectory() - # temp_dir_name = tmpdir.name - temp_dir_name = "./" - filename_log = f"{temp_dir_name}/{output_files_prefix}.log.tmp" - filename_yaml = f"{temp_dir_name}/{output_files_prefix}.yaml.tmp" - filename_npy = f"{temp_dir_name}/{output_files_prefix}.npy.tmp" - temp_filenames = [filename_log, filename_yaml, filename_npy] - final_filenames = [x.replace(".tmp", "") for x in temp_filenames] - - logger = logging.getLogger(__name__) - - # setup logging - handlers = [ - logging.StreamHandler(), - logging.FileHandler(filename_log), - ] - logging.basicConfig( - handlers=handlers, - format="%(asctime)s:%(levelname)s:%(message)s", - level=getattr(logging, args.logging_level.upper(), None), - ) - - # make a copy of the YAML - with open(filename_yaml, "w") as outfile: - yaml.dump(yaml_config, outfile, default_flow_style=False) - - logging.info(json.dumps(yaml_config, sort_keys=True, indent=4)) - - boundary = franklin_safe - - logging.info("Starting data collector...") - data_collector = FakeDroneDataCollector( - filename_npy=filename_npy, yaml_config=yaml_config, position_controller=None - ) - data_collector.radios_to_online() # blocking - - if len(yaml_config["receivers"]) == 0: - logging.info("EMITTER ONLINE!") - while True: - time.sleep(5) - else: - data_collector.start() - while data_collector.is_collecting(): - time.sleep(5) - - # we finished lets move files out to final positions - for idx in range(len(temp_filenames)): - os.rename(temp_filenames[idx], final_filenames[idx]) diff --git a/spf/dataset/v4_data.py b/spf/dataset/v4_data.py new file mode 100644 index 00000000..22b1aaf4 --- /dev/null +++ b/spf/dataset/v4_data.py @@ -0,0 +1,51 @@ +import zarr +from numcodecs import Blosc + +v4rx_f64_keys = [ + "system_timestamps", + "gps_timestamps", + "lat", + "long", + "heading", + "avg_phase_diff", + "rssi", + "gain", +] + + +def v4rx_keys(): + return v4rx_f64_keys + ["signal_matrix"] + + +def v4rx_new_dataset( + filename, timesteps, buffer_size, n_receivers, chunk_size=4096, compressor=None +): + z = zarr.open( + filename, + mode="w", + ) + if compressor is None: + compressor = Blosc( + cname="zstd", + clevel=1, + shuffle=Blosc.BITSHUFFLE, + ) + z.create_group("receivers") + for receiver_idx in range(n_receivers): + receiver_z = z["receivers"].create_group(f"r{receiver_idx}") + receiver_z.create_dataset( + "signal_matrix", + shape=(timesteps, 2, buffer_size), + chunks=(1, 1, 1024 * chunk_size), + dtype="complex128", + compressor=compressor, + ) + for key in v4rx_f64_keys: + receiver_z.create_dataset( + key, + shape=(timesteps,), + chunks=(1024 * chunk_size), + dtype="float64", + compressor=compressor, + ) + return z diff --git a/spf/drone_v3_datacollector.py b/spf/drone_v3_datacollector.py new file mode 100644 index 00000000..dba2bb47 --- /dev/null +++ b/spf/drone_v3_datacollector.py @@ -0,0 +1,122 @@ +import argparse +import json +import logging +import os +import time +from datetime import datetime + +import yaml + +from spf.data_collector import FakeDroneDataCollector +from spf.gps.boundaries import franklin_safe + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-c", + "--yaml-config", + type=str, + help="YAML config file", + required=True, + ) + parser.add_argument( + "-t", "--tag", type=str, help="tag files", required=False, default="" + ) + parser.add_argument( + "--tx-gain", type=int, help="tag files", required=False, default=None + ) + parser.add_argument( + "-l", + "--logging-level", + type=str, + help="Logging level", + default="INFO", + required=False, + ) + parser.add_argument( + "-m", + "--device-mapping", + type=str, + help="Device mapping file", + default=None, + required=True, + ) + args = parser.parse_args() + + run_started_at = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + # read YAML + with open(args.yaml_config, "r") as stream: + yaml_config = yaml.safe_load(stream) + + # open device mapping and figure out URIs + with open(args.device_mapping, "r") as device_mapping: + port_to_uri = { + int(mapping[0]): f"usb:1.{mapping[1]}.5" + for mapping in [line.strip().split() for line in device_mapping] + } + + for receiver in yaml_config["receivers"] + [yaml_config["emitter"]]: + if "receiver-port" in receiver: + receiver["receiver-uri"] = port_to_uri[receiver["receiver-port"]] + if "emitter-port" in yaml_config["emitter"]: + yaml_config["emitter"]["emitter-uri"] = port_to_uri[ + yaml_config["emitter"]["emitter-port"] + ] + + if args.tx_gain is not None: + assert yaml_config["emitter"]["type"] == "sdr" + yaml_config["emitter"]["tx-gain"] = args.tx_gain + + output_files_prefix = f"rover_{run_started_at}_nRX{len(yaml_config['receivers'])}_{yaml_config['routine']}" + if args.tag != "": + output_files_prefix += f"_tag_{args.tag}" + + # setup filename + # tmpdir = tempfile.TemporaryDirectory() + # temp_dir_name = tmpdir.name + temp_dir_name = "./" + filename_log = f"{temp_dir_name}/{output_files_prefix}.log.tmp" + filename_yaml = f"{temp_dir_name}/{output_files_prefix}.yaml.tmp" + filename_data = f"{temp_dir_name}/{output_files_prefix}.npy.tmp" + temp_filenames = [filename_log, filename_yaml, filename_data] + final_filenames = [x.replace(".tmp", "") for x in temp_filenames] + + logger = logging.getLogger(__name__) + + # setup logging + handlers = [ + logging.StreamHandler(), + logging.FileHandler(filename_log), + ] + logging.basicConfig( + handlers=handlers, + format="%(asctime)s:%(levelname)s:%(message)s", + level=getattr(logging, args.logging_level.upper(), None), + ) + + # make a copy of the YAML + with open(filename_yaml, "w") as outfile: + yaml.dump(yaml_config, outfile, default_flow_style=False) + + logging.info(json.dumps(yaml_config, sort_keys=True, indent=4)) + + boundary = franklin_safe + + logging.info("Starting data collector...") + data_collector = FakeDroneDataCollector( + filename_npy=filename_data, yaml_config=yaml_config, position_controller=None + ) + data_collector.radios_to_online() # blocking + + if len(yaml_config["receivers"]) == 0: + logging.info("EMITTER ONLINE!") + while True: + time.sleep(5) + else: + data_collector.start() + while data_collector.is_collecting(): + time.sleep(5) + + # we finished lets move files out to final positions + for idx in range(len(temp_filenames)): + os.rename(temp_filenames[idx], final_filenames[idx]) diff --git a/spf/mavlink_radio_collection.py b/spf/mavlink_radio_collection.py index e9c4a616..bc5eb6c9 100644 --- a/spf/mavlink_radio_collection.py +++ b/spf/mavlink_radio_collection.py @@ -191,13 +191,13 @@ def filenames_from_time_in_seconds(time_in_seconds, temp_dir_name, yaml_config): if not args.fake_radio: data_collector = DroneDataCollector( - filename_npy=temp_filenames["npy"], + data_filename=temp_filenames["npy"], yaml_config=yaml_config, position_controller=drone, ) else: data_collector = FakeDroneDataCollector( - filename_npy=temp_filenames["npy"], + data_filename=temp_filenames["npy"], yaml_config=yaml_config, position_controller=None, ) diff --git a/spf/notebooks/zarr_testing.ipynb b/spf/notebooks/zarr_testing.ipynb new file mode 100644 index 00000000..a456d825 --- /dev/null +++ b/spf/notebooks/zarr_testing.ipynb @@ -0,0 +1,143 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import zarr\n", + "import numpy as np\n", + "\n", + "from numcodecs import Blosc, Zstd\n", + "\n", + "\n", + "def random_signal_matrix(self, buffer_size):\n", + " return np.random.uniform(-1, 1, (buffer_size,)) + 1.0j * np.random.uniform(\n", + " -1, 1, (buffer_size,)\n", + " )\n", + "\n", + "\n", + "total_samples = 2**24\n", + "buffer_size = 2**18\n", + "assert total_samples % buffer_size == 0\n", + "timesteps = total_samples // buffer_size\n", + "chunk_size = 4096\n", + "data_points = 1024\n", + "filename = \"testzarr\"\n", + "\n", + "\n", + "f64_keys = [\n", + " \"system_timestamps\",\n", + " \"gps_timestamps\",\n", + " \"lat\",\n", + " \"long\",\n", + " \"heading\",\n", + " \"avg_phase_diff\",\n", + " \"rssi\",\n", + " \"gain\",\n", + "]\n", + "\n", + "\n", + "def v4rx_new_dataset(filename, n_receivers, compressor=None):\n", + " z = zarr.open(\n", + " filename,\n", + " mode=\"w\",\n", + " )\n", + " if compressor is None:\n", + " compressor = Blosc(\n", + " cname=\"zstd\",\n", + " clevel=1,\n", + " shuffle=Blosc.BITSHUFFLE,\n", + " )\n", + " z.create_group(\"receivers\")\n", + " for receiver_idx in range(n_receivers):\n", + " receiver_z = z[\"receivers\"].create_group(f\"r{receiver_idx}\")\n", + " receiver_z.create_dataset(\n", + " \"signal_matrix\",\n", + " shape=(timesteps, 2, buffer_size),\n", + " chunks=(1, 1, 1024 * chunk_size),\n", + " dtype=\"complex128\",\n", + " compressor=compressor,\n", + " )\n", + " for key in f64_keys:\n", + " receiver_z.create_dataset(\n", + " key,\n", + " shape=(timesteps,),\n", + " chunks=(1024 * chunk_size),\n", + " dtype=\"float64\",\n", + " compressor=compressor,\n", + " )\n", + " return z\n", + "\n", + "\n", + "z = v4rx_new_dataset(\"testdata\", 2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "z.tree()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "z.receivers.r0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "z.create_group(\"test\")\n", + "z.tree()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install ipytree" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "spf", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/spf/sdrpluto/README.md b/spf/sdrpluto/README.md index 42160b0b..c5a81444 100644 --- a/spf/sdrpluto/README.md +++ b/spf/sdrpluto/README.md @@ -32,6 +32,10 @@ This returns open objects to both receiver and emitter SDR. The emitter is not a ## Benchmarking [ Laptop / rpi4 ](https://docs.google.com/spreadsheets/d/1kEzWVTT2jg84SchoqwFrf9tJ0UjC-6dh6JA90qW3eL0/edit?usp=sharing) +``` +python spf/sdrpluto/benchmark.py --uri ip:192.168.1.17 --buffer-sizes '2**16' '2**18' '2**20' --rx-buffers 2 --write-to-file testdata --chunk-size 512 1024 4096 --compress blosc1 blosc4 none zstd1 zstd4 +``` + ## Command line client diff --git a/spf/sdrpluto/sdr_controller.py b/spf/sdrpluto/sdr_controller.py index 66b7c37e..16c759c1 100644 --- a/spf/sdrpluto/sdr_controller.py +++ b/spf/sdrpluto/sdr_controller.py @@ -139,6 +139,7 @@ def args_to_rx_config(args): intermediate=args.fi, uri=args.receiver_uri, rx_spacing=args.rx_spacing, + rx_buffers=args.kernel_buffers, # rx_theta_in_pis=0.25, ) @@ -680,7 +681,7 @@ def plot_recv_signal( axs[idx][1].set_title("Power recv (%d)" % idx) diff = pi_norm(np.angle(signal_matrix[0]) - np.angle(signal_matrix[1])) axs[0][3].clear() - axs[0][3].scatter(t, diff, s=1) + axs[0][3].scatter(t, diff, s=0.3, alpha=0.1) mean, _mean = circular_mean(diff) axs[0][3].axhline(y=mean, color="black", label="circular mean") axs[0][3].axhline(y=_mean, color="red", label="trimmed circular mean") @@ -772,6 +773,13 @@ def plot_recv_signal( required=False, default=int(2**9), ) # 12 + parser.add_argument( + "--kernel-buffers", + type=int, + help="kernel buffers", + required=False, + default=2, + ) # 12 parser.add_argument( "--rx-spacing", type=float, diff --git a/spf/utils.py b/spf/utils.py index 4b3d85eb..9caf9d04 100644 --- a/spf/utils.py +++ b/spf/utils.py @@ -1,3 +1,6 @@ +import numpy as np + + class dotdict(dict): __getattr__ = dict.get @@ -15,3 +18,7 @@ def is_pi(): return True except (RuntimeError, ImportError): return False + + +def random_signal_matrix(n): + return np.random.uniform(-1, 1, (n,)) + 1.0j * np.random.uniform(-1, 1, (n,)) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 00cdd1b0..d1cfec89 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -6,8 +6,9 @@ from spf.dataset.spf_dataset import SessionsDatasetSimulated from spf.dataset.spf_generate import generate_session_and_dump +from spf.dataset.v4_data import v4rx_f64_keys, v4rx_new_dataset from spf.rf import get_peaks_for_2rx -from spf.utils import dotdict +from spf.utils import dotdict, random_signal_matrix @pytest.fixture @@ -126,3 +127,22 @@ def test_live_data_generation(default_args): "/".join([args.output, "onesession.pkl"]), compression="lzma", ) + + +def testv4_data_create(): + with tempfile.TemporaryDirectory() as tmp: + timesteps = 11 + buffer_size = 2**13 + z = v4rx_new_dataset( + tmp + "/testdata", + timesteps=timesteps, + buffer_size=buffer_size, + n_receivers=2, + ) + for time_idx in range(timesteps): + for receiver_idx in range(2): + z.receivers[f"r{receiver_idx}"].signal_matrix[time_idx, :] = ( + random_signal_matrix(2 * buffer_size).reshape(2, buffer_size) + ) + for k in v4rx_f64_keys: + z.receivers[f"r{receiver_idx}"][k][time_idx] = np.random.rand() diff --git a/tests/test_in_simulator.py b/tests/test_in_simulator.py index 0e37d018..0605050b 100644 --- a/tests/test_in_simulator.py +++ b/tests/test_in_simulator.py @@ -13,14 +13,16 @@ root_dir = os.path.dirname(os.path.dirname(spf.__file__)) +simulator_speedup = 5 + @pytest.fixture(scope="session") def adrupilot_simulator(): client = docker.from_env() container = client.containers.run( "csmisko/ardupilotspf:latest", - "/ardupilot/Tools/autotest/sim_vehicle.py -l 37.76509485,-122.40940127,0,0 \ - -v rover -f rover-skid --out tcpin:0.0.0.0:14590 --out tcpin:0.0.0.0:14591 -S 5", + f"/ardupilot/Tools/autotest/sim_vehicle.py -l 37.76509485,-122.40940127,0,0 \ + -v rover -f rover-skid --out tcpin:0.0.0.0:14590 --out tcpin:0.0.0.0:14591 -S {simulator_speedup}", stdin_open=True, ports={ "14590/tcp": ("127.0.0.1", 14590), @@ -111,12 +113,12 @@ def test_time_since_boot(adrupilot_simulator): def test_reboot(adrupilot_simulator): - time1 = float(get_time_since_boot()[0]) + time1 = float(get_time_since_boot()[0]) / simulator_speedup start_time = time.time() time.sleep(1) end_time = time.time() time.sleep(1) - time2 = float(get_time_since_boot()[0]) + time2 = float(get_time_since_boot()[0]) / simulator_speedup assert (time2 - time1) > (end_time - start_time) assert (end_time - start_time) - (time2 - time1) < 20 @@ -128,7 +130,7 @@ def test_reboot(adrupilot_simulator): stderr=subprocess.STDOUT, env=get_env(), ) - time_since_boot = float(get_time_since_boot()[0]) + time_since_boot = float(get_time_since_boot()[0]) / simulator_speedup time.sleep(0.5) # takes some time to write to disk end_time = time.time() assert (end_time - start_time) > time_since_boot