Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arrow io #272

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
481 changes: 481 additions & 0 deletions arrowio.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/pgstac/migrations/pgstac.unreleased.sql
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname='btree_gist') THEN
CREATE EXTENSION IF NOT EXISTS btree_gist;
END IF;
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname='btree_gist') THEN
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname='unaccent') THEN
CREATE EXTENSION IF NOT EXISTS unaccent;
END IF;
END;
Expand Down
54 changes: 54 additions & 0 deletions src/pgstac/sql/004_search.sql
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,60 @@ END;
$$ LANGUAGE PLPGSQL SECURITY DEFINER;


CREATE OR REPLACE FUNCTION search_items(_search jsonb = '{}'::jsonb) RETURNS SETOF items AS $$
DECLARE
searches searches%ROWTYPE;
_where text;
orderby text;
token record;
token_prev boolean;
token_item items%ROWTYPE;
token_where text;
full_where text;
_limit int := coalesce((_search->>'limit')::int, 10);
_querylimit int;
has_prev boolean := FALSE;
has_next boolean := FALSE;
BEGIN
searches := search_query(_search);
_where := searches._where;
orderby := searches.orderby;
RAISE NOTICE 'SEARCH:TOKEN: %', _search->>'token';
token := get_token_record(_search->>'token');
RAISE NOTICE '***TOKEN: %', token;
_querylimit := _limit + 1;
IF token IS NOT NULL THEN
token_prev := token.prev;
token_item := token.item;
token_where := get_token_filter(_search->'sortby', token_item, token_prev, FALSE);
RAISE DEBUG 'TOKEN_WHERE: % (%ms from search start)', token_where, age_ms(timer);
IF token_prev THEN -- if we are using a prev token, we know has_next is true
RAISE DEBUG 'There is a previous token, so automatically setting has_next to true';
has_next := TRUE;
orderby := sort_sqlorderby(_search, TRUE);
ELSE
RAISE DEBUG 'There is a next token, so automatically setting has_prev to true';
has_prev := TRUE;

END IF;
ELSE -- if there was no token, we know there is no prev
RAISE DEBUG 'There is no token, so we know there is no prev. setting has_prev to false';
has_prev := FALSE;
END IF;

full_where := concat_ws(' AND ', _where, token_where);

RETURN QUERY
SELECT *
FROM search_rows(
full_where,
orderby,
NULL,
_querylimit
) as i;
END;
$$ LANGUAGE PLPGSQL;

CREATE OR REPLACE FUNCTION search(_search jsonb = '{}'::jsonb) RETURNS jsonb AS $$
DECLARE
searches searches%ROWTYPE;
Expand Down
1 change: 1 addition & 0 deletions src/pypgstac/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"tenacity==8.1.*",
"cachetools==5.3.*",
"version-parser>= 1.0.1",
"psycopg-infdate>=1.0.3",
]

[project.optional-dependencies]
Expand Down
127 changes: 124 additions & 3 deletions src/pypgstac/python/pypgstac/db.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
"""Base library for database interaction with PgSTAC."""

import atexit
import logging
import time
from types import TracebackType
from typing import Any, Generator, List, Optional, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)

import orjson
import psycopg
from cachetools import LRUCache, cachedmethod
from psycopg import Connection, sql
from psycopg.types.json import set_json_dumps, set_json_loads
from psycopg_pool import ConnectionPool
Expand All @@ -16,7 +29,15 @@
except ImportError:
from pydantic import BaseSettings # type:ignore

import psycopg_infdate
import pyarrow as pa
import shapely
from stac_geoparquet.to_arrow import _process_arrow_table as cleanarrow
from tenacity import retry, retry_if_exception_type, stop_after_attempt
from version_parser import Version as V

from .hydration import hydrate
from .version import __version__ as pypgstac_version

logger = logging.getLogger(__name__)

Expand All @@ -28,6 +49,7 @@ def dumps(data: dict) -> str:

set_json_dumps(dumps)
set_json_loads(orjson.loads)
psycopg_infdate.register_inf_date_handler(psycopg)


def pg_notice_handler(notice: psycopg.errors.Diagnostic) -> None:
Expand All @@ -36,6 +58,15 @@ def pg_notice_handler(notice: psycopg.errors.Diagnostic) -> None:
logger.info(msg)


def _chunks(
lst: Sequence[Dict[str, Any]],
n: int,
) -> Generator[Sequence[Dict[str, Any]], None, None]:
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]


class Settings(BaseSettings):
"""Base Settings for Database Connection."""

