Skip to content

Commit

Permalink
Merge pull request #315 from creare-com/feature/readonly_tag
Browse files Browse the repository at this point in the history
general readonly tag
  • Loading branch information
jmilloy authored Sep 27, 2019
2 parents f0beb50 + 23c3425 commit 2a55a65
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 140 deletions.
7 changes: 3 additions & 4 deletions podpac/core/data/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ class DataSource(Node):
Custom DataSource Nodes must implement the :meth:`get_data` and :meth:`get_native_coordinates` methods.
"""

source = tl.Any()
native_coordinates = tl.Instance(Coordinates)
source = tl.Any().tag(readonly=True)
native_coordinates = tl.Instance(Coordinates).tag(readonly=True)
interpolation = interpolation_trait()
coordinate_index_type = tl.Enum(["list", "numpy", "xarray", "pandas"], default_value="numpy")
nan_vals = tl.List(allow_none=True)
Expand All @@ -170,8 +170,7 @@ class DataSource(Node):
# when native_coordinates is not defined, default calls get_native_coordinates
@tl.default("native_coordinates")
def _default_native_coordinates(self):
self.native_coordinates = self.get_native_coordinates()
return self.native_coordinates
return self.get_native_coordinates()

# this adds a more helpful error message if user happens to try an inspect _interpolation before evaluate
@tl.default("_interpolation")
Expand Down
3 changes: 2 additions & 1 deletion podpac/core/data/test/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def test_init_interpolators(self):
Interpolation({"default": {"method": "nearest", "params": {"spatial_tolerance": "tol"}}})

# should not allow undefined params
interp = Interpolation({"default": {"method": "nearest", "params": {"myarg": 1}}})
with pytest.warns(DeprecationWarning): # eventually, Traitlets will raise an exception here
interp = Interpolation({"default": {"method": "nearest", "params": {"myarg": 1}}})
with pytest.raises(AttributeError):
assert interp.config[("default",)]["interpolators"][0].myarg == "tol"

Expand Down
41 changes: 1 addition & 40 deletions podpac/core/data/test/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,29 +172,10 @@ def test_auth_session(self):
assert node.auth_session is None

def test_dataset(self):
"""test dataset attribute and traitlet default """
"""test dataset trait """
self.mock_pydap()

node = PyDAP(source=self.source, datakey=self.datakey)

# override/reset source on dataset opening
node._open_dataset(source="newsource")
assert node.source == "newsource"
assert isinstance(node.dataset, pydap.model.DatasetType)

def test_source(self):
"""test source attribute and trailet observer """
self.mock_pydap()

node = PyDAP(source=self.source, datakey=self.datakey, native_coordinates=self.coordinates)

# observe source
node._update_dataset(change={"old": None})
assert node.source == self.source

output = node._update_dataset(change={"new": "newsource", "old": "oldsource"})
assert node.source == "newsource"
assert node.native_coordinates == self.coordinates
assert isinstance(node.dataset, pydap.model.DatasetType)

def test_get_data(self):
Expand Down Expand Up @@ -298,13 +279,6 @@ def test_dataset(self):
RasterReader = rasterio.io.DatasetReader # Rasterio >= v1.0
assert isinstance(node.dataset, RasterReader)

# update source when asked
with pytest.raises(rasterio.errors.RasterioIOError):
node.source = "assets/not-tiff"
node._open_dataset()

assert node.source == "assets/not-tiff"

node.close_dataset()

def test_default_native_coordinates(self):
Expand Down Expand Up @@ -349,19 +323,6 @@ def test_get_band_numbers(self):
assert isinstance(numbers, np.ndarray)
np.testing.assert_array_equal(numbers, np.arange(3) + 1)

def test_source(self):
"""test source attribute and trailets observe"""

node = Rasterio(source=self.source)
assert node.source == self.source

def test_change_source(self):
node = Rasterio(source=self.source)
assert node.band_count == 3

node.source = self.source.replace("RGB.byte.tif", "h5raster.hdf5")
assert node.band_count == 1


class TestH5PY(object):
source = os.path.join(os.path.dirname(__file__), "assets/h5raster.hdf5")
Expand Down
127 changes: 45 additions & 82 deletions podpac/core/data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class Array(DataSource):
`native_coordinates` need to supplied by the user when instantiating this node.
"""

source = ArrayTrait()
source = ArrayTrait().tag(readonly=True)

