Skip to content

Commit

Permalink
Merge pull request #290 from CAVEconnectome/skeleton_dev
Browse files Browse the repository at this point in the history
Skeleton dev
  • Loading branch information
kebwi authored Dec 18, 2024
2 parents 7c802f6 + 1348dc0 commit 52597bf
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 28 deletions.
4 changes: 4 additions & 0 deletions caveclient/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,12 @@
+ "/{datastack_name}/precomputed/skeleton/{skvn}/info",
"get_cache_contents_via_skvn_ridprefixes": skeleton_v1
+ "/{datastack_name}/precomputed/skeleton/query_cache/{skeleton_version}/{root_id_prefixes}/{limit}",
# TODO: DEPRECATED: This endpoint is deprecated and will be removed in the future.
# Please use the POST endpoint in the future.
"skeletons_exist_via_skvn_rids": skeleton_v1
+ "/{datastack_name}/precomputed/skeleton/exists/{skeleton_version}/{root_ids}",
"skeletons_exist_via_skvn_rids_as_post": skeleton_v1
+ "/{datastack_name}/precomputed/skeleton/exists",
"get_skeleton_via_rid": skeleton_v1
+ "/{datastack_name}/precomputed/skeleton/{root_id}",
"get_skeleton_via_skvn_rid": skeleton_v1
Expand Down
104 changes: 76 additions & 28 deletions caveclient/skeletonservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@

SERVER_KEY = "skeleton_server_address"

MAX_SKELETONS_EXISTS_QUERY_SIZE = 1000
MAX_BULK_ASYNCHRONOUS_SKELETONS = 10000
BULK_ASYNC_SKELETONS_BATCH_SIZE = 100
BULK_SKELETONS_BATCH_SIZE = 100


class NoL2CacheException(Exception):
Expand Down Expand Up @@ -219,7 +220,26 @@ def decompressBytesToDict(inputBytes):
inputBytesStrDict = json.loads(inputBytesStr)
return inputBytesStrDict

