Skip to content

Commit

Permalink
Rename types.{Array -> Vec} to avoid name clash with funlib.persisten…
Browse files Browse the repository at this point in the history
…ce.Array
  • Loading branch information
funkey committed Mar 8, 2024
1 parent 9a2a0e6 commit 28ccf81
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 30 deletions.
4 changes: 2 additions & 2 deletions funlib/persistence/graphs/graph_database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from networkx import Graph
from funlib.geometry import Roi
from ..types import Array
from ..types import Vec

import logging
from abc import ABC, abstractmethod
Expand All @@ -10,7 +10,7 @@
logger = logging.getLogger(__name__)


AttributeType = type | str | Array
AttributeType = type | str | Vec


class GraphDataBase(ABC):
Expand Down
4 changes: 2 additions & 2 deletions funlib/persistence/graphs/pgsql_graph_database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .sql_graph_database import SQLGraphDataBase
from ..types import Array
from ..types import Vec
from funlib.geometry import Roi

import logging
Expand Down Expand Up @@ -186,7 +186,7 @@ def __sql_value(self, value):
return str(value)

def __sql_type(self, type):
if isinstance(type, Array):
if isinstance(type, Vec):
return self.__sql_type(type.dtype) + f"[{type.size}]"
try:
return {bool: "BOOLEAN", int: "INTEGER", str: "VARCHAR", float: "REAL"}[
Expand Down
8 changes: 4 additions & 4 deletions funlib/persistence/graphs/sql_graph_database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .graph_database import GraphDataBase, AttributeType
from ..types import Array, type_to_str
from ..types import Vec, type_to_str

from funlib.geometry import Coordinate
from funlib.geometry import Roi
Expand Down Expand Up @@ -120,12 +120,12 @@ def get(value, default):
)

position_type = node_attrs[self.position_attribute]
if isinstance(position_type, Array):
if isinstance(position_type, Vec):
self.ndims = position_type.size
assert self.ndims > 1, (
"Don't use Arrays of size 1 for the position, use the "
"Don't use Vecs of size 1 for the position, use the "
"scalar type directly instead (i.e., 'float' instead of "
"'Array(float, 1)'."
"'Vec(float, 1)'."
)
# if ndims == 1, we know that we have a single scalar now
else:
Expand Down
10 changes: 5 additions & 5 deletions funlib/persistence/graphs/sqlite_graph_database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .sql_graph_database import SQLGraphDataBase, AttributeType
from ..types import Array
from ..types import Vec

from funlib.geometry import Roi

Expand Down Expand Up @@ -50,14 +50,14 @@ def __init__(
f"{attr}_{d}" for d in range(attr_type.size)
]
for attr, attr_type in self.node_attrs.items()
if isinstance(attr_type, Array)
if isinstance(attr_type, Vec)
}
self.edge_array_columns = {
attr: [
f"{attr}_{d}" for d in range(attr_type.size)
]
for attr, attr_type in self.edge_attrs.items()
if isinstance(attr_type, Array)
if isinstance(attr_type, Vec)
}

def _drop_tables(self) -> None:
Expand Down Expand Up @@ -175,7 +175,7 @@ def _node_attrs_to_columns(self, attrs):
columns = []
for attr in attrs:
attr_type = self.node_attrs[attr]
if isinstance(attr_type, Array):
if isinstance(attr_type, Vec):
columns += [
f"{attr}_{d}" for d in range(attr_type.size)
]
Expand All @@ -200,7 +200,7 @@ def _edge_attrs_to_columns(self, attrs):
columns = []
for attr in attrs:
attr_type = self.edge_attrs[attr]
if isinstance(attr_type, Array):
if isinstance(attr_type, Vec):
columns += [
f"{attr}_{d}" for d in range(attr_type.size)
]
Expand Down
6 changes: 3 additions & 3 deletions funlib/persistence/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@


@dataclass
class Array:
class Vec:
dtype: type | str
size: int


def type_to_str(type):
if isinstance(type, Array):
return f"Array({type_to_str(type.dtype)}, {type.size})"
if isinstance(type, Vec):
return f"Vec({type_to_str(type.dtype)}, {type.size})"
else:
return type.__name__
28 changes: 14 additions & 14 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from funlib.geometry import Roi, Coordinate
from funlib.persistence.types import Array
from funlib.persistence.types import Vec

import networkx as nx
import pytest


