-
Notifications
You must be signed in to change notification settings - Fork 279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean up and added test cases for data_objects.data_containers
#1831
Changes from 12 commits
cc5f39a
3fca6bc
ef19719
bf72557
680fb92
78b567a
218e33c
e2b4d70
1550eab
11f62ca
77c27d5
9abeca0
70682e5
cb3bf71
4f8fb5d
d12ca6e
0054e09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -80,20 +80,6 @@ def force_array(item, shape): | |||||||||||||
else: | ||||||||||||||
return np.zeros(shape, dtype='bool') | ||||||||||||||
|
||||||||||||||
def restore_field_information_state(func): | ||||||||||||||
""" | ||||||||||||||
A decorator that takes a function with the API of (self, grid, field) | ||||||||||||||
and ensures that after the function is called, the field_parameters will | ||||||||||||||
be returned to normal. | ||||||||||||||
""" | ||||||||||||||
def save_state(self, grid, field=None, *args, **kwargs): | ||||||||||||||
old_params = grid.field_parameters | ||||||||||||||
grid.field_parameters = self.field_parameters | ||||||||||||||
tr = func(self, grid, field, *args, **kwargs) | ||||||||||||||
grid.field_parameters = old_params | ||||||||||||||
return tr | ||||||||||||||
return save_state | ||||||||||||||
|
||||||||||||||
def sanitize_weight_field(ds, field, weight): | ||||||||||||||
field_object = ds._get_field_info(field) | ||||||||||||||
if weight is None: | ||||||||||||||
|
@@ -215,6 +201,10 @@ def _set_center(self, center): | |||||||||||||
self.center = self.ds.find_max(("gas", "density"))[1] | ||||||||||||||
elif center.startswith("max_"): | ||||||||||||||
self.center = self.ds.find_max(center[4:])[1] | ||||||||||||||
elif center.lower() == "min": | ||||||||||||||
self.center = self.ds.find_min(("gas", "density"))[1] | ||||||||||||||
elif center.startswith("min_"): | ||||||||||||||
self.center = self.ds.find_min(center[4:])[1] | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch! |
||||||||||||||
else: | ||||||||||||||
self.center = self.ds.arr(center, 'code_length', dtype='float64') | ||||||||||||||
self.set_field_parameter('center', self.center) | ||||||||||||||
|
@@ -429,11 +419,12 @@ def _parameter_iterate(self, seq): | |||||||||||||
def write_out(self, filename, fields=None, format="%0.16e"): | ||||||||||||||
if fields is None: fields=sorted(self.field_data.keys()) | ||||||||||||||
if self._key_fields is None: raise ValueError | ||||||||||||||
field_order = self._key_fields[:] | ||||||||||||||
field_order = sorted(self._determine_fields(self._key_fields)) | ||||||||||||||
for field in field_order: self[field] | ||||||||||||||
field_order += [field for field in fields if field not in field_order] | ||||||||||||||
fid = open(filename,"w") | ||||||||||||||
fid.write("\t".join(["#"] + field_order + ["\n"])) | ||||||||||||||
field_header = [str(f) for f in field_order] | ||||||||||||||
fid.write("\t".join(["#"] + field_header + ["\n"])) | ||||||||||||||
field_data = np.array([self.field_data[field] for field in field_order]) | ||||||||||||||
for line in range(field_data.shape[1]): | ||||||||||||||
field_data[:,line].tofile(fid, sep="\t", format=format) | ||||||||||||||
|
@@ -1673,8 +1664,8 @@ def to_frb(self, width, resolution, center=None, height=None, | |||||||||||||
"Currently we only support images centered at R=0. " + | ||||||||||||||
"We plan to generalize this in the near future") | ||||||||||||||
from yt.visualization.fixed_resolution import CylindricalFixedResolutionBuffer | ||||||||||||||
if iterable(width): | ||||||||||||||
radius = max(width) | ||||||||||||||
if isinstance(width, tuple): | ||||||||||||||
radius = width[0] | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this change is correct - the original code is correct. We're checking for Also it's a behavior change to pick There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed it as I was getting error for the test case:
Error:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Width is allowed to be a tuple of tuples, like this:
this allows constructing images where the width doesn't equal the height. For the cylindrical FRB the width really represents a radius, and what happens right now is it assumes if you're passing a tuple you're passing two floats and it assumes the radius you really mean is the maximum of the two floats you passed in. That's not great, so maybe it's ok to change this. Rather than what you've done though I'd suggest if we're passed a tuple, to check if it's a valid Basically the code here: yt/yt/data_objects/data_containers.py Lines 1691 to 1696 in cc299b0
But in the if statement raise an error instead of interpreting the tuple as a width and height. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||||||||||
else: | ||||||||||||||
radius = width | ||||||||||||||
if iterable(resolution): resolution = max(resolution) | ||||||||||||||
|
@@ -1990,20 +1981,7 @@ def extract_connected_sets(self, field, num_levels, min_val, max_val, | |||||||||||||
{'contour_slices_%s' % contour_key: cids}) | ||||||||||||||
return cons, contours | ||||||||||||||
|
||||||||||||||
def paint_grids(self, field, value, default_value=None): | ||||||||||||||
""" | ||||||||||||||
This function paints every cell in our dataset with a given *value*. | ||||||||||||||
If default_value is given, the other values for the given in every grid | ||||||||||||||
are discarded and replaced with *default_value*. Otherwise, the field is | ||||||||||||||
mandated to 'know how to exist' in the grid. | ||||||||||||||
|
||||||||||||||
Note that this only paints the cells *in the dataset*, so cells in grids | ||||||||||||||
with child cells are left untouched. | ||||||||||||||
""" | ||||||||||||||
for grid in self._grids: | ||||||||||||||
if default_value is not None: | ||||||||||||||
grid[field] = np.ones(grid.ActiveDimensions)*default_value | ||||||||||||||
grid[field][self._get_point_indices(grid)] = value | ||||||||||||||
|
||||||||||||||
def volume(self): | ||||||||||||||
""" | ||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,11 @@ | ||
from yt.utilities.answer_testing.level_sets_tests import \ | ||
ExtractConnectedSetsTest | ||
from yt.utilities.answer_testing.framework import \ | ||
requires_ds, \ | ||
data_dir_load | ||
from yt.testing import fake_random_ds | ||
from yt.utilities.answer_testing.level_sets_tests import ExtractConnectedSetsTest | ||
|
||
g30 = "IsolatedGalaxy/galaxy0030/galaxy0030" | ||
@requires_ds(g30, big_data=True) | ||
def test_connected_sets(): | ||
ds = data_dir_load(g30) | ||
data_source = ds.disk([0.5, 0.5, 0.5], [0., 0., 1.], | ||
(8, 'kpc'), (1, 'kpc')) | ||
yield ExtractConnectedSetsTest(g30, data_source, ("gas", "density"), | ||
5, 1e-24, 8e-24) | ||
ds = fake_random_ds(16, nprocs=8, particles=16**3) | ||
data_source = ds.disk([0.5, 0.5, 0.5], [0., 0., 1.], (8, 'kpc'), (1, 'kpc')) | ||
field = ("gas", "density") | ||
min_val, max_val = data_source[field].min()/2, data_source[field].max()/2 | ||
data_source.extract_connected_sets(field, 5, min_val, max_val, | ||
log_space=True, cumulative=True) | ||
yield ExtractConnectedSetsTest(ds, data_source, field, 5, min_val, max_val) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import shelve | ||
|
||
import numpy as np | ||
from nose.tools import assert_raises | ||
from numpy.testing import assert_array_equal | ||
|
||
from yt.data_objects.data_containers import YTDataContainer | ||
from yt.testing import assert_equal, fake_random_ds, fake_amr_ds, fake_particle_ds | ||
from yt.utilities.exceptions import YTFieldNotFound | ||
|
||
|
||
def test_yt_data_container(): | ||
# Test if ds could be None | ||
with assert_raises(RuntimeError) as err: | ||
YTDataContainer(None, None) | ||
desired = 'Error: ds must be set either through class type or parameter' \ | ||
' to the constructor' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style nit: If you put the string inside parentheses then you don't need to use the line continuation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
assert_equal(str(err.exception), desired) | ||
|
||
# Test if field_data key exists | ||
ds = fake_random_ds(5) | ||
proj = ds.proj("density", 0, data_source=ds.all_data()) | ||
assert_equal('px' in proj.keys(), True) | ||
assert_equal('pz' in proj.keys(), False) | ||
|
||
# Delete the key and check if exits | ||
proj.__delitem__('px') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
assert_equal('px' in proj.keys(), False) | ||
proj.__delitem__('density') | ||
assert_equal('density' in proj.keys(), False) | ||
|
||
# Delete a non-existent field | ||
with assert_raises(YTFieldNotFound) as ex: | ||
proj.__delitem__('p_mass') | ||
desired = "Could not find field '('stream', 'p_mass')' in UniformGridData." | ||
assert_equal(str(ex.exception), desired) | ||
|
||
# Test the convert method | ||
assert_equal(proj.convert('HydroMethod'), -1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is actually a relic of the old pre yt-3.0 API that should have been removed a long time ago. Can you remove this test and the function calling it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should I delete There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't delete You should delete |
||
|
||
def test_write_out(): | ||
filename = "sphere.txt" | ||
ds = fake_particle_ds() | ||
sp = ds.sphere(ds.domain_center, 0.25) | ||
sp.write_out(filename) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that you've added a test for this function (awesome!) would you mind also adding a docstring in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
with open(filename, "r") as file: | ||
file_row_1 = file.readline() | ||
file_row_2 = file.readline() | ||
file_row_2 = np.array(file_row_2.split('\t'), dtype=np.float64) | ||
sorted_keys = sorted(sp.field_data.keys()) | ||
_keys = [str(k) for k in sorted_keys] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see the point of signifying these as private with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
_keys = "\t".join(["#"] + _keys + ["\n"]) | ||
_data = [sp.field_data[k][0] for k in sorted_keys] | ||
|
||
assert_equal(_keys, file_row_1) | ||
assert_array_equal(_data, file_row_2) | ||
|
||
def test_save_object(): | ||
ds = fake_particle_ds() | ||
sp = ds.sphere(ds.domain_center, 0.25) | ||
sp.save_object("my_sphere_1", filename="test_save_obj") | ||
obj = shelve.open("test_save_obj", protocol=-1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be wrapped in a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had done it earlier using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, fair enough! |
||
loaded_sphere = obj["my_sphere_1"][1] | ||
assert_array_equal(loaded_sphere.center, sp.center) | ||
assert_equal(loaded_sphere.radius, sp.radius) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe also test field access? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
obj.close() | ||
|
||
sp.save_object("my_sphere_2") | ||
|
||
def test_to_dataframe(): | ||
try: | ||
import pandas as pd | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than wrapping the test logic in a try/except, please use You probably also need to install pandas on travis and appveyor to get this test to actually execute. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
pd # to ignore Flake8 error | ||
fields = ["density", "velocity_z"] | ||
ds = fake_random_ds(6) | ||
dd = ds.all_data() | ||
df1 = dd.to_dataframe(fields) | ||
assert_array_equal(dd[fields[0]], df1[fields[0]]) | ||
assert_array_equal(dd[fields[1]], df1[fields[1]]) | ||
except ImportError: | ||
pass | ||
|
||
def test_std(): | ||
ds = fake_random_ds(3) | ||
ds.all_data().std('density', weight="velocity_z") | ||
|
||
def test_to_frb(): | ||
ds = fake_amr_ds(fields=["density", "cell_mass"], geometry="cylindrical", | ||
particles=16**3) | ||
proj = ds.proj("density", weight_field="cell_mass", axis=1, | ||
data_source=ds.all_data()) | ||
proj.to_frb((1.0, 'unitary'), 64) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you test that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In case of |
||
|
||
def test_extract_isocontours(): | ||
# Test isocontour properties for AMRGridData | ||
ds = fake_amr_ds(fields=["density", "cell_mass"], particles=16**3) | ||
dd = ds.all_data() | ||
rho = dd.quantities["WeightedAverageQuantity"]("density", | ||
weight="cell_mass") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd format these lines like:
Often when you find yourself hitting a line length limit you can get around it by defining an intermediate variable. This has the side benefit of often making the code more understandable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
dd.extract_isocontours("density", rho, "triangles.obj", True) | ||
dd.calculate_isocontour_flux("density", rho, "x", "y", "z", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This formatting choice seems odd (the complete line is shorter than the lines above). Are you using an auto-formatter like yapf or eyeballing it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd format these two lines like:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
"dx") | ||
|
||
# Test error in case of ParticleData | ||
ds = fake_particle_ds() | ||
dd = ds.all_data() | ||
rho = dd.quantities["WeightedAverageQuantity"]("particle_velocity_x", | ||
weight="particle_mass") | ||
with assert_raises(NotImplementedError): | ||
dd.extract_isocontours("density", rho, sample_values='x') | ||
|
||
def test_extract_connected_sets(): | ||
|
||
ds = fake_random_ds(16, nprocs=8, particles=16 ** 3) | ||
data_source = ds.disk([0.5, 0.5, 0.5], [0., 0., 1.], (8, 'kpc'), (1, 'kpc')) | ||
field = ("gas", "density") | ||
min_val, max_val = data_source[field].min() / 2, data_source[field].max() / 2 | ||
|
||
data_source.extract_connected_sets(field, 3, min_val, max_val, | ||
log_space=True, cumulative=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you test that the data returned by this function are correct? Also the formatting right here is a little odd, this line should be aligned with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removing this test as it is already present in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd format this as:
_setup_data_source
doesn't return anything, and this way of calling it makes it clearer that that's the case.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done