Skip to content

Commit

Permalink
Merge pull request #1831 from git-abhishek/cc-do-containers
Browse files Browse the repository at this point in the history
Clean up and added test cases for `data_objects.data_containers`
  • Loading branch information
Nathan Goldbaum authored Jun 29, 2018
2 parents a209a67 + 0054e09 commit a5db4fc
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 86 deletions.
9 changes: 5 additions & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ env:
FLAKE8=flake8
CODECOV=codecov
COVERAGE=coverage
PANDAS=pandas

before_install:
- |
Expand Down Expand Up @@ -50,15 +51,15 @@ install:
pip install --upgrade wheel
pip install --upgrade setuptools
# Install dependencies
pip install mock $NUMPY $SCIPY $H5PY $CYTHON $MATPLOTLIB $SYMPY $FASTCACHE $IPYTHON $FLAKE8 $CODECOV $COVERAGE nose nose-timer
pip install mock $NUMPY $SCIPY $H5PY $CYTHON $MATPLOTLIB $SYMPY $FASTCACHE $IPYTHON $FLAKE8 $CODECOV $COVERAGE $PANDAS nose nose-timer
# install yt
pip install -e .
jobs:
include:
- stage: lint
python: 3.4
env: NUMPY=numpy==1.10.4 CYTHON=cython==0.24 MATPLOTLIB= SYMPY= H5PY= SCIPY= FASTCACHE= IPYTHON=
python: 3.6
env: NUMPY=numpy==1.10.4 CYTHON=cython==0.24 MATPLOTLIB= SYMPY= H5PY= SCIPY= FASTCACHE= IPYTHON= PANDAS=
script: flake8 yt/

- stage: tests
Expand All @@ -73,7 +74,7 @@ jobs:

- stage: tests
python: 3.4
env: FLAKE8=
env: FLAKE8= PANDAS=pandas==0.20
script: coverage run $(which nosetests) -c nose.cfg yt

- stage: tests
Expand Down
2 changes: 1 addition & 1 deletion appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ install:
- "python --version"

# Install specified version of numpy and dependencies
- "conda install --yes -c conda-forge numpy scipy nose setuptools ipython Cython sympy fastcache h5py matplotlib mock"
- "conda install --yes -c conda-forge numpy scipy nose setuptools ipython Cython sympy fastcache h5py matplotlib mock pandas"
- "pip install -e ."

# Not a .NET project
Expand Down
4 changes: 3 additions & 1 deletion yt/data_objects/construction_data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,9 @@ def __init__(self, *args, **kwargs):
self._final_start_index = self.global_startindex