Expand Down Expand Up @@ -63,6 +94,7 @@ def __init__(
commit_on_exit: bool = True,
debug: bool = False,
use_queue: bool = False,
item_funcs: Optional[List[Callable]] = None,
) -> None:
"""Initialize Database."""
self.dsn: str
Expand All @@ -76,8 +108,29 @@ def __init__(
self.initial_version = "0.1.9"
self.debug = debug
self.use_queue = use_queue
self.item_funcs = item_funcs
if self.debug:
logging.basicConfig(level=logging.DEBUG)
self.cache: LRUCache = LRUCache(maxsize=256)

def check_version(self) -> None:
db_version = self.version
if db_version is None:
raise Exception("Failed to detect the target database version.")

if db_version != "unreleased":
v1 = V(db_version)
v2 = V(pypgstac_version)
if (v1.get_major_version(), v1.get_minor_version()) != (
v2.get_major_version(),
v2.get_minor_version(),
):
raise Exception(
f"pypgstac version {pypgstac_version}"
" is not compatible with the target"
f" database version {self.version}."
f" database version {db_version}.",
)

def get_pool(self) -> ConnectionPool:
"""Get Database Pool."""
Expand Down Expand Up @@ -222,7 +275,7 @@ def query(
conn.rollback()
raise e

def query_one(self, *args: Any, **kwargs: Any) -> Union[Tuple, str, None]:
def query_one(self, *args: Any, **kwargs: Any) -> Union[Tuple, str, dict, None]:
"""Return results from a query that returns a single row."""
try:
r = next(self.query(*args, **kwargs))
Expand Down Expand Up @@ -298,6 +351,74 @@ def func(self, function_name: str, *args: Any) -> Generator:
base_query = sql.SQL("SELECT * FROM {}({});").format(func, placeholders)
return self.query(base_query, cleaned_args)

@cachedmethod(lambda self: self.cache)
def collection_baseitem(self, collection_id: str) -> dict:
"""Get collection."""
base_item = self.query_one(
"SELECT base_item FROM collections WHERE id=%s",
(collection_id,),
)
if not isinstance(base_item, dict):
raise Exception(
f"Collection {collection_id} is not present in the database",
)
logger.debug(f"Found {collection_id} with base_item {base_item}")
return base_item

def pgstac_row_reader(
self,
id,
collection,
geometry,
datetime,
end_datetime,
content,
):
"""Read pgstac item, hydrate it, and convert to item stac json formatted dict."""
base_item = self.collection_baseitem(collection)
content["id"] = id
content["collection"] = collection
content["geometry"] = geometry
if datetime == end_datetime and "datetime" not in content["properties"]:
content["properties"]["datetime"] = datetime
elif datetime != end_datetime:
if "start_datetime" not in content["properties"]:
content["properties"]["start_datetime"] = datetime
if "end_datetime" not in content["properties"]:
content["properties"]["end_datetime"] = end_datetime
if "bbox" not in content:
geom = shapely.wkb.loads(geometry)
content["bbox"] = list(geom.bounds)
if "type" not in content:
content["type"] = "Feature"
content = hydrate(base_item, content)
if self.item_funcs is not None:
for func in self.item_funcs:
content = func(content)
return content

def get_table(self, results):
"""Convert pgstac item row results to arrow table."""
pylist = [self.pgstac_row_reader(*r) for r in results]
table = pa.Table.from_pylist(pylist)
table = cleanarrow(table)

return table

def search(self, query: Union[dict, str, psycopg.types.json.Jsonb] = "{}") -> str:
"""Search PgSTAC."""
return dumps(next(self.func("search", query))[0])

results = self.query(
"""
SELECT
id,
collection,
st_asbinary(geometry),
datetime::text,
end_datetime::text,
content
FROM search_items(%s);
""",
(query,),
)
return self.get_table(results)
1 change: 1 addition & 0 deletions src/pypgstac/python/pypgstac/dumper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Utilities to dump data from pgstac tables to items and collections."""
94 changes: 75 additions & 19 deletions src/pypgstac/python/pypgstac/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

from .db import PgstacDB
from .hydration import dehydrate
from .version import __version__
from .reader import Reader

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -149,29 +149,85 @@ class Loader:

db: PgstacDB
_partition_cache: Dict[str, Partition]
_reader: Reader

def __init__(self, db: PgstacDB):
self.db = db
self._partition_cache: Dict[str, Partition] = {}

def check_version(self) -> None:
db_version = self.db.version
if db_version is None:
raise Exception("Failed to detect the target database version.")

if db_version != "unreleased":
v1 = V(db_version)
v2 = V(__version__)
if (v1.get_major_version(), v1.get_minor_version()) != (
v2.get_major_version(),
v2.get_minor_version(),
):
raise Exception(
f"pypgstac version {__version__}"
" is not compatible with the target"
f" database version {self.db.version}."
f" database version {db_version}.",
def get_reader(self, file = None):
if file is None and self._reader is None:
raise Exception('No file set')
elif file is not None:
self._reader = Reader(file)
assert(self._reader)
return self._reader

def check_partitions(self):
reader = self.get_reader()
stats = orjson.dumps(reader.stats.to_pylist()).decode()
query = """
WITH objs AS (
SELECT
value
FROM jsonb_array_elements(%s::jsonb)
), tocheck AS (
SELECT
x.*
FROM objs, jsonb_to_record(
value
) AS x(
collection text,
start_month timestamptz,
start_min timestamptz,
start_max timestamptz,
end_min timestamptz,
end_max timestamptz,
count_all int
)
), agged AS (
SELECT
partition,
tocheck.collection,
partition_dtrange,
constraint_dtrange,
constraint_edtrange,
tstzrange(
min(start_min),
max(start_max),
'[]') as in_dtrange,
tstzrange(
min(end_min),
max(end_max),
'[]') as in_edtrange,
sum(count_all) as count
FROM tocheck LEFT JOIN partitions_view pv ON (
tocheck.collection=pv.collection AND
start_month <@ partition_dtrange
)
GROUP BY 1,2,3,4,5
)
SELECT
partition,
collection,
partition_dtrange::text,
constraint_dtrange::text,
constraint_edtrange::text,
in_dtrange::text,
in_edtrange::text,
in_dtrange <@ partition_dtrange as partition_good,
in_dtrange <@ constraint_dtrange as constraint_dtrange_good,
in_edtrange <@ constraint_edtrange as constraint_edtrange_good
FROM agged
"""
res = self.db.query(query, (stats,))
for r in res:
print(r)






@lru_cache(maxsize=128)
def collection_json(self, collection_id: str) -> Tuple[Dict[str, Any], int, str]:
Expand All @@ -197,7 +253,7 @@ def load_collections(
insert_mode: Optional[Methods] = Methods.insert,
) -> None:
"""Load a collections json or ndjson file."""
self.check_version()
self.db.check_version()

if file is None:
file = "stdin"
Expand Down
Loading