Skip to content

Commit

Permalink
Add support for array data types in GraphDBs
Browse files Browse the repository at this point in the history
  • Loading branch information
funkey committed Mar 7, 2024
1 parent 7381b3b commit 0273ee9
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 16 deletions.
1 change: 1 addition & 0 deletions funlib/persistence/graphs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .sqlite_graph_database import SQLiteGraphDataBase # noqa
from .pgsql_graph_database import PgSQLGraphDatabase # noqa
from .types import Array # noqa
8 changes: 6 additions & 2 deletions funlib/persistence/graphs/graph_database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from networkx import Graph
from funlib.geometry import Roi
from .types import Array

import logging
from abc import ABC, abstractmethod
Expand All @@ -9,6 +10,9 @@
logger = logging.getLogger(__name__)


AttributeType = type | str | Array


class GraphDataBase(ABC):
"""
Interface for graph databases that supports slicing to retrieve
Expand All @@ -33,15 +37,15 @@ def __getitem__(self, roi) -> Graph:

@property
@abstractmethod
def node_attrs(self) -> dict[str, type]:
def node_attrs(self) -> dict[str, AttributeType]:
"""
Return the node attributes supported by the database.
"""
pass

@property
@abstractmethod
def edge_attrs(self) -> dict[str, type]:
def edge_attrs(self) -> dict[str, AttributeType]:
"""
Return the edge attributes supported by the database.
"""
Expand Down
6 changes: 6 additions & 0 deletions funlib/persistence/graphs/pgsql_graph_database.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .sql_graph_database import SQLGraphDataBase
from .types import Array
from funlib.geometry import Roi

import logging
import psycopg2
import json
from typing import Optional, Any, Iterable
from collections.abc import Iterable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -177,12 +179,16 @@ def __exec(self, query):
def __sql_value(self, value):
if isinstance(value, str):
return f"'{value}'"
if isinstance(value, Iterable):
return f"array[{','.join([self.__sql_value(v) for v in value])}]"
elif value is None:
return "NULL"
else:
return str(value)

def __sql_type(self, type):
if isinstance(type, Array):
return self.__sql_type(type.dtype) + f"[{type.size}]"
try:
return {bool: "BOOLEAN", int: "INTEGER", str: "VARCHAR", float: "REAL"}[
type
Expand Down
23 changes: 12 additions & 11 deletions funlib/persistence/graphs/sql_graph_database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .graph_database import GraphDataBase
from .graph_database import GraphDataBase, AttributeType
from .types import Array, type_to_str

from funlib.geometry import Coordinate
from funlib.geometry import Roi
Expand Down Expand Up @@ -55,8 +56,8 @@ class SQLGraphDataBase(GraphDataBase):
The custom attributes to store on each edge.
"""

_node_attrs: Optional[dict[str, type]] = None
_edge_attrs: Optional[dict[str, type]] = None
_node_attrs: Optional[dict[str, AttributeType]] = None
_edge_attrs: Optional[dict[str, AttributeType]] = None

def __init__(
self,
Expand All @@ -67,8 +68,8 @@ def __init__(
nodes_table: str = "nodes",
edges_table: str = "edges",
endpoint_names: Optional[list[str]] = None,
node_attrs: Optional[dict[str, type]] = None,
edge_attrs: Optional[dict[str, type]] = None,
node_attrs: Optional[dict[str, AttributeType]] = None,
edge_attrs: Optional[dict[str, AttributeType]] = None,
):
self.position_attributes = position_attributes
self.ndim = len(self.position_attributes)
Expand Down Expand Up @@ -205,19 +206,19 @@ def write_graph(
)

@property
def node_attrs(self) -> dict[str, type]:
def node_attrs(self) -> dict[str, AttributeType]:
return self._node_attrs if self._node_attrs is not None else {}

@node_attrs.setter
def node_attrs(self, value: dict[str, type]) -> None:
def node_attrs(self, value: dict[str, AttributeType]) -> None:
self._node_attrs = value

@property
def edge_attrs(self) -> dict[str, type]:
def edge_attrs(self) -> dict[str, AttributeType]:
return self._edge_attrs if self._edge_attrs is not None else {}

@edge_attrs.setter
def edge_attrs(self, value: dict[str, type]) -> None:
def edge_attrs(self, value: dict[str, AttributeType]) -> None:
self._edge_attrs = value

def read_nodes(
Expand Down Expand Up @@ -523,8 +524,8 @@ def __create_metadata(self):
"directed": self.directed,
"total_roi_offset": self.total_roi.offset,
"total_roi_shape": self.total_roi.shape,
"node_attrs": {k: v.__name__ for k, v in self.node_attrs.items()},
"edge_attrs": {k: v.__name__ for k, v in self.edge_attrs.items()},
"node_attrs": {k: type_to_str(v) for k, v in self.node_attrs.items()},
"edge_attrs": {k: type_to_str(v) for k, v in self.edge_attrs.items()},
}

return metadata
Expand Down
6 changes: 3 additions & 3 deletions funlib/persistence/graphs/sqlite_graph_database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .sql_graph_database import SQLGraphDataBase
from .sql_graph_database import SQLGraphDataBase, AttributeType

from funlib.geometry import Roi

Expand All @@ -22,8 +22,8 @@ def __init__(
nodes_table: str = "nodes",
edges_table: str = "edges",
endpoint_names: Optional[list[str]] = None,
node_attrs: Optional[dict[str, type]] = None,
edge_attrs: Optional[dict[str, type]] = None,
node_attrs: Optional[dict[str, AttributeType]] = None,
edge_attrs: Optional[dict[str, AttributeType]] = None,
):
self.db_file = db_file
self.meta_collection = self.db_file.parent / f"{self.db_file.stem}-meta.json"
Expand Down
14 changes: 14 additions & 0 deletions funlib/persistence/graphs/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dataclasses import dataclass


@dataclass
class Array:
dtype: type | str
size: int


def type_to_str(type):
if isinstance(type, Array):
return f"Array({type_to_str(type.dtype)}, {type.size})"
else:
return type.__name__

0 comments on commit 0273ee9

Please sign in to comment.