Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up and added test cases for data_objects.data_containers #1831

Merged
merged 17 commits into from
Jun 29, 2018
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion yt/data_objects/construction_data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,8 @@ def __init__(self, *args, **kwargs):
self._final_start_index = self.global_startindex

def _setup_data_source(self, level_state = None):
if level_state is None: return
if level_state is None:
return super(YTSmoothedCoveringGrid, self)._setup_data_source()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd format this as:

if level_state is None:
    super(YTSmoothedCoveringGrid, self)._setup_data_source()
    return

_setup_data_source doesn't return anything, and this way of calling it makes it clearer that that's the case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# We need a buffer region to allow for zones that contribute to the
# interpolation but are not directly inside our bounds
level_state.data_source = self.ds.region(
Expand Down
40 changes: 9 additions & 31 deletions yt/data_objects/data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,6 @@ def force_array(item, shape):
else:
return np.zeros(shape, dtype='bool')

def restore_field_information_state(func):
"""
A decorator that takes a function with the API of (self, grid, field)
and ensures that after the function is called, the field_parameters will
be returned to normal.
"""
def save_state(self, grid, field=None, *args, **kwargs):
old_params = grid.field_parameters
grid.field_parameters = self.field_parameters
tr = func(self, grid, field, *args, **kwargs)
grid.field_parameters = old_params
return tr
return save_state

def sanitize_weight_field(ds, field, weight):
field_object = ds._get_field_info(field)
if weight is None:
Expand Down Expand Up @@ -215,6 +201,10 @@ def _set_center(self, center):
self.center = self.ds.find_max(("gas", "density"))[1]
elif center.startswith("max_"):
self.center = self.ds.find_max(center[4:])[1]
elif center.lower() == "min":
self.center = self.ds.find_min(("gas", "density"))[1]
elif center.startswith("min_"):
self.center = self.ds.find_min(center[4:])[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

else:
self.center = self.ds.arr(center, 'code_length', dtype='float64')
self.set_field_parameter('center', self.center)
Expand Down Expand Up @@ -429,11 +419,12 @@ def _parameter_iterate(self, seq):
def write_out(self, filename, fields=None, format="%0.16e"):
if fields is None: fields=sorted(self.field_data.keys())
if self._key_fields is None: raise ValueError
field_order = self._key_fields[:]
field_order = sorted(self._determine_fields(self._key_fields))
for field in field_order: self[field]
field_order += [field for field in fields if field not in field_order]
fid = open(filename,"w")
fid.write("\t".join(["#"] + field_order + ["\n"]))
field_header = [str(f) for f in field_order]
fid.write("\t".join(["#"] + field_header + ["\n"]))
field_data = np.array([self.field_data[field] for field in field_order])
for line in range(field_data.shape[1]):
field_data[:,line].tofile(fid, sep="\t", format=format)
Expand Down Expand Up @@ -1673,8 +1664,8 @@ def to_frb(self, width, resolution, center=None, height=None,
"Currently we only support images centered at R=0. " +
"We plan to generalize this in the near future")
from yt.visualization.fixed_resolution import CylindricalFixedResolutionBuffer
if iterable(width):
radius = max(width)
if isinstance(width, tuple):
radius = width[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this change is correct - the original code is correct. We're checking for iterable(width) because width might be a list or an ndarray or some other kind of iterable. It's easier to just use the iterable function to test for this rather than using isinstance and insisting on only accepting a certain number of types.

Also it's a behavior change to pick width[0] instead of max(width). Can you explain why you did that? Choosing max(width) is arbitrary, but since we already made the decision it's best to not change our minds without a good reason.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it as I was getting error for the test case:

def test_to_frb():
    ds = fake_amr_ds(fields=["density", "cell_mass"], geometry="cylindrical",
                     particles=16**3)
    proj = ds.proj("density", weight_field="cell_mass", axis=1,
                   data_source=ds.all_data())
    proj.to_frb((1.0, 'unitary'), 64)

Error:

line 1668, in to_frb
    radius = max(width)
TypeError: '>' not supported between instances of 'str' and 'float'

Copy link
Member

@ngoldbaum ngoldbaum Jun 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Width is allowed to be a tuple of tuples, like this:

((1.0, 'unitary'), (0.7, 'unitary'))

this allows constructing images where the width doesn't equal the height.

For the cylindrical FRB the width really represents a radius, and what happens right now is it assumes if you're passing a tuple you're passing two floats and it assumes the radius you really mean is the maximum of the two floats you passed in.

That's not great, so maybe it's ok to change this. Rather than what you've done though I'd suggest if we're passed a tuple, to check if it's a valid (length, unit) tuple, and if not raise an error. That is, for a cylindrical dataset, it only makes sense to pass in one width to inrepret as the radius, so if someone passes in more than one, we should raise an error.

Basically the code here:

if iterable(width):
w, u = width
if isinstance(w, tuple) and isinstance(u, tuple):
height = u
w, u = w
width = self.ds.quan(w, input_units = u)

But in the if statement raise an error instead of interpreting the tuple as a width and height.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

else:
radius = width
if iterable(resolution): resolution = max(resolution)
Expand Down Expand Up @@ -1990,20 +1981,7 @@ def extract_connected_sets(self, field, num_levels, min_val, max_val,
{'contour_slices_%s' % contour_key: cids})
return cons, contours

def paint_grids(self, field, value, default_value=None):
"""
This function paints every cell in our dataset with a given *value*.
If default_value is given, the other values for the given in every grid
are discarded and replaced with *default_value*. Otherwise, the field is
mandated to 'know how to exist' in the grid.

Note that this only paints the cells *in the dataset*, so cells in grids
with child cells are left untouched.
"""
for grid in self._grids:
if default_value is not None:
grid[field] = np.ones(grid.ActiveDimensions)*default_value
grid[field][self._get_point_indices(grid)] = value

def volume(self):
"""
Expand Down
7 changes: 7 additions & 0 deletions yt/data_objects/tests/test_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,10 @@ def test_clone_sphere():
sp1["density"]

assert_array_equal(sp0["density"], sp1["density"])

def test_clone_cut_region():
ds = fake_random_ds(64, nprocs=4, fields=("density", "temperature"))
dd = ds.all_data()
reg1 = dd.cut_region(["obj['temperature'] > 0.5", "obj['density'] < 0.75"])
reg2 = reg1.clone()
assert_array_equal(reg1['density'], reg2['density'])
21 changes: 9 additions & 12 deletions yt/data_objects/tests/test_connected_sets.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from yt.utilities.answer_testing.level_sets_tests import \
ExtractConnectedSetsTest
from yt.utilities.answer_testing.framework import \
requires_ds, \
data_dir_load
from yt.testing import fake_random_ds
from yt.utilities.answer_testing.level_sets_tests import ExtractConnectedSetsTest

g30 = "IsolatedGalaxy/galaxy0030/galaxy0030"
@requires_ds(g30, big_data=True)
def test_connected_sets():
ds = data_dir_load(g30)
data_source = ds.disk([0.5, 0.5, 0.5], [0., 0., 1.],
(8, 'kpc'), (1, 'kpc'))
yield ExtractConnectedSetsTest(g30, data_source, ("gas", "density"),
5, 1e-24, 8e-24)
ds = fake_random_ds(16, nprocs=8, particles=16**3)
data_source = ds.disk([0.5, 0.5, 0.5], [0., 0., 1.], (8, 'kpc'), (1, 'kpc'))
field = ("gas", "density")
min_val, max_val = data_source[field].min()/2, data_source[field].max()/2
data_source.extract_connected_sets(field, 5, min_val, max_val,
log_space=True, cumulative=True)
yield ExtractConnectedSetsTest(ds, data_source, field, 5, min_val, max_val)
121 changes: 121 additions & 0 deletions yt/data_objects/tests/test_data_containers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import shelve

import numpy as np
from nose.tools import assert_raises
from numpy.testing import assert_array_equal

from yt.data_objects.data_containers import YTDataContainer
from yt.testing import assert_equal, fake_random_ds, fake_amr_ds, fake_particle_ds
from yt.utilities.exceptions import YTFieldNotFound


def test_yt_data_container():
# Test if ds could be None
with assert_raises(RuntimeError) as err:
YTDataContainer(None, None)
desired = 'Error: ds must be set either through class type or parameter' \
' to the constructor'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit: If you put the string inside parentheses then you don't need to use the line continuation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

assert_equal(str(err.exception), desired)

# Test if field_data key exists
ds = fake_random_ds(5)
proj = ds.proj("density", 0, data_source=ds.all_data())
assert_equal('px' in proj.keys(), True)
assert_equal('pz' in proj.keys(), False)

# Delete the key and check if exits
proj.__delitem__('px')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than calling __delitem__ directly I'd format this (and similarly below) as del proj['px']

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

assert_equal('px' in proj.keys(), False)
proj.__delitem__('density')
assert_equal('density' in proj.keys(), False)

# Delete a non-existent field
with assert_raises(YTFieldNotFound) as ex:
proj.__delitem__('p_mass')
desired = "Could not find field '('stream', 'p_mass')' in UniformGridData."
assert_equal(str(ex.exception), desired)

# Test the convert method
assert_equal(proj.convert('HydroMethod'), -1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually a relic of the old pre yt-3.0 API that should have been removed a long time ago. Can you remove this test and the function calling it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I delete convert and __delitem__ methods only?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't delete __delitem__, that's still needed. I was saying that in the test don't invoke __delitem__ directly, instead use the del statement to invoke it indirectly.

You should delete convert since it's a relic function that doesn't do what it says in its own docstring.


def test_write_out():
filename = "sphere.txt"
ds = fake_particle_ds()
sp = ds.sphere(ds.domain_center, 0.25)
sp.write_out(filename)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that you've added a test for this function (awesome!) would you mind also adding a docstring in yt/data_containers.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


with open(filename, "r") as file:
file_row_1 = file.readline()
file_row_2 = file.readline()
file_row_2 = np.array(file_row_2.split('\t'), dtype=np.float64)
sorted_keys = sorted(sp.field_data.keys())
_keys = [str(k) for k in sorted_keys]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the point of signifying these as private with _.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

_keys = "\t".join(["#"] + _keys + ["\n"])
_data = [sp.field_data[k][0] for k in sorted_keys]

assert_equal(_keys, file_row_1)
assert_array_equal(_data, file_row_2)

def test_save_object():
ds = fake_particle_ds()
sp = ds.sphere(ds.domain_center, 0.25)
sp.save_object("my_sphere_1", filename="test_save_obj")
obj = shelve.open("test_save_obj", protocol=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be wrapped in a with statement to prevent the file handler to being left opened if the assertion bellow fail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had done it earlier using with statement however it was failing in python 2.7 (error: AttributeError: DbfilenameShelf instance has no attribute '__exit__').

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, fair enough!

loaded_sphere = obj["my_sphere_1"][1]
assert_array_equal(loaded_sphere.center, sp.center)
assert_equal(loaded_sphere.radius, sp.radius)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also test field access?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

obj.close()

sp.save_object("my_sphere_2")

def test_to_dataframe():
try:
import pandas as pd

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does import pandas as _ make flake8 happy?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than wrapping the test logic in a try/except, please use yt.testing.requires_module, which is a decorator you can apply to a test function to tell nose that the test function requires some optional module (in this case pandas).

You probably also need to install pandas on travis and appveyor to get this test to actually execute.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

pd # to ignore Flake8 error
fields = ["density", "velocity_z"]
ds = fake_random_ds(6)
dd = ds.all_data()
df1 = dd.to_dataframe(fields)
assert_array_equal(dd[fields[0]], df1[fields[0]])
assert_array_equal(dd[fields[1]], df1[fields[1]])
except ImportError:
pass

def test_std():
ds = fake_random_ds(3)
ds.all_data().std('density', weight="velocity_z")

def test_to_frb():
ds = fake_amr_ds(fields=["density", "cell_mass"], geometry="cylindrical",
particles=16**3)
proj = ds.proj("density", weight_field="cell_mass", axis=1,
data_source=ds.all_data())
proj.to_frb((1.0, 'unitary'), 64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you test that the FixedResolutionBuffer object you get back is correct? (e.g. it produces images with a resolution of (64, 64) and that the image covers the full domain? For the latter you can just inspect the state of the FixedResolutionBuffer object (e.g. by checking frb.bounds) rather than trying to make sure the image is correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case of CylindricalFixedResolutionBuffer the object does not have bounds field instead there is buff_size.


def test_extract_isocontours():
# Test isocontour properties for AMRGridData
ds = fake_amr_ds(fields=["density", "cell_mass"], particles=16**3)
dd = ds.all_data()
rho = dd.quantities["WeightedAverageQuantity"]("density",
weight="cell_mass")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd format these lines like:

q = dd.quantities["WeightedAverageQuantity"]
rho = q("density", weight="cell_mass")

Often when you find yourself hitting a line length limit you can get around it by defining an intermediate variable. This has the side benefit of often making the code more understandable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

dd.extract_isocontours("density", rho, "triangles.obj", True)
dd.calculate_isocontour_flux("density", rho, "x", "y", "z",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This formatting choice seems odd (the complete line is shorter than the lines above). Are you using an auto-formatter like yapf or eyeballing it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd format these two lines like:

dd.calculate_isocontour_flux(
    "density", rho, "x", "y", "z", "dx")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

"dx")

# Test error in case of ParticleData
ds = fake_particle_ds()
dd = ds.all_data()
rho = dd.quantities["WeightedAverageQuantity"]("particle_velocity_x",
weight="particle_mass")
with assert_raises(NotImplementedError):
dd.extract_isocontours("density", rho, sample_values='x')

def test_extract_connected_sets():

ds = fake_random_ds(16, nprocs=8, particles=16 ** 3)
data_source = ds.disk([0.5, 0.5, 0.5], [0., 0., 1.], (8, 'kpc'), (1, 'kpc'))
field = ("gas", "density")
min_val, max_val = data_source[field].min() / 2, data_source[field].max() / 2

data_source.extract_connected_sets(field, 3, min_val, max_val,
log_space=True, cumulative=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you test that the data returned by this function are correct? Also the formatting right here is a little odd, this line should be aligned with field, not 3.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this test as it is already present in test_connected_sets.py.

20 changes: 16 additions & 4 deletions yt/data_objects/tests/test_spheres.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import numpy as np
from numpy.testing import assert_array_equal

from yt.data_objects.profiles import create_profile
from yt.testing import \
fake_random_ds, \
assert_equal, \
periodicity_cases
from yt.testing import fake_random_ds, assert_equal, periodicity_cases


def setup():
from yt.config import ytcfg
Expand Down Expand Up @@ -65,3 +64,16 @@ def test_domain_sphere():
for f in _fields_to_compare:
sp[f].sort()
assert_equal(sp[f], ref_sp[f])

def test_sphere_center():
ds = fake_random_ds(16, nprocs=8,
fields=("density", "temperature", "velocity_x"))

# Test if we obtain same center in different ways
sp1 = ds.sphere("max", (0.25, 'unitary'))
sp2 = ds.sphere("max_density", (0.25, 'unitary'))
assert_array_equal(sp1.center, sp2.center)

sp1 = ds.sphere("min", (0.25, 'unitary'))
sp2 = ds.sphere("min_density", (0.25, 'unitary'))
assert_array_equal(sp1.center, sp2.center)
Empty file added yt/utilities/input_validator.py
Empty file.
Empty file.