Skip to content

Commit

Permalink
Quick fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mannbach committed Aug 20, 2024
1 parent 31fcfd8 commit 4c7e8eb
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 12 deletions.
2 changes: 0 additions & 2 deletions netin/graphs/binary_class_node_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def from_nd_array(cls,
def from_fraction(
cls,
N: int, f_m: float,
node_labels: Optional[List[str]] = None,
class_labels: Optional[List[str]] = None,
rng: Optional[np.random.Generator] = None)\
-> 'BinaryClassNodeVector':
Expand All @@ -47,7 +46,6 @@ def from_fraction(
rng = np.random.default_rng() if rng is None else rng
return cls.from_ndarray(
values=np.where(rng.random(N) < f_m, MINORITY_VALUE, MAJORITY_VALUE),
node_labels=node_labels,
class_labels=class_labels)

def get_minority_mask(self) -> np.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion netin/graphs/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_node_class(self, name: str) -> CategoricalNodeVector:
def get_node_classes(self) -> Dict[int, CategoricalNodeVector]:
return self._node_classes

def has_node_classes(self, name: str) -> bool:
def has_node_class(self, name: str) -> bool:
return name in self._node_classes

def add_edge(self, source: int, target: int) -> None:
Expand Down
6 changes: 3 additions & 3 deletions netin/models/binary_class_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def _initialize_node_classes(self):
N=self.get_final_number_of_nodes(),
f_m=self.f_m,
rng=self._rng)
if self.graph.has_node_attribute(CLASS_ATTRIBUTE):
node_class_values_pre = self.graph.get_node_attribute(CLASS_ATTRIBUTE)
if self.graph.has_node_class(CLASS_ATTRIBUTE):
node_class_values_pre = self.graph.get_node_class(CLASS_ATTRIBUTE)
assert isinstance(node_class_values_pre, BinaryClassNodeVector),\
"The node class values must be binary"
node_class_values[:len(node_class_values_pre)] =\
node_class_values_pre.values
self.graph.set_node_attribute(
self.graph.set_node_class(
CLASS_ATTRIBUTE, node_class_values)
4 changes: 3 additions & 1 deletion netin/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
super().__init__(*args, **kwargs)

self.N = N
self.graph = None

self._set_seed(seed)

Expand All @@ -58,7 +59,8 @@ def _set_seed(self, seed: Union[int, np.random.Generator]):
self._rng = np.random.default_rng(seed=seed)
elif isinstance(seed, np.random.Generator):
self._rng = seed
raise ValueError("seed must be an int or np.random.Generator")
else:
raise ValueError(f"`seed` must be an `int` or `np.random.Generator` but is {type(seed)}")

def _initialize_simulation(self):
self.log(f"Initializing simulation of {self.__class__.__name__}")
Expand Down
2 changes: 1 addition & 1 deletion netin/models/pah_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ def _initialize_lfms(self):
"""
super()._initialize_lfms()
self.pa = PreferentialAttachment(
N=self.compute_final_number_of_nodes(),
N=self.get_final_number_of_nodes(),
graph=self.graph)
4 changes: 2 additions & 2 deletions netin/models/undirected_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def __init__(
def get_initial_graph(self) -> Graph:
graph = Graph()
for i in range(self.m):
self.graph.add_node(i)
graph.add_node(i)
for j in range(i):
self.graph.add_edge(i, j)
graph.add_edge(i, j)
return graph

def _simulate(self) -> Graph:
Expand Down
4 changes: 2 additions & 2 deletions netin/utils/event_handling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, Any, List, Callable

from collections import defaultdict
from enum import Enum

class Event(Enum):
Expand All @@ -12,7 +12,7 @@ class Event(Enum):

class HasEvents:
EVENTS: List[Event] = []
_event_handlers: Dict[Event, Callable[[Any], None]] = {}
_event_handlers: Dict[Event, Callable[[Any], None]] = defaultdict(list)

def trigger_event(self, *args, event: Event, **kwargs):
assert event in self.EVENTS,\
Expand Down

0 comments on commit 4c7e8eb

Please sign in to comment.