def _build_endpoint(
def _build_skeletons_exist_endpoint(
self,
root_ids: List,
datastack_name: str,
skeleton_version: int,
post: bool = False,
):
endpoint_mapping = self.default_url_mapping
endpoint_mapping["datastack_name"] = datastack_name
if not post:
endpoint_mapping["root_ids"] = ",".join([str(v) for v in root_ids])
endpoint_mapping["skeleton_version"] = skeleton_version
endpoint = "skeletons_exist_via_skvn_rids"
else:
endpoint = "skeletons_exist_via_skvn_rids_as_post"

url = self._endpoints[endpoint].format_map(endpoint_mapping)
return url

def _build_get_skeleton_endpoint(
self,
root_id: int,
datastack_name: str,
Expand Down Expand Up @@ -379,6 +399,11 @@ def skeletons_exist(
"""
Confirm or deny that a set of root ids have H5 skeletons in the cache.
"""
if self._server_version < Version("0.9.0"):
logging.warning(
"Server version is old and only supports GET interactions for bulk async skeletons. Consider upgrading to a newer server version to enable POST interactions."
)

if datastack_name is None:
datastack_name = self._datastack_name
assert datastack_name is not None
Expand All @@ -389,32 +414,54 @@ def skeletons_exist(
f"Unknown skeleton version: {skeleton_version}. Valid options: {valid_skeleton_versions}"
)

if isinstance(root_ids, int):
root_ids = str(root_ids)
if isinstance(root_ids, np.ndarray):
root_ids = root_ids.tolist()
if not isinstance(root_ids, List): # If not a list, it can only be a string at this point
root_ids = [root_ids]

if isinstance(root_ids, int):
root_ids = str(root_ids)
elif isinstance(root_ids, List):
root_ids = ",".join([str(v) for v in root_ids])
if len(root_ids) > MAX_SKELETONS_EXISTS_QUERY_SIZE:
logging.warning(
f"The number of root_ids exceeds the current limit of {MAX_SKELETONS_EXISTS_QUERY_SIZE}. Only the first {MAX_SKELETONS_EXISTS_QUERY_SIZE} will be processed."
)
root_ids = root_ids[:MAX_SKELETONS_EXISTS_QUERY_SIZE]

endpoint_mapping = self.default_url_mapping
endpoint_mapping["datastack_name"] = datastack_name
endpoint_mapping["root_ids"] = root_ids
results = {}
for batch in range(0, len(root_ids), BULK_SKELETONS_BATCH_SIZE):
rids_one_batch = root_ids[batch : batch + BULK_SKELETONS_BATCH_SIZE]

endpoint_mapping["skeleton_version"] = skeleton_version
url = self._endpoints["skeletons_exist_via_skvn_rids"].format_map(
endpoint_mapping
)
if self._server_version < Version("0.9.0"):
url = self._build_skeletons_exist_endpoint(
rids_one_batch, datastack_name, skeleton_version
)
response = self.session.get(url)
self.raise_for_status(response, log_warning=log_warning)
else:
url = self._build_skeletons_exist_endpoint(
rids_one_batch, datastack_name, skeleton_version, True
)
data = {
"root_ids": rids_one_batch,
"skeleton_version": skeleton_version,
}
response = self.session.post(url, json=data)
response = handle_response(response, as_json=False)

response = self.session.get(url)
self.raise_for_status(response, log_warning=log_warning)
result_json = response.json()
if isinstance(result_json, dict):
# Convert string keys to ints
results.update({int(key): value for key, value in result_json.items()})
elif isinstance(result_json, bool):
assert len(rids_one_batch) == 1
results[int(rids_one_batch[0])] = result_json
else:
raise ValueError(f"Unexpected response type: {type(result_json)}")

result_json = response.json()
if isinstance(result_json, bool):
if len(results) == 1:
# When investigating a single root id, this returns a single bool, not a dict, list, etc.
return result_json
result_json_w_ints = {int(key): value for key, value in result_json.items()}
return result_json_w_ints
return list(results.values())[0]
return results

@cached(TTLCache(maxsize=32, ttl=3600))
def get_precomputed_skeleton_info(
Expand Down Expand Up @@ -511,7 +558,7 @@ def get_skeleton(
skeleton_versions = self.get_versions()
skeleton_version = sorted(skeleton_versions)[-1]

url = self._build_endpoint(
url = self._build_get_skeleton_endpoint(
root_id, datastack_name, skeleton_version, endpoint_format
)

Expand Down Expand Up @@ -685,6 +732,11 @@ def generate_bulk_skeletons_async(
if not self.fc.l2cache.has_cache():
raise NoL2CacheException("SkeletonClient requires an L2Cache.")

if self._server_version < Version("0.8.0"):
logging.warning(
"Server version is old and only supports GET interactions for bulk async skeletons. Consider upgrading to a newer server version to enable POST interactions."
)

if skeleton_version is None:
logging.warning(
"The optional nature of the 'skeleton_version' parameter will be deprecated in the future. Please specify a skeleton version."
Expand All @@ -697,11 +749,6 @@ def generate_bulk_skeletons_async(
raise ValueError(
f"root_ids must be a list or numpy array of root_ids, not a {type(root_ids)}"
)

if self._server_version < Version("0.8.0"):
logging.warning(
"Server version is old and only supports GET interactions for bulk async skeletons. Consider upgrading to a newer server version to enable POST interactions."
)

if len(root_ids) > MAX_BULK_ASYNCHRONOUS_SKELETONS:
logging.warning(
Expand All @@ -714,14 +761,15 @@ def generate_bulk_skeletons_async(
# So consider reverting to the unbatched approach in the future.

estimated_async_time_secs_upper_bound_sum = 0
for batch in range(0, len(root_ids), BULK_ASYNC_SKELETONS_BATCH_SIZE):
rids_one_batch = root_ids[batch : batch + BULK_ASYNC_SKELETONS_BATCH_SIZE]
for batch in range(0, len(root_ids), BULK_SKELETONS_BATCH_SIZE):
rids_one_batch = root_ids[batch : batch + BULK_SKELETONS_BATCH_SIZE]

if self._server_version < Version("0.8.0"):
url = self._build_bulk_async_endpoint(
rids_one_batch, datastack_name, skeleton_version
)
response = self.session.get(url)
self.raise_for_status(response, log_warning=log_warning)
else:
url = self._build_bulk_async_endpoint(
rids_one_batch, datastack_name, skeleton_version, post=True
Expand Down

0 comments on commit 52597bf

Please sign in to comment.