@tl.validate("source")
def _validate_source(self, d):
Expand Down Expand Up @@ -123,16 +123,17 @@ class PyDAP(DataSource):
auth_session instead if you have security concerns.
"""

# required inputs
source = tl.Unicode(default_value="")
source = tl.Unicode().tag(readonly=True)
dataset = tl.Instance("pydap.model.DatasetType").tag(readonly=True)

# node attrs
datakey = tl.Unicode().tag(attr=True)

# optional inputs and later defined traits
auth_session = tl.Instance(authentication.Session, allow_none=True)
# optional inputs
auth_class = tl.Type(authentication.Session)
username = tl.Unicode(None, allow_none=True)
password = tl.Unicode(None, allow_none=True)
dataset = tl.Instance("pydap.model.DatasetType")
auth_session = tl.Instance(authentication.Session, allow_none=True)
username = tl.Unicode(default_value=None, allow_none=True)
password = tl.Unicode(default_value=None, allow_none=True)

@tl.default("auth_session")
def _auth_session_default(self):
Expand All @@ -157,7 +158,7 @@ def _auth_session_default(self):
return session

@tl.default("dataset")
def _open_dataset(self, source=None):
def _open_dataset(self):
"""Summary
Parameters
Expand All @@ -170,46 +171,26 @@ def _open_dataset(self, source=None):
TYPE
Description
"""
# TODO: is source ever None?
# TODO: enforce string source
if source is None:
source = self.source
else:
self.source = source

# auth session
# if self.auth_session:
try:
dataset = pydap.client.open_url(source, session=self.auth_session)
dataset = self._open_url()
except Exception:
# TODO handle a 403 error
# TODO: Check Url (probably inefficient...)
try:
self.auth_session.get(self.source + ".dds")
dataset = pydap.client.open_url(source, session=self.auth_session)
dataset = self._open_url()
except Exception:
# TODO: handle 403 error
print ("Warning, dataset could not be opened. Check login credentials.")
dataset = None

return dataset

@tl.observe("source")
def _update_dataset(self, change=None):
if change is None:
return

if change["old"] == None or change["old"] == "":
return

if self.dataset is not None and "new" in change:
self.dataset = self._open_dataset(source=change["new"])

try:
if self.native_coordinates is not None:
self.native_coordinates = self.get_native_coordinates()
except NotImplementedError:
pass
def _open_url(self):
return pydap.client.open_url(self.source, session=self.auth_session)

@common_doc(COMMON_DATA_DOC)
def get_native_coordinates(self):
Expand Down Expand Up @@ -279,14 +260,16 @@ class CSV(DataSource):
Raw Pandas DataFrame used to read the data
"""

source = tl.Unicode()
source = tl.Unicode().tag(readonly=True)
dataset = tl.Instance(pd.DataFrame).tag(readonly=True)

# node attrs
dims = tl.List(default_value=["alt", "lat", "lon", "time"]).tag(attr=True)
alt_col = tl.Union([tl.Unicode(), tl.Int()]).tag(attr=True)
lat_col = tl.Union([tl.Unicode(), tl.Int()]).tag(attr=True)
lon_col = tl.Union([tl.Unicode(), tl.Int()]).tag(attr=True)
time_col = tl.Union([tl.Unicode(), tl.Int()]).tag(attr=True)
data_col = tl.Union([tl.Unicode(), tl.Int()]).tag(attr=True)
dims = tl.List(default_value=["alt", "lat", "lon", "time"]).tag(attr=True)
dataset = tl.Instance(pd.DataFrame)

def _first_init(self, **kwargs):
# First part of if tests to make sure this is the CSV parent class
Expand Down Expand Up @@ -397,9 +380,10 @@ class Rasterio(DataSource):
* Linux: export CURL_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt
"""

source = tl.Union([tl.Unicode(), tl.Instance(BytesIO)])
source = tl.Union([tl.Unicode(), tl.Instance(BytesIO)]).tag(readonly=True)
dataset = tl.Any().tag(readonly=True)

dataset = tl.Any(allow_none=True)
# node attrs
band = tl.CInt(1).tag(attr=True)

@tl.default("dataset")
Expand Down Expand Up @@ -431,25 +415,6 @@ def close_dataset(self):
"""
self.dataset.close()

@tl.observe("source")
def _update_dataset(self, change):
if hasattr(self, "_band_count"):
delattr(self, "_band_count")

if hasattr(self, "_band_descriptions"):
delattr(self, "_band_descriptions")

if hasattr(self, "_band_keys"):
delattr(self, "_band_keys")

# only update dataset if dataset trait has been defined the first time
if trait_is_defined(self, "dataset"):
self.dataset = self._open_dataset()

# update native_coordinates if they have been defined
if trait_is_defined(self, "native_coordinates"):
self.native_coordinates = self.get_native_coordinates()

@common_doc(COMMON_DATA_DOC)
def get_native_coordinates(self):
"""{get_native_coordinates}
Expand Down Expand Up @@ -672,13 +637,15 @@ class H5PY(DatasetCoordinatedMixin, DataSource):
Default is 'r'. The mode used to open the HDF5 file. Options are r, r+, w, w- or x, a (see h5py.File).
"""

