Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix STACObject inheritance #451

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion pystac/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,18 @@ def read_file(href: str) -> STACObject:
a :class:`~pystac.STACObject` and must be read using
:meth:`ItemCollection.from_file <pystac.ItemCollection.from_file>`
"""
return STACObject.from_file(href)
stac_io = StacIO.default()
duckontheweb marked this conversation as resolved.
Show resolved Hide resolved
d = stac_io.read_json(href)
typ = pystac.serialization.identify.identify_stac_object_type(d)

if typ == STACObjectType.CATALOG:
return Catalog.from_file(href)
elif typ == STACObjectType.COLLECTION:
return Collection.from_file(href)
elif typ == STACObjectType.ITEM:
return Item.from_file(href)
else:
raise STACTypeError(f"Cannot read file of type {typ}")


def write_file(
Expand Down
9 changes: 7 additions & 2 deletions pystac/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ def from_dict(
if migrate:
result = pystac.read_dict(d, href=href, root=root)
if not isinstance(result, Catalog):
raise pystac.STACError(f"{result} is not a Catalog")
raise pystac.STACTypeError(f"{result} is not a Catalog")
return result

catalog_type = CatalogType.determine_type(d)
Expand All @@ -919,7 +919,7 @@ def from_dict(

d.pop("stac_version")

cat = Catalog(
cat = cls(
id=id,
description=description,
title=title,
Expand All @@ -946,7 +946,12 @@ def full_copy(

@classmethod
def from_file(cls, href: str, stac_io: Optional[pystac.StacIO] = None) -> "Catalog":
if stac_io is None:
stac_io = pystac.StacIO.default()

result = super().from_file(href, stac_io)
if not isinstance(result, Catalog):
raise pystac.STACTypeError(f"{result} is not a {Catalog}.")
result._stac_io = stac_io

return result
2 changes: 1 addition & 1 deletion pystac/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def from_dict(

d.pop("stac_version")

collection = Collection(
collection = cls(
id=id,
description=description,
extent=extent,
Expand Down
46 changes: 0 additions & 46 deletions pystac/serialization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,8 @@
# flake8: noqa
from typing import Any, Dict, Optional, TYPE_CHECKING

import pystac
from pystac.serialization.identify import (
STACVersionRange,
identify_stac_object,
identify_stac_object_type,
)
from pystac.serialization.common_properties import merge_common_properties
from pystac.serialization.migrate import migrate_to_latest

if TYPE_CHECKING:
from pystac.stac_object import STACObject
from pystac.catalog import Catalog


def stac_object_from_dict(
d: Dict[str, Any], href: Optional[str] = None, root: Optional["Catalog"] = None
) -> "STACObject":
"""Determines how to deserialize a dictionary into a STAC object.

Args:
d : The dict to parse.
href : Optional href that is the file location of the object being
parsed.
root : Optional root of the catalog for this object.
If provided, the root's resolved object cache can be used to search for
previously resolved instances of the STAC object.

