Skip to content

Commit

Permalink
Merge pull request #4 from janelia-cellmap/main
Browse files Browse the repository at this point in the history
Add features from cellmap fork
  • Loading branch information
rhoadesScholar authored Mar 27, 2024
2 parents 4b4e685 + c2d205d commit 7f8ce63
Show file tree
Hide file tree
Showing 8 changed files with 591 additions and 116 deletions.
422 changes: 370 additions & 52 deletions funlib/persistence/arrays/datasets.py

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions funlib/persistence/graphs/pgsql_graph_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def __init__(
nodes_table=nodes_table,
edges_table=edges_table,
endpoint_names=endpoint_names,
node_attrs=node_attrs,
edge_attrs=edge_attrs,
node_attrs=node_attrs, # type: ignore
edge_attrs=edge_attrs, # type: ignore
)

def _drop_tables(self) -> None:
Expand Down Expand Up @@ -101,12 +101,12 @@ def _create_tables(self) -> None:
f"{self.nodes_table_name}({self.position_attribute})"
)

columns = list(self.edge_attrs.keys())
columns = list(self.edge_attrs.keys()) # type: ignore
types = list([self.__sql_type(t) for t in self.edge_attrs.values()])
column_types = [f"{c} {t}" for c, t in zip(columns, types)]
self.__exec(
f"CREATE TABLE IF NOT EXISTS {self.edges_table_name}("
f"{self.endpoint_names[0]} BIGINT not null, "
f"{self.endpoint_names[0]} BIGINT not null, " # type: ignore
f"{self.endpoint_names[1]} BIGINT not null, "
f"{' '.join([c + ',' for c in column_types])}"
f"PRIMARY KEY ({self.endpoint_names[0]}, {self.endpoint_names[1]})"
Expand Down
56 changes: 27 additions & 29 deletions funlib/persistence/graphs/sql_graph_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def __init__(
node_attrs: Optional[dict[str, AttributeType]] = None,
edge_attrs: Optional[dict[str, AttributeType]] = None,
):
assert mode in self.valid_modes, f"Mode '{mode}' not in allowed modes {self.valid_modes}"
assert (
mode in self.valid_modes
), f"Mode '{mode}' not in allowed modes {self.valid_modes}"
self.mode = mode

if mode in self.read_modes:
Expand Down Expand Up @@ -135,8 +137,8 @@ def get(value, default):

self.directed = get(directed, False)
self.total_roi = get(
total_roi,
Roi((None,) * self.ndims, (None,) * self.ndims))
total_roi, Roi((None,) * self.ndims, (None,) * self.ndims)
)
self.nodes_table_name = get(nodes_table, "nodes")
self.edges_table_name = get(edges_table, "edges")
self.endpoint_names = get(endpoint_names, ["u", "v"])
Expand Down Expand Up @@ -229,7 +231,7 @@ def read_graph(
edges = self.read_edges(
roi, nodes=nodes, read_attrs=edge_attrs, attr_filter=edges_filter
)
u, v = self.endpoint_names
u, v = self.endpoint_names # type: ignore
try:
edge_list = [(e[u], e[v], self.__remove_keys(e, [u, v])) for e in edges]
except KeyError as e:
Expand Down Expand Up @@ -336,11 +338,7 @@ def read_nodes(

nodes = [
self._columns_to_node_attrs(
{
key: val
for key, val in zip(read_columns, values)
},
read_attrs
{key: val for key, val in zip(read_columns, values)}, read_attrs
)
for values in self._select_query(select_statement)
]
Expand Down Expand Up @@ -375,11 +373,11 @@ def read_edges(
return []

node_ids = ", ".join([str(node["id"]) for node in nodes])
node_condition = f"{self.endpoint_names[0]} IN ({node_ids})"
node_condition = f"{self.endpoint_names[0]} IN ({node_ids})" # type: ignore

logger.debug("Reading nodes in roi %s" % roi)
# TODO: AND vs OR here
desired_columns = ", ".join(self.endpoint_names + list(self.edge_attrs.keys()))
desired_columns = ", ".join(self.endpoint_names + list(self.edge_attrs.keys())) # type: ignore
select_statement = (
f"SELECT {desired_columns} FROM {self.edges_table_name} WHERE "
+ node_condition
Expand All @@ -390,7 +388,7 @@ def read_edges(
)
)

edge_attrs = self.endpoint_names + (
edge_attrs = self.endpoint_names + ( # type: ignore
list(self.edge_attrs.keys()) if read_attrs is None else read_attrs
)
attr_filter = attr_filter if attr_filter is not None else {}
Expand All @@ -401,7 +399,7 @@ def read_edges(
{
key: val
for key, val in zip(
self.endpoint_names + list(self.edge_attrs.keys()), values
self.endpoint_names + list(self.edge_attrs.keys()), values # type: ignore
)
if key in edge_attrs
}
Expand Down Expand Up @@ -486,8 +484,8 @@ def update_edges(
if not roi.contains(pos_u):
logger.debug(
(
f"Skipping edge with {self.endpoint_names[0]} {{}}, {self.endpoint_names[1]} {{}},"
+ f"and data {{}} because {self.endpoint_names[0]} not in roi {{}}"
f"Skipping edge with {self.endpoint_names[0]} {{}}, {self.endpoint_names[1]} {{}}," # type: ignore
+ f"and data {{}} because {self.endpoint_names[0]} not in roi {{}}" # type: ignore
).format(u, v, data, roi)
)
continue
Expand All @@ -497,7 +495,7 @@ def update_edges(
update_statement = (
f"UPDATE {self.edges_table_name} SET "
f"{', '.join(setters)} WHERE "
f"{self.endpoint_names[0]}={u} AND {self.endpoint_names[1]}={v}"
f"{self.endpoint_names[0]}={u} AND {self.endpoint_names[1]}={v}" # type: ignore
)

self._update_query(update_statement, commit=False)
Expand Down Expand Up @@ -528,10 +526,7 @@ def write_nodes(
pos = self.__get_node_pos(data)
if roi is not None and not roi.contains(pos):
continue
values.append(
[node_id]
+ [data.get(attr, None) for attr in attrs]
)
values.append([node_id] + [data.get(attr, None) for attr in attrs])

if len(values) == 0:
logger.debug("No nodes to insert in %s", roi)
Expand Down Expand Up @@ -602,12 +597,13 @@ def __load_metadata(self, metadata):

# simple attributes
for attr_name in [
"position_attribute",
"directed",
"nodes_table_name",
"edges_table_name",
"endpoint_names",
"ndims"]:
"position_attribute",
"directed",
"nodes_table_name",
"edges_table_name",
"endpoint_names",
"ndims",
]:

if getattr(self, attr_name) is None:
setattr(self, attr_name, metadata[attr_name])
Expand Down Expand Up @@ -657,7 +653,7 @@ def __remove_keys(self, dictionary, keys):

def __get_node_pos(self, n: dict[str, Any]) -> Optional[Coordinate]:
try:
return Coordinate(n[self.position_attribute])
return Coordinate(n[self.position_attribute]) # type: ignore
except KeyError:
return None

Expand All @@ -681,11 +677,13 @@ def __attr_query(self, attrs: dict[str, Any]) -> str:
def __roi_query(self, roi: Roi) -> str:
query = "WHERE "
pos_attr = self.position_attribute
for dim in range(self.ndims):
for dim in range(self.ndims): # type: ignore
if dim > 0:
query += " AND "
if roi.begin[dim] is not None and roi.end[dim] is not None:
query += f"{pos_attr}[{dim + 1}] BETWEEN {roi.begin[dim]} and {roi.end[dim]}"
query += (
f"{pos_attr}[{dim + 1}] BETWEEN {roi.begin[dim]} and {roi.end[dim]}"
)
elif roi.begin[dim] is not None:
query += f"{pos_attr}[{dim + 1}]>={roi.begin[dim]}"
elif roi.begin[dim] is not None:
Expand Down
32 changes: 11 additions & 21 deletions funlib/persistence/graphs/sqlite_graph_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def __init__(
def node_array_columns(self):
if not self._node_array_columns:
self._node_array_columns = {
attr: [
f"{attr}_{d}" for d in range(attr_type.size)
]
attr: [f"{attr}_{d}" for d in range(attr_type.size)]
for attr, attr_type in self.node_attrs.items()
if isinstance(attr_type, Vec)
}
Expand All @@ -63,9 +61,7 @@ def node_array_columns(self):
def edge_array_columns(self):
if not self._edge_array_columns:
self._edge_array_columns = {
attr: [
f"{attr}_{d}" for d in range(attr_type.size)
]
attr: [f"{attr}_{d}" for d in range(attr_type.size)]
for attr, attr_type in self.edge_attrs.items()
if isinstance(attr_type, Vec)
}
Expand Down Expand Up @@ -100,16 +96,16 @@ def _create_tables(self) -> None:
f"{', '.join(node_columns)}"
")"
)
if self.ndims > 1:
if self.ndims > 1: # type: ignore
position_columns = self.node_array_columns[self.position_attribute]
else:
position_columns = self.position_attribute
self.cur.execute(
f"CREATE INDEX IF NOT EXISTS pos_index ON {self.nodes_table_name}({','.join(position_columns)})"
)
edge_columns = [
f"{self.endpoint_names[0]} INTEGER not null",
f"{self.endpoint_names[1]} INTEGER not null",
f"{self.endpoint_names[0]} INTEGER not null", # type: ignore
f"{self.endpoint_names[1]} INTEGER not null", # type: ignore
]
for attr in self.edge_attrs.keys():
if attr in self.edge_array_columns:
Expand All @@ -119,7 +115,7 @@ def _create_tables(self) -> None:
self.cur.execute(
f"CREATE TABLE IF NOT EXISTS {self.edges_table_name}("
+ f"{', '.join(edge_columns)}"
+ f", PRIMARY KEY ({self.endpoint_names[0]}, {self.endpoint_names[1]})"
+ f", PRIMARY KEY ({self.endpoint_names[0]}, {self.endpoint_names[1]})" # type: ignore
+ ")"
)

Expand All @@ -142,7 +138,7 @@ def _select_query(self, query):
#
# If SQL dialects allow array element access, they start counting at 1.
# We don't want that, we start counting at 0 like normal people.
query = re.sub(r'\[(\d+)\]', lambda m: "_" + str(int(m.group(1)) - 1), query)
query = re.sub(r"\[(\d+)\]", lambda m: "_" + str(int(m.group(1)) - 1), query)

try:
return self.cur.execute(query)
Expand Down Expand Up @@ -201,9 +197,7 @@ def _node_attrs_to_columns(self, attrs):
for attr in attrs:
attr_type = self.node_attrs[attr]
if isinstance(attr_type, Vec):
columns += [
f"{attr}_{d}" for d in range(attr_type.size)
]
columns += [f"{attr}_{d}" for d in range(attr_type.size)]
else:
columns.append(attr)
return columns
Expand All @@ -213,8 +207,7 @@ def _columns_to_node_attrs(self, columns, query_attrs):
for attr in query_attrs:
if attr in self.node_array_columns:
value = tuple(
columns[f"{attr}_{d}"]
for d in range(self.node_attrs[attr].size)
columns[f"{attr}_{d}"] for d in range(self.node_attrs[attr].size)
)
else:
value = columns[attr]
Expand All @@ -226,9 +219,7 @@ def _edge_attrs_to_columns(self, attrs):
for attr in attrs:
attr_type = self.edge_attrs[attr]
if isinstance(attr_type, Vec):
columns += [
f"{attr}_{d}" for d in range(attr_type.size)
]
columns += [f"{attr}_{d}" for d in range(attr_type.size)]
else:
columns.append(attr)
return columns
Expand All @@ -238,8 +229,7 @@ def _columns_to_edge_attrs(self, columns, query_attrs):
for attr in query_attrs:
if attr in self.edge_array_columns:
value = tuple(
columns[f"{attr}_{d}"]
for d in range(self.edge_attrs[attr].size)
columns[f"{attr}_{d}"] for d in range(self.edge_attrs[attr].size)
)
else:
value = columns[attr]
Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,7 @@ ignore_missing_imports = True
ignore_missing_imports = True

[mypy-h5py.*]
ignore_missing_imports = True

[mypy-psycopg2.*]
ignore_missing_imports = True
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"pymongo",
"numpy",
"h5py",
"psycopg2",
"psycopg2-binary",
]

[tool.setuptools.dynamic]
Expand Down
Loading

0 comments on commit 7f8ce63

Please sign in to comment.