Skip to content

Commit

Permalink
debug and fix v2 collection
Browse files Browse the repository at this point in the history
misc

debug v2

ignore temp yaml

docs and checks

fix v2 collection

update v2 notebook
  • Loading branch information
misko committed Jan 11, 2024
1 parent 47736f9 commit 4eea90d
Show file tree
Hide file tree
Showing 12 changed files with 793 additions and 278 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ test_data.txt
**/*.pkl
**/*.log
**/token
**/session_output_*.png
**/session_output_*.png
**/*2024*.yaml
73 changes: 49 additions & 24 deletions spf/dataset/spf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import yaml
from compress_pickle import load
from deepdiff import DeepDiff
from torch.utils.data import Dataset

from spf.dataset.spf_generate import generate_session
Expand All @@ -35,6 +36,17 @@
)


# from Stackoverflow
def yaml_as_dict(my_file):
my_dict = {}
with open(my_file, "r") as fp:
docs = yaml.safe_load_all(fp)
for doc in docs:
for key, value in doc.items():
my_dict[key] = value
return my_dict


def pos_to_rel(p, width):
return 2 * (p / width - 0.5)

Expand Down Expand Up @@ -152,14 +164,17 @@ def get_m(self, filename, bypass=False):
)
return self.m_cache[filename]

def get_all_files_with_extension(self, extension):
return filter(
lambda x: f".{extension}" in x,
["%s/%s" % (self.root_dir, x) for x in os.listdir(self.root_dir)],
)

def get_all_valid_files(self):
return sorted(
filter(
self.check_file,
filter(
lambda x: ".npy" in x,
["%s/%s" % (self.root_dir, x) for x in os.listdir(self.root_dir)],
),
self.get_all_files_with_extension("npy"),
)
)

Expand Down Expand Up @@ -189,6 +204,7 @@ def __init__(
root_dir (string): Directory with all the images.
"""
assert nsources == 1 # TODO implement more
self.check_files = True
self.root_dir = root_dir
self.nthetas = nthetas
self.thetas = np.linspace(-np.pi, np.pi, self.nthetas)
Expand All @@ -208,7 +224,7 @@ def __init__(
)
self.detector_position = np.array([[receiver_pos_x, receiver_pos_y]])
self.snapshots_in_file = snapshots_in_file
self.snapshots_in_sample = snapshots_in_session
self.snapshots_in_session = snapshots_in_session
self.step_size = step_size
self.filenames = self.get_all_valid_files()
if len(self.filenames) == 0:
Expand All @@ -223,26 +239,26 @@ def __init__(
# mode='r',
# shape=(self.snapshots_in_file,self.nthetas+5)) for filename in self.filenames
# ]
self.samples_per_file = [
self.sessions_per_file = [
self.get_m(filename, bypass=True).shape[0]
- (self.snapshots_in_sample * self.step_size)
- (self.snapshots_in_session * self.step_size)
for filename in self.filenames
]
self.cumsum_samples_per_file = np.cumsum([0] + self.samples_per_file)
self.len = sum(self.samples_per_file)
self.zeros = np.zeros((self.snapshots_in_sample, 5))
self.ones = np.ones((self.snapshots_in_sample, 5))
self.cumsum_sessions_per_file = np.cumsum([0] + self.sessions_per_file)
self.len = sum(self.sessions_per_file)
self.zeros = np.zeros((self.snapshots_in_session, 5))
self.ones = np.ones((self.snapshots_in_session, 5))
self.widths = (
np.ones((self.snapshots_in_sample, 1), dtype=np.int32) * self.args.width
np.ones((self.snapshots_in_session, 1), dtype=np.int32) * self.args.width
)
self.halfpis = -np.ones((self.snapshots_in_sample, 1)) * np.pi / 2
self.halfpis = -np.ones((self.snapshots_in_session, 1)) * np.pi / 2
# print("WARNING BY DEFAULT FLIPPING RADIO FEATURE SINCE COORDS WERE WRONG IN PI COLLECT!")

def __getitem__(self, idx):
fileidx, startidx = self.idx_to_fileidx_and_startidx(idx)
m = self.get_m(self.filenames[fileidx])[
startidx : startidx
+ self.snapshots_in_sample * self.step_size : self.step_size
+ self.snapshots_in_session * self.step_size : self.step_size
]

detector_position_at_t = np.broadcast_to(
Expand Down Expand Up @@ -283,10 +299,22 @@ class SessionsDatasetRealV2(SessionsDatasetReal):
def file_shape(self):
return (2, self.snapshots_in_file, len(self.column_names))

def get_yaml_config(self):
yaml_config = None
for yaml_fn in self.get_all_files_with_extension("yaml"):
if yaml_config is None:
yaml_config = yaml.safe_load(open(yaml_fn, "r"))
else:
yaml_config_b = yaml.safe_load(open(yaml_fn, "r"))
ddiff = DeepDiff(yaml_config, yaml_config_b, ignore_order=True)
if len(ddiff):
raise ValueError("YAML configs do not match")
return yaml_config

def __init__(
self,
root_dir,
yaml_config_fn,
yaml_config_fn=None,
snapshots_in_session=128,
nsources=1,
step_size=1,
Expand All @@ -301,21 +329,20 @@ def __init__(
root_dir (string): Directory with all the images.
"""
self.check_files = check_files
# read YAML
with open(yaml_config_fn, "r") as stream:
yaml_config = yaml.safe_load(stream)

assert nsources == 1 # TODO implement more
self.root_dir = root_dir
self.snapshots_in_file = snapshots_in_file
self.snapshots_in_session = snapshots_in_session
self.step_size = step_size
yaml_config = self.get_yaml_config()

assert nsources == 1 # TODO implement more

self.root_dir = root_dir
self.nthetas = yaml_config["n-thetas"]
self.thetas = np.linspace(-np.pi, np.pi, self.nthetas)

self.column_names = v2_column_names(nthetas=self.nthetas)

self.filenames = self.get_all_valid_files()
self.args = dotdict(
{
"width": yaml_config["width"],
Expand All @@ -337,8 +364,6 @@ def __init__(
)
self.rx_antenna_offsets = np.array(self.rx_antenna_offsets)

self.filenames = self.get_all_valid_files()

if len(self.filenames) == 0:
print("SessionsDatasetReal: No valid files to load from")
raise ValueError
Expand Down Expand Up @@ -376,7 +401,7 @@ def __getitem__(self, idx):
if unadjusted_startidx >= sessions_per_receiver: # use B
receiver_idx = 1
startidx = unadjusted_startidx - sessions_per_receiver

# receiver_idx = 1 - receiver_idx
m = self.get_m(self.filenames[fileidx])[
receiver_idx,
startidx : startidx
Expand Down
Loading

0 comments on commit 4eea90d

Please sign in to comment.