Note: This is used internally in StacIO instances to deserialize STAC Objects.
"""
if identify_stac_object_type(d) == pystac.STACObjectType.ITEM:
collection_cache = None
if root is not None:
collection_cache = root._resolved_objects.as_collection_cache()

# Merge common properties in case this is an older STAC object.
merge_common_properties(d, json_href=href, collection_cache=collection_cache)

info = identify_stac_object(d)

d = migrate_to_latest(d, info)

if info.object_type == pystac.STACObjectType.CATALOG:
return pystac.Catalog.from_dict(d, href=href, root=root, migrate=False)

if info.object_type == pystac.STACObjectType.COLLECTION:
return pystac.Collection.from_dict(d, href=href, root=root, migrate=False)

if info.object_type == pystac.STACObjectType.ITEM:
return pystac.Item.from_dict(d, href=href, root=root, migrate=False)

raise pystac.STACTypeError(f"Unknown STAC object type {info.object_type}")
61 changes: 54 additions & 7 deletions pystac/stac_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@

import pystac
from pystac.utils import safe_urlparse
import pystac.serialization
from pystac.serialization import (
merge_common_properties,
identify_stac_object_type,
identify_stac_object,
migrate_to_latest,
)

# Use orjson if available
try:
Expand Down Expand Up @@ -95,12 +100,31 @@ def stac_object_from_dict(
href: Optional[str] = None,
root: Optional["Catalog_Type"] = None,
) -> "STACObject_Type":
result = pystac.serialization.stac_object_from_dict(d, href, root)
if isinstance(result, pystac.Catalog):
# Set the stac_io instance for usage by io operations
# where this catalog is the root.
if identify_stac_object_type(d) == pystac.STACObjectType.ITEM:
collection_cache = None
if root is not None:
collection_cache = root._resolved_objects.as_collection_cache()

# Merge common properties in case this is an older STAC object.
merge_common_properties(
d, json_href=href, collection_cache=collection_cache
)

info = identify_stac_object(d)
d = migrate_to_latest(d, info)

if info.object_type == pystac.STACObjectType.CATALOG:
result = pystac.Catalog.from_dict(d, href=href, root=root, migrate=False)
result._stac_io = self
return result
return result

if info.object_type == pystac.STACObjectType.COLLECTION:
return pystac.Collection.from_dict(d, href=href, root=root, migrate=False)

if info.object_type == pystac.STACObjectType.ITEM:
return pystac.Item.from_dict(d, href=href, root=root, migrate=False)

raise ValueError(f"Unknown STAC object type {info.object_type}")

def read_json(
self, source: Union[str, "Link_Type"], *args: Any, **kwargs: Any
Expand Down Expand Up @@ -302,7 +326,30 @@ def stac_object_from_dict(
root: Optional["Catalog_Type"] = None,
) -> "STACObject_Type":
STAC_IO.issue_deprecation_warning()
return pystac.serialization.stac_object_from_dict(d, href, root)
if identify_stac_object_type(d) == pystac.STACObjectType.ITEM:
collection_cache = None
if root is not None:
collection_cache = root._resolved_objects.as_collection_cache()

# Merge common properties in case this is an older STAC object.
merge_common_properties(
d, json_href=href, collection_cache=collection_cache
)

info = identify_stac_object(d)

d = migrate_to_latest(d, info)

if info.object_type == pystac.STACObjectType.CATALOG:
return pystac.Catalog.from_dict(d, href=href, root=root, migrate=False)

if info.object_type == pystac.STACObjectType.COLLECTION:
return pystac.Collection.from_dict(d, href=href, root=root, migrate=False)

if info.object_type == pystac.STACObjectType.ITEM:
return pystac.Item.from_dict(d, href=href, root=root, migrate=False)

raise ValueError(f"Unknown STAC object type {info.object_type}")

# This is set in __init__.py
_STAC_OBJECT_CLASSES = None
Expand Down
7 changes: 6 additions & 1 deletion pystac/stac_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pystac import STACError
from pystac.link import Link
from pystac.utils import is_absolute_href, make_absolute_href
from pystac import serialization
from pystac.serialization.identify import identify_stac_object

if TYPE_CHECKING:
from pystac.catalog import Catalog as Catalog_Type
Expand Down Expand Up @@ -469,7 +471,10 @@ def from_file(
if not is_absolute_href(href):
href = make_absolute_href(href)

o = stac_io.read_stac_object(href)
d = stac_io.read_json(href)
info = identify_stac_object(d)
d = serialization.migrate.migrate_to_latest(d, info)
o = cls.from_dict(d, href=href)

# Set the self HREF, if it's not already set to something else.
if o.get_self_href() is None:
Expand Down
24 changes: 24 additions & 0 deletions tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,3 +1107,27 @@ def test_full_copy_4(self) -> None:
].get_absolute_href()
assert href is not None
self.assertTrue(os.path.exists(href))


class CatalogSubClassTest(unittest.TestCase):
"""This tests cases related to creating classes inheriting from pystac.Catalog to
ensure that inheritance, class methods, etc. function as expected."""

TEST_CASE_1 = TestCases.get_path("data-files/catalogs/test-case-1/catalog.json")

class BasicCustomCatalog(pystac.Catalog):
pass

def setUp(self) -> None:
self.stac_io = pystac.StacIO.default()

def test_from_dict_returns_subclass(self) -> None:
catalog_dict = self.stac_io.read_json(self.TEST_CASE_1)
custom_catalog = self.BasicCustomCatalog.from_dict(catalog_dict)

self.assertIsInstance(custom_catalog, self.BasicCustomCatalog)

def test_from_file_returns_subclass(self) -> None:
custom_catalog = self.BasicCustomCatalog.from_file(self.TEST_CASE_1)

self.assertIsInstance(custom_catalog, self.BasicCustomCatalog)
24 changes: 24 additions & 0 deletions tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,27 @@ def test_from_items(self) -> None:

self.assertEqual(interval[0], datetime(2000, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC))
self.assertEqual(interval[1], datetime(2001, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC))


class CollectionSubClassTest(unittest.TestCase):
"""This tests cases related to creating classes inheriting from pystac.Catalog to
ensure that inheritance, class methods, etc. function as expected."""

MULTI_EXTENT = TestCases.get_path("data-files/collections/multi-extent.json")

class BasicCustomCollection(pystac.Collection):
pass

def setUp(self) -> None:
self.stac_io = pystac.StacIO.default()

def test_from_dict_returns_subclass(self) -> None:
collection_dict = self.stac_io.read_json(self.MULTI_EXTENT)
custom_collection = self.BasicCustomCollection.from_dict(collection_dict)

self.assertIsInstance(custom_collection, self.BasicCustomCollection)

def test_from_file_returns_subclass(self) -> None:
custom_collection = self.BasicCustomCollection.from_file(self.MULTI_EXTENT)

self.assertIsInstance(custom_collection, self.BasicCustomCollection)
24 changes: 24 additions & 0 deletions tests/test_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,3 +698,27 @@ def test_asset_updated(self) -> None:
new_a1_value = cm.get_updated(item.assets["analytic"])
self.assertEqual(new_a1_value, set_value)
self.assertEqual(cm.updated, item_value)


class ItemSubClassTest(unittest.TestCase):
"""This tests cases related to creating classes inheriting from pystac.Catalog to
ensure that inheritance, class methods, etc. function as expected."""

SAMPLE_ITEM = TestCases.get_path("data-files/item/sample-item.json")

class BasicCustomItem(pystac.Item):
pass

def setUp(self) -> None:
self.stac_io = pystac.StacIO.default()

def test_from_dict_returns_subclass(self) -> None:
item_dict = self.stac_io.read_json(self.SAMPLE_ITEM)
custom_item = self.BasicCustomItem.from_dict(item_dict)

self.assertIsInstance(custom_item, self.BasicCustomItem)

def test_from_file_returns_subclass(self) -> None:
custom_item = self.BasicCustomItem.from_file(self.SAMPLE_ITEM)

self.assertIsInstance(custom_item, self.BasicCustomItem)