Skip to content

Commit

Permalink
BUG: Fix pickling for CRS builder classes
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed Aug 13, 2021
1 parent cefc33d commit 316b933
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 11 deletions.
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Latest
- BUG: Make datum name match exact in :func:`pyproj.database.query_utm_crs_info` (pull #887)
- BUG: Update :class:`pyproj.enums.GeodIntermediateFlag` for future Python compatibility (issue #855)
- BUG: Hide unnecessary PROJ ERROR from proj_crs_get_coordoperation (issue #873)
- BUG: Fix pickling for CRS builder classes (issue #897)

3.1.0
-----
Expand Down
11 changes: 7 additions & 4 deletions pyproj/crs/crs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
import threading
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Union

from pyproj._crs import ( # noqa
_CRS,
Expand Down Expand Up @@ -1411,9 +1411,12 @@ def is_geocentric(self) -> bool:
def __eq__(self, other: Any) -> bool:
return self.equals(other)

def __reduce__(self) -> Tuple[Type["CRS"], Tuple[str]]:
"""special method that allows CRS instance to be pickled"""
return self.__class__, (self.srs,)
def __getstate__(self) -> Dict[str, str]:
return {"srs": self.srs}

def __setstate__(self, state: Dict[str, Any]):
self.__dict__.update(state)
self._local = CRSLocal()

def __hash__(self) -> int:
return hash(self.to_wkt())
Expand Down
12 changes: 12 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pickle
from contextlib import contextmanager
from distutils.version import LooseVersion
from pathlib import Path
Expand Down Expand Up @@ -87,3 +88,14 @@ def get_wgs84_datum_name():
if PROJ_GTE_8:
return "World Geodetic System 1984 ensemble"
return "World Geodetic System 1984"


def assert_can_pickle(raw_obj, tmp_path):
file_path = tmp_path / "temporary.pickle"
with open(file_path, "wb") as f:
pickle.dump(raw_obj, f)

with open(file_path, "rb") as f:
unpickled = pickle.load(f)

assert raw_obj == unpickled
5 changes: 5 additions & 0 deletions test/crs/test_crs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from test.conftest import (
HAYFORD_ELLIPSOID_NAME,
PROJ_GTE_8,
assert_can_pickle,
get_wgs84_datum_name,
grids_available,
)
Expand Down Expand Up @@ -1475,3 +1476,7 @@ def test_to_3d(crs_input):
def test_to_3d__name():
crs_3d = CRS("EPSG:2056").to_3d(name="TEST")
assert crs_3d.name == "TEST"


def test_crs__pickle(tmp_path):
assert_can_pickle(CRS("epsg:4326"), tmp_path)
20 changes: 13 additions & 7 deletions test/crs/test_crs_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,24 @@
from pyproj.crs.coordinate_system import Cartesian2DCS, Ellipsoidal3DCS, VerticalCS
from pyproj.crs.datum import CustomDatum
from pyproj.crs.enums import VerticalCSAxis
from test.conftest import HAYFORD_ELLIPSOID_NAME
from test.conftest import HAYFORD_ELLIPSOID_NAME, assert_can_pickle


def test_make_projected_crs():
def test_make_projected_crs(tmp_path):
aeaop = AlbersEqualAreaConversion(0, 0)
pc = ProjectedCRS(conversion=aeaop, name="Albers")
assert pc.name == "Albers"
assert pc.type_name == "Projected CRS"
assert pc.coordinate_operation == aeaop
assert_can_pickle(pc, tmp_path)


def test_make_geographic_crs():
def test_make_geographic_crs(tmp_path):
gc = GeographicCRS(name="WGS 84")
assert gc.name == "WGS 84"
assert gc.type_name == "Geographic 2D CRS"
assert gc.to_authority() == ("OGC", "CRS84")
assert_can_pickle(gc, tmp_path)


def test_make_geographic_3d_crs():
Expand All @@ -43,15 +45,16 @@ def test_make_geographic_3d_crs():
assert gcrs.to_authority() == ("IGNF", "WGS84GEODD")


def test_make_derived_geographic_crs():
def test_make_derived_geographic_crs(tmp_path):
conversion = RotatedLatitudeLongitudeConversion(o_lat_p=0, o_lon_p=0)
dgc = DerivedGeographicCRS(base_crs=GeographicCRS(), conversion=conversion)
assert dgc.name == "undefined"
assert dgc.type_name == "Geographic 2D CRS"
assert dgc.coordinate_operation == conversion
assert_can_pickle(dgc, tmp_path)


def test_vertical_crs():
def test_vertical_crs(tmp_path):
vc = VerticalCRS(
name="NAVD88 height",
datum="North American Vertical Datum 1988",
Expand All @@ -61,6 +64,7 @@ def test_vertical_crs():
assert vc.type_name == "Vertical CRS"
assert vc.coordinate_system == VerticalCS()
assert vc.to_json_dict()["geoid_model"]["name"] == "GEOID12B"
assert_can_pickle(vc, tmp_path)


@pytest.mark.parametrize(
Expand All @@ -84,7 +88,7 @@ def test_vertical_crs__chance_cs_axis(axis):
assert vc.coordinate_system == VerticalCS(axis=axis)


def test_compund_crs():
def test_compund_crs(tmp_path):
vertcrs = VerticalCRS(
name="NAVD88 height",
datum="North American Vertical Datum 1988",
Expand All @@ -111,9 +115,10 @@ def test_compund_crs():
assert compcrs.type_name == "Compound CRS"
assert compcrs.sub_crs_list[0].type_name == "Projected CRS"
assert compcrs.sub_crs_list[1].type_name == "Vertical CRS"
assert_can_pickle(compcrs, tmp_path)


def test_bound_crs():
def test_bound_crs(tmp_path):
proj_crs = ProjectedCRS(conversion=UTMConversion(12))
bound_crs = BoundCRS(
source_crs=proj_crs,
Expand All @@ -126,6 +131,7 @@ def test_bound_crs():
assert bound_crs.source_crs.coordinate_operation.name == "UTM zone 12N"
assert bound_crs.coordinate_operation.towgs84 == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
assert bound_crs.target_crs.name == "WGS 84"
assert_can_pickle(bound_crs, tmp_path)


def test_bound_crs__example():
Expand Down

0 comments on commit 316b933

Please sign in to comment.