Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes definition of point of interest and limits the number of POIs #93

Merged
merged 2 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
fail-fast: false
matrix:
os: ["ubuntu-latest", "macos-latest", "windows-latest"]
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
66 changes: 31 additions & 35 deletions bird_cloud_gnn/radar_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@ class RadarDataset(DGLDataset):
data_folder (str): Folder with the CSV files.
features (array of str): List of features expected to be present at every CSV file.
target (str): Target column. 0, 1 or missing expected.
max_distance (float): Maximum distance to look for neighbours.
min_neighbours (int): If a point has less than this amount of neighbours, it is ignored.
num_neighbours (int): If a point has less than this amount of neighbours, it is ignored.
max_edge_distance (float): Creates a edge between two nodes if their distance is less than this value.
max_poi_per_label (int): Select at most this amount of POIs. If there are more POIs, they are chosen randomly.
"""

# pylint: disable=too-many-arguments, too-many-locals
# pylint: disable=too-many-arguments
def __init__(
self,
data,
features,
target,
name="Radar",
max_distance=500.0,
min_neighbours=100,
num_neighbours=100,
max_edge_distance=50.0,
max_poi_per_label=200,
):
"""Constructor

Expand All @@ -48,15 +48,12 @@ def __init__(
features (array of str): List of features expected to be present in every CSV file.
If "centered_x" and/or "centered_y" are included these are calculated on the fly.
target (str): Target column. 0, 1 or missing expected.
max_distance (float, optional): Maximum distance to look for neighbours. Defaults to
500.0.
min_neighbours (int, optional): If a point has less than this amount of neighbours, it
is ignored. Defaults to 100.
num_neighbours (int, optional): Number of selected neighbours. Defaults to 100.
max_edge_distance (float, optional): Creates a edge between two nodes if their distance
is less than this value. Default to 50.0.

Raises:
ValueError: If `data` is not a valid folder, file or pandas.DataFrame
ValueError: If `data` is not a valid folder, file or pandas.DataFrame
"""

self.data_path = None
Expand All @@ -75,9 +72,9 @@ def __init__(
self._name = name
self.features = features
self.target = target
self.max_distance = max_distance
self.min_neighbours = min_neighbours
self.num_neighbours = num_neighbours
self.max_edge_distance = max_edge_distance
self.max_poi_per_label = max_poi_per_label
self.graphs = []
self.labels = []
super().__init__(
Expand All @@ -87,9 +84,9 @@ def __init__(
data_hash,
features,
target,
max_distance,
min_neighbours,
num_neighbours,
max_edge_distance,
max_poi_per_label,
),
)

Expand Down Expand Up @@ -129,32 +126,31 @@ def _process_data(self, data):

data_features = data[temp_features]

na_index = data[data[self.target].isna()].index

data_xyz_notna = data_xyz.drop(na_index)
data_features_notna = data_features.drop(na_index)

data_target = data[self.target]
data_target_notna = data_target[data_xyz_notna.index]
tree = KDTree(data_xyz)

data_xyz_notna.reset_index(drop=True, inplace=True)
data_features_notna.reset_index(drop=True, inplace=True)
def sample_or_all(input_array, k):
if len(input_array) <= k:
return input_array

tree = KDTree(data_xyz)
tree_notna = KDTree(data_xyz_notna)
rng = np.random.default_rng()
return rng.choice(input_array, k, replace=False)

distance_matrix = tree_notna.sparse_distance_matrix(
tree, self.max_distance, output_type="coo_matrix"
points_of_interest = np.concatenate(
[
sample_or_all(
data[data[self.target] == label].index.to_numpy(),
self.max_poi_per_label,
)
for label in [0, 1] # Current possible labels
]
)

number_neighbours = distance_matrix.getnnz(1)
points_of_interest = np.where(number_neighbours >= self.min_neighbours)[0]

_, poi_indexes = tree.query(
data_xyz_notna.loc[points_of_interest], self.min_neighbours
data_xyz.loc[points_of_interest], self.num_neighbours
)
self.labels = np.concatenate(
(self.labels, data_target_notna.values[points_of_interest])
(self.labels, data_target.values[points_of_interest])
)
for _, indexes in enumerate(poi_indexes):
local_xyz = data_xyz.iloc[indexes]
Expand Down Expand Up @@ -218,8 +214,8 @@ def save(self):
"data_path": self.data_path,
"features": self.features,
"target": self.target,
"max_distance": self.max_distance,
"min_neighbours": self.min_neighbours,
"num_neighbours": self.num_neighbours,
"max_poi_per_label": self.max_poi_per_label,
},
)

Expand All @@ -239,8 +235,8 @@ def load(self):
self.data_path = info["data_path"]
self.features = info["features"]
self.target = info["target"]
self.max_distance = info["max_distance"]
self.min_neighbours = info["min_neighbours"]
self.num_neighbours = info["num_neighbours"]
self.max_poi_per_label = info["max_poi_per_label"]

def cache_dir(self):
if self.data_path is None:
Expand Down
3 changes: 1 addition & 2 deletions project_setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ checklist](https://guide.esciencecenter.nl/#/best_practices/checklist).

This repository is set up with Python versions:

- 3.6
- 3.7
- 3.8
- 3.9
- 3.10

Add or remove Python versions based on project requirements. See [the
guide](https://guide.esciencecenter.nl/#/best_practices/language_guides/python) for more information about Python
Expand Down
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ classifiers =
License :: OSI Approved :: Apache Software License
Natural Language :: English
Programming Language :: Python :: 3
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Expand All @@ -31,7 +30,7 @@ version = 0.1.0

[options]
zip_safe = False
python_requires = >=3.7
python_requires = >=3.8
include_package_data = True
packages = find:
install_requires =
Expand Down
7 changes: 3 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@ def dataset_fixture(feat_fixture):
for i in range(0, 5):
generate_data(tmp_path / f"data{i:03}.csv", 2**6)

max_distance = 30_000
min_neighbours = 20
num_neighbours = 20
features = feat_fixture["features"]
target = feat_fixture["target"]
dataset = RadarDataset(
tmp_path,
features,
target,
max_distance=max_distance,
min_neighbours=min_neighbours,
num_neighbours=num_neighbours,
max_poi_per_label=100,
)
return dataset

Expand Down
Loading