source = tl.Unicode()
dataset = tl.Any(allow_none=True)
source = tl.Unicode().tag(readonly=True)
dataset = tl.Any().tag(readonly=True)
file_mode = tl.Unicode(default_value="r").tag(readonly=True)

# node attrs
datakey = tl.Unicode().tag(attr=True)
file_mode = tl.Unicode(default_value="r")

@tl.default("dataset")
def _open_dataset(self, source=None):
def _open_dataset(self):
"""Opens the data source
Parameters
Expand All @@ -691,30 +658,16 @@ def _open_dataset(self, source=None):
Any
raster.open(source)
"""
# TODO: update this to remove block (see Rasterio)
if source is None:
source = self.source
else:
self.source = source

# TODO: dataset should not open by default
# prefer with as: syntax
return h5py.File(source, self.file_mode)
return h5py.File(self.source, self.file_mode)

def close_dataset(self):
"""Closes the file for the datasource
"""
self.dataset.close()

@tl.observe("source")
def _update_dataset(self, change):
# TODO: update this to look like Rasterio
if self.dataset is not None:
self.close_dataset()
self.dataset = self._open_dataset(change["new"])
if trait_is_defined(self, "native_coordinates"):
self.native_coordinates = self.get_native_coordinates()

@common_doc(COMMON_DATA_DOC)
def get_data(self, coordinates, coordinates_index):
"""{get_data}
Expand Down Expand Up @@ -749,10 +702,13 @@ def _find_h5py_keys(obj, keys=[]):


class Zarr(DatasetCoordinatedMixin, DataSource):
source = tl.Unicode(allow_none=True)
dataset = tl.Any()
source = tl.Unicode(default_value=None, allow_none=True).tag(readonly=True)
dataset = tl.Any().tag(readonly=True)

# node attrs
datakey = tl.Unicode().tag(attr=True)

# optional inputs
access_key_id = tl.Unicode()
secret_access_key = tl.Unicode()
region_name = tl.Unicode()
Expand Down Expand Up @@ -852,11 +808,13 @@ class WCS(DataSource):
The coordinates of the WCS source
"""

source = tl.Unicode()
source = tl.Unicode().tag(readonly=True)
wcs_coordinates = tl.Instance(Coordinates).tag(readonly=True) # default below

# node attrs
layer_name = tl.Unicode().tag(attr=True)
version = tl.Unicode(WCS_DEFAULT_VERSION).tag(attr=True)
crs = tl.Unicode(WCS_DEFAULT_CRS).tag(attr=True)
wcs_coordinates = tl.Instance(Coordinates) # default below

_get_capabilities_qs = tl.Unicode("SERVICE=WCS&REQUEST=DescribeCoverage&" "VERSION={version}&COVERAGE={layer}")
_get_data_qs = tl.Unicode(
Expand Down Expand Up @@ -1188,7 +1146,9 @@ class ReprojectedSource(DataSource):
Coordinates where the source node should be evaluated.
"""

source = NodeTrait()
source = NodeTrait().tag(readonly=True)

# node attrs
source_interpolation = interpolation_trait().tag(attr=True)
reprojected_coordinates = tl.Instance(Coordinates).tag(attr=True)

Expand Down Expand Up @@ -1270,9 +1230,12 @@ class Dataset(DataSource):
For example, if the data contains ['lat', 'lon', 'channel'], the second channel can be selected using `extra_dim=dict(channel=1)`
"""

source = tl.Any(allow_none=True).tag(readonly=True)
dataset = tl.Instance(xr.Dataset).tag(readonly=True)

# node attrs
extra_dim = tl.Dict({}).tag(attr=True)
datakey = tl.Unicode().tag(attr=True)
dataset = tl.Instance(xr.Dataset)

@tl.default("dataset")
def _dataset_default(self):
Expand Down
10 changes: 10 additions & 0 deletions podpac/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,18 @@ def _validate_units(self, d):

def __init__(self, **kwargs):
""" Do not overwrite me """

tkwargs = self._first_init(**kwargs)

# make tagged "readonly" and "attr" traits read_only, and set them using set_trait
# NOTE: The set_trait is required because this sets the traits read_only at the *class* level;
# on subsequent initializations, they will already be read_only.
for name, trait in self.traits().items():
if trait.metadata.get("readonly") or trait.metadata.get("attr"):
if name in tkwargs:
self.set_trait(name, tkwargs.pop(name))
trait.read_only = True

# Call traitlest constructor
super(Node, self).__init__(**tkwargs)
self.init()
Expand Down
Loading

0 comments on commit 2a55a65

Please sign in to comment.