def _setup_data_source(self, level_state = None):
if level_state is None: return
if level_state is None:
super(YTSmoothedCoveringGrid, self)._setup_data_source()
return
# We need a buffer region to allow for zones that contribute to the
# interpolation but are not directly inside our bounds
level_state.data_source = self.ds.region(
Expand Down
138 changes: 74 additions & 64 deletions yt/data_objects/data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
mylog, \
ensure_list, \
fix_axis, \
iterable
iterable, validate_width_tuple
from yt.units.unit_object import UnitParseError
from yt.units.yt_array import \
YTArray, \
Expand All @@ -50,7 +50,7 @@
YTDataSelectorNotImplemented, \
YTDimensionalityError, \
YTBooleanObjectError, \
YTBooleanObjectsWrongDataset
YTBooleanObjectsWrongDataset, YTException
from yt.utilities.lib.marching_cubes import \
march_cubes_grid, march_cubes_grid_flux
from yt.utilities.parallel_tools.parallel_analysis_interface import \
Expand All @@ -71,29 +71,6 @@

data_object_registry = {}

def force_array(item, shape):
try:
return item.copy()
except AttributeError:
if item:
return np.ones(shape, dtype='bool')
else:
return np.zeros(shape, dtype='bool')

def restore_field_information_state(func):
"""
A decorator that takes a function with the API of (self, grid, field)
and ensures that after the function is called, the field_parameters will
be returned to normal.
"""
def save_state(self, grid, field=None, *args, **kwargs):
old_params = grid.field_parameters
grid.field_parameters = self.field_parameters
tr = func(self, grid, field, *args, **kwargs)
grid.field_parameters = old_params
return tr
return save_state

def sanitize_weight_field(ds, field, weight):
field_object = ds._get_field_info(field)
if weight is None:
Expand Down Expand Up @@ -215,6 +192,10 @@ def _set_center(self, center):
self.center = self.ds.find_max(("gas", "density"))[1]
elif center.startswith("max_"):
self.center = self.ds.find_max(center[4:])[1]
elif center.lower() == "min":
self.center = self.ds.find_min(("gas", "density"))[1]
elif center.startswith("min_"):
self.center = self.ds.find_min(center[4:])[1]
else:
self.center = self.ds.arr(center, 'code_length', dtype='float64')
self.set_field_parameter('center', self.center)
Expand Down Expand Up @@ -242,13 +223,6 @@ def has_field_parameter(self, name):
"""
return name in self.field_parameters

def convert(self, datatype):
"""
This will attempt to convert a given unit to cgs from code units.
It either returns the multiplicative factor or throws a KeyError.
"""
return self.ds[datatype]

def clear_data(self):
"""
Clears out all data from the YTDataContainer instance, freeing memory.
Expand Down Expand Up @@ -427,20 +401,72 @@ def _parameter_iterate(self, seq):

_key_fields = None
def write_out(self, filename, fields=None, format="%0.16e"):
if fields is None: fields=sorted(self.field_data.keys())
if self._key_fields is None: raise ValueError
field_order = self._key_fields[:]
"""Write out the YTDataContainer object in a text file.
This function will take a data object and produce a tab delimited text
file containing the fields presently existing and the fields given in
the ``fields`` list.
Parameters
----------
filename : String
The name of the file to write to.
fields : List of string, Default = None
If this is supplied, these fields will be added to the list of
fields to be saved to disk. If not supplied, whatever fields
presently exist will be used.
format : String, Default = "%0.16e"
Format of numbers to be written in the file.
Raises
------
ValueError
Raised when there is no existing field.
YTException
Raised when field_type of supplied fields is inconsistent with the
field_type of existing fields.
Examples
--------
>>> ds = fake_particle_ds()
>>> sp = ds.sphere(ds.domain_center, 0.25)
>>> sp.write_out("sphere_1.txt")
>>> sp.write_out("sphere_2.txt", fields=["cell_volume"])
"""
if fields is None:
fields = sorted(self.field_data.keys())

if self._key_fields is None:
raise ValueError

field_order = self._key_fields
diff_fields = [field for field in fields if field not in field_order]
field_order += diff_fields
field_order = sorted(self._determine_fields(field_order))
field_types = {u for u, v in field_order}

if len(field_types) != 1:
diff_fields = self._determine_fields(diff_fields)
req_ftype = self._determine_fields(self._key_fields[0])[0][0]
f_type = {f for f in diff_fields if f[0] != req_ftype }
msg = ("Field type %s of the supplied field %s is inconsistent"
" with field type '%s'." %
([f[0] for f in f_type], [f[1] for f in f_type], req_ftype))
raise YTException(msg)

for field in field_order: self[field]
field_order += [field for field in fields if field not in field_order]
fid = open(filename,"w")
fid.write("\t".join(["#"] + field_order + ["\n"]))
field_data = np.array([self.field_data[field] for field in field_order])
for line in range(field_data.shape[1]):
field_data[:,line].tofile(fid, sep="\t", format=format)
fid.write("\n")
fid.close()

def save_object(self, name, filename = None):
with open(filename, "w") as fid:
field_header = [str(f) for f in field_order]
fid.write("\t".join(["#"] + field_header + ["\n"]))
field_data = np.array([self.field_data[field] for field in field_order])
for line in range(field_data.shape[1]):
field_data[:, line].tofile(fid, sep="\t", format=format)
fid.write("\n")

def save_object(self, name, filename=None):
"""
Save an object. If *filename* is supplied, it will be stored in
a :mod:`shelve` file of that name. Otherwise, it will be stored via
Expand All @@ -455,7 +481,7 @@ def save_object(self, name, filename = None):
else:
self.index.save_object(self, name)

def to_dataframe(self, fields = None):
def to_dataframe(self, fields=None):
r"""Export a data object to a pandas DataFrame.
This function will take a data object and construct from it and
Expand Down Expand Up @@ -1673,12 +1699,9 @@ def to_frb(self, width, resolution, center=None, height=None,
"Currently we only support images centered at R=0. " +
"We plan to generalize this in the near future")
from yt.visualization.fixed_resolution import CylindricalFixedResolutionBuffer
if iterable(width):
radius = max(width)
else:
radius = width
validate_width_tuple(width)
if iterable(resolution): resolution = max(resolution)
frb = CylindricalFixedResolutionBuffer(self, radius, resolution)
frb = CylindricalFixedResolutionBuffer(self, width, resolution)
return frb

if center is None:
Expand Down Expand Up @@ -1990,20 +2013,7 @@ def extract_connected_sets(self, field, num_levels, min_val, max_val,
{'contour_slices_%s' % contour_key: cids})
return cons, contours

def paint_grids(self, field, value, default_value=None):
"""
This function paints every cell in our dataset with a given *value*.
If default_value is given, the other values for the given in every grid
are discarded and replaced with *default_value*. Otherwise, the field is
mandated to 'know how to exist' in the grid.

Note that this only paints the cells *in the dataset*, so cells in grids
with child cells are left untouched.
"""
for grid in self._grids:
if default_value is not None:
grid[field] = np.ones(grid.ActiveDimensions)*default_value
grid[field][self._get_point_indices(grid)] = value

def volume(self):
"""
Expand Down
7 changes: 7 additions & 0 deletions yt/data_objects/tests/test_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,10 @@ def test_clone_sphere():
sp1["density"]

assert_array_equal(sp0["density"], sp1["density"])

def test_clone_cut_region():
ds = fake_random_ds(64, nprocs=4, fields=("density", "temperature"))
dd = ds.all_data()
reg1 = dd.cut_region(["obj['temperature'] > 0.5", "obj['density'] < 0.75"])
reg2 = reg1.clone()
assert_array_equal(reg1['density'], reg2['density'])
21 changes: 9 additions & 12 deletions yt/data_objects/tests/test_connected_sets.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from yt.utilities.answer_testing.level_sets_tests import \
ExtractConnectedSetsTest
from yt.utilities.answer_testing.framework import \
requires_ds, \
data_dir_load
from yt.testing import fake_random_ds
from yt.utilities.answer_testing.level_sets_tests import ExtractConnectedSetsTest

g30 = "IsolatedGalaxy/galaxy0030/galaxy0030"
@requires_ds(g30, big_data=True)
def test_connected_sets():
ds = data_dir_load(g30)
data_source = ds.disk([0.5, 0.5, 0.5], [0., 0., 1.],
(8, 'kpc'), (1, 'kpc'))
yield ExtractConnectedSetsTest(g30, data_source, ("gas", "density"),
5, 1e-24, 8e-24)
ds = fake_random_ds(16, nprocs=8, particles=16**3)
data_source = ds.disk([0.5, 0.5, 0.5], [0., 0., 1.], (8, 'kpc'), (1, 'kpc'))
field = ("gas", "density")
min_val, max_val = data_source[field].min()/2, data_source[field].max()/2
data_source.extract_connected_sets(field, 5, min_val, max_val,
log_space=True, cumulative=True)
yield ExtractConnectedSetsTest(ds, data_source, field, 5, min_val, max_val)
Loading

0 comments on commit a5db4fc

Please sign in to comment.