def test_graph_filtering(provider_factory):
graph_writer = provider_factory(
"w", node_attrs={"position": Array(float, 3), "selected": bool}, edge_attrs={"selected": bool}
"w", node_attrs={"position": Vec(float, 3), "selected": bool}, edge_attrs={"selected": bool}
)
roi = Roi((0, 0, 0), (10, 10, 10))
graph = graph_writer[roi]
Expand Down Expand Up @@ -54,7 +54,7 @@ def test_graph_filtering(provider_factory):
def test_graph_filtering_complex(provider_factory):
graph_provider = provider_factory(
"w",
node_attrs={"position": Array(float, 3), "selected": bool, "test": str},
node_attrs={"position": Vec(float, 3), "selected": bool, "test": str},
edge_attrs={"selected": bool, "a": int, "b": int},
)
roi = Roi((0, 0, 0), (10, 10, 10))
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_graph_filtering_complex(provider_factory):
def test_graph_read_and_update_specific_attrs(provider_factory):
graph_provider = provider_factory(
"w",
node_attrs={"position": Array(float, 3), "selected": bool, "test": str},
node_attrs={"position": Vec(float, 3), "selected": bool, "test": str},
edge_attrs={"selected": bool, "a": int, "b": int, "c": int},
)
roi = Roi((0, 0, 0), (10, 10, 10))
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_graph_read_and_update_specific_attrs(provider_factory):
def test_graph_read_unbounded_roi(provider_factory):
graph_provider = provider_factory(
"w",
node_attrs={"position": Array(float, 3), "selected": bool, "test": str},
node_attrs={"position": Vec(float, 3), "selected": bool, "test": str},
edge_attrs={"selected": bool, "a": int, "b": int},
)
roi = Roi((0, 0, 0), (10, 10, 10))
Expand Down Expand Up @@ -197,14 +197,14 @@ def test_graph_read_unbounded_roi(provider_factory):

def test_graph_read_meta_values(provider_factory):
roi = Roi((0, 0, 0), (10, 10, 10))
provider_factory("w", True, roi, node_attrs={"position": Array(float, 3)})
provider_factory("w", True, roi, node_attrs={"position": Vec(float, 3)})
graph_provider = provider_factory("r", None, None)
assert True == graph_provider.directed
assert roi == graph_provider.total_roi


def test_graph_default_meta_values(provider_factory):
provider = provider_factory("w", False, None, node_attrs={"position": Array(float, 3)})
provider = provider_factory("w", False, None, node_attrs={"position": Vec(float, 3)})
assert False == provider.directed
assert provider.total_roi is None or provider.total_roi == Roi(
(None, None, None), (None, None, None)
Expand All @@ -220,7 +220,7 @@ def test_graph_io(provider_factory):
graph_provider = provider_factory(
"w",
node_attrs={
"position": Array(float, 3),
"position": Vec(float, 3),
"swip": str,
"zap": str,
}
Expand Down Expand Up @@ -263,7 +263,7 @@ def test_graph_fail_if_exists(provider_factory):
graph_provider = provider_factory(
"w",
node_attrs={
"position": Array(float, 3),
"position": Vec(float, 3),
"swip": str,
"zap": str,
}
Expand All @@ -289,7 +289,7 @@ def test_graph_fail_if_not_exists(provider_factory):
graph_provider = provider_factory(
"w",
node_attrs={
"position": Array(float, 3),
"position": Vec(float, 3),
"swip": str,
"zap": str,
}
Expand All @@ -316,7 +316,7 @@ def test_graph_write_attributes(provider_factory):
graph_provider = provider_factory(
"w",
node_attrs={
"position": Array(int, 3),
"position": Vec(int, 3),
"swip": str,
"zap": str,
}
Expand Down Expand Up @@ -372,7 +372,7 @@ def test_graph_write_roi(provider_factory):
graph_provider = provider_factory(
"w",
node_attrs={
"position": Array(float, 3),
"position": Vec(float, 3),
"swip": str,
"zap": str,
}
Expand Down Expand Up @@ -411,7 +411,7 @@ def test_graph_connected_components(provider_factory):
graph_provider = provider_factory(
"w",
node_attrs={
"position": Array(float, 3),
"position": Vec(float, 3),
"swip": str,
"zap": str,
}
Expand Down Expand Up @@ -449,7 +449,7 @@ def test_graph_has_edge(provider_factory):
graph_provider = provider_factory(
"w",
node_attrs={
"position": Array(float, 3),
"position": Vec(float, 3),
"swip": str,
"zap": str,
}
Expand Down

0 comments on commit 28ccf81

Please sign in to comment.