Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor meca to use virtualfile_from_data #1613

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 46 additions & 118 deletions pygmt/src/meca.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,7 @@
import pandas as pd
from pygmt.clib import Session
from pygmt.exceptions import GMTError, GMTInvalidInput
from pygmt.helpers import (
build_arg_string,
data_kind,
dummy_context,
fmt_docstring,
kwargs_to_strings,
use_alias,
)
from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias


def data_format_code(convention, component="full"):
Expand Down Expand Up @@ -136,7 +129,7 @@ def meca(

Parameters
----------
spec: dict, 1D array, 2D array, pd.DataFrame, or str
spec : str or dict or numpy.ndarray or pandas.DataFrame
Either a filename containing focal mechanism parameters as columns, a
1- or 2-D array with the same, or a dictionary. If a filename or array,
`convention` is required so we know how to interpret the
Expand Down Expand Up @@ -247,6 +240,29 @@ def update_pointers(data_pointers):
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 237-241 are unnecessary because line 236 already checks spec is dict or pd.DataFrame. Am I right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might be right, let me recheck the logic.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weiji14 Please see my new PR #1784.

raise GMTError("Location not fully specified.")

# check the inputs for longitude, latitude, and depth
# just in case the user entered different length lists
if (
isinstance(longitude, (list, np.ndarray))
or isinstance(latitude, (list, np.ndarray))
or isinstance(depth, (list, np.ndarray))
):
if (len(longitude) != len(latitude)) or (len(longitude) != len(depth)):
raise GMTError("Unequal number of focal mechanism locations supplied.")

if isinstance(spec, dict) and any(
isinstance(s, (list, np.ndarray)) for s in spec.values()
):
# before constructing the 2D array lets check that each key
# of the dict has the same quantity of values to avoid bugs
list_length = len(list(spec.values())[0])
for value in list(spec.values()):
if len(value) != list_length:
raise GMTError(
"Unequal number of focal mechanism "
"parameters supplied in 'spec'."
)

param_conventions = {
"AKI": ["strike", "dip", "rake", "magnitude"],
"GCMT": [
Expand Down Expand Up @@ -313,95 +329,20 @@ def update_pointers(data_pointers):

# create a dict type pointer for easier to read code
if isinstance(spec, dict):
dict_type_pointer = list(spec.values())[0]
elif isinstance(spec, pd.DataFrame):
# use df.values as pointer for DataFrame behavior
dict_type_pointer = spec.values

# assemble the 1D array for the case of floats and ints as values
if isinstance(dict_type_pointer, (int, float)):
# update pointers
set_pointer(data_pointers, spec)
# look for optional parameters in the right place
(
longitude,
latitude,
depth,
plot_longitude,
plot_latitude,
) = update_pointers(data_pointers)

# Construct the array (order matters)
spec = [longitude, latitude, depth] + [spec[key] for key in foc_params]

# Add in plotting options, if given, otherwise add 0s
for arg in plot_longitude, plot_latitude:
if arg is None:
spec.append(0)
else:
if "A" not in kwargs:
kwargs["A"] = True
spec.append(arg)

# or assemble the 2D array for the case of lists as values
elif isinstance(dict_type_pointer, list):
# update pointers
set_pointer(data_pointers, spec)
# look for optional parameters in the right place
(
longitude,
latitude,
depth,
plot_longitude,
plot_latitude,
) = update_pointers(data_pointers)

# before constructing the 2D array lets check that each key
# of the dict has the same quantity of values to avoid bugs
list_length = len(list(spec.values())[0])
for value in list(spec.values()):
if len(value) != list_length:
raise GMTError(
"Unequal number of focal mechanism "
"parameters supplied in 'spec'."
)
# lets also check the inputs for longitude, latitude,
# and depth if it is a list or array
if (
isinstance(longitude, (list, np.ndarray))
or isinstance(latitude, (list, np.ndarray))
or isinstance(depth, (list, np.ndarray))
):
if (len(longitude) != len(latitude)) or (
len(longitude) != len(depth)
):
raise GMTError(
"Unequal number of focal mechanism " "locations supplied."
)

# values are ok, so build the 2D array
spec_array = []
for index in range(list_length):
# Construct the array one row at a time (note that order
# matters here, hence the list comprehension!)
row = [longitude[index], latitude[index], depth[index]] + [
spec[key][index] for key in foc_params
]

# Add in plotting options, if given, otherwise add 0s as
# required by GMT
for arg in plot_longitude, plot_latitude:
if arg is None:
row.append(0)
else:
if "A" not in kwargs:
kwargs["A"] = True
row.append(arg[index])
spec_array.append(row)
spec = spec_array

# or assemble the array for the case of pd.DataFrames
elif isinstance(dict_type_pointer, np.ndarray):
# Convert single int, float data to List[int, float] data
_spec = {
"longitude": np.atleast_1d(longitude),
"latitude": np.atleast_1d(latitude),
"depth": np.atleast_1d(depth),
}
_spec.update({key: np.atleast_1d(val) for key, val in spec.items()})
spec = pd.DataFrame.from_dict(_spec)

assert isinstance(spec, pd.DataFrame)
dict_type_pointer = spec.values

# Assemble the array for the case of pd.DataFrames
if isinstance(dict_type_pointer, np.ndarray):
Comment on lines -403 to +345
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to explain this change (clearer if you look at commit e223057), previously the spec could be formatted in 3 different paths depending on the type:

  1. if isinstance(dict_type_pointer, (int, float)): (i.e. a dictionary of int/float like {"depth": 0}
  2. elif isinstance(dict_type_pointer, list): (i.e. a dictionary of lists like {"depth": [0, 1, 2]}
  3. elif isinstance(dict_type_pointer, np.ndarray): (i.e. a pandas.DataFrame with np.ndarray columns)

There was a lot of duplication in these 3 code paths, so I've removed paths 1 and 2 by converting any spec provided by the user in dictionary format into pandas.DataFrame. So now, everything happens via path 3 only.

Another benefit of handling everything using pandas.DataFrame is that columns can have different data types, so this sets things up for the ability to pass in string labels later in a follow up PR.

# update pointers
set_pointer(data_pointers, spec)
# look for optional parameters in the right place
Expand All @@ -413,19 +354,7 @@ def update_pointers(data_pointers):
plot_latitude,
) = update_pointers(data_pointers)

# lets also check the inputs for longitude, latitude, and depth
# just in case the user entered different length lists
if (
isinstance(longitude, (list, np.ndarray))
or isinstance(latitude, (list, np.ndarray))
or isinstance(depth, (list, np.ndarray))
):
if (len(longitude) != len(latitude)) or (len(longitude) != len(depth)):
raise GMTError(
"Unequal number of focal mechanism locations supplied."
)

# values are ok, so build the 2D array in the correct order
# build the 2D array in the correct order
spec_array = []
for index in range(len(spec)):
# Construct the array one row at a time (note that order
Expand All @@ -449,20 +378,19 @@ def update_pointers(data_pointers):
else:
raise GMTError("Parameter 'spec' contains values of an unsupported type.")

# Convert 1d array types into 2d arrays
if isinstance(spec, np.ndarray) and spec.ndim == 1:
spec = np.atleast_2d(spec)

# determine data_foramt from convection and component
data_format = data_format_code(convention=convention, component=component)

# Assemble -S flag
kwargs["S"] = data_format + scale

kind = data_kind(spec)
with Session() as lib:
if kind == "matrix":
file_context = lib.virtualfile_from_matrix(np.atleast_2d(spec))
elif kind == "file":
file_context = dummy_context(spec)
else:
raise GMTInvalidInput(f"Unrecognized data type: {type(spec)}")
# Choose how data will be passed into the module
file_context = lib.virtualfile_from_data(check_kind="vector", data=spec)
with file_context as fname:
arg_str = " ".join([fname, build_arg_string(kwargs)])
lib.call_module("meca", arg_str)
46 changes: 41 additions & 5 deletions pygmt/tests/test_meca.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
import pytest
from pygmt import Figure
from pygmt.exceptions import GMTError
from pygmt.helpers import GMTTempFile


Expand All @@ -17,7 +18,7 @@ def test_meca_spec_dictionary():
fig = Figure()
# Right lateral strike slip focal mechanism
fig.meca(
dict(strike=0, dip=90, rake=0, magnitude=5),
spec=dict(strike=0, dip=90, rake=0, magnitude=5),
longitude=0,
latitude=5,
depth=0,
Expand All @@ -41,7 +42,7 @@ def test_meca_spec_dict_list():
strike=[330, 350], dip=[30, 50], rake=[90, 90], magnitude=[3, 2]
)
fig.meca(
focal_mechanisms,
spec=focal_mechanisms,
longitude=[-124.3, -124.4],
latitude=[48.1, 48.2],
depth=[12.0, 11.0],
Expand All @@ -52,6 +53,39 @@ def test_meca_spec_dict_list():
return fig


def test_meca_spec_unequal_sized_lists_fails():
"""
Test that supplying a dictionary containing unequal sized lists of
coordinates (longitude/latitude/depth) or focal mechanisms
(strike/dip/rake/magnitude) to the spec parameter fails.
"""
fig = Figure()

# Unequal sized coordinates (longitude/latitude/depth)
focal_mechanisms = dict(
strike=[330, 350], dip=[30, 50], rake=[90, 90], magnitude=[3, 2]
)
with pytest.raises(GMTError):
fig.meca(
spec=focal_mechanisms,
longitude=[-124.3],
latitude=[48.1, 48.2],
depth=[12.0],
scale="2c",
)

# Unequal sized focal mechanisms (strike/dip/rake/magnitude)
focal_mechanisms = dict(strike=[330], dip=[30, 50], rake=[90], magnitude=[3, 2])
with pytest.raises(GMTError):
fig.meca(
spec=focal_mechanisms,
longitude=[-124.3, -124.4],
latitude=[48.1, 48.2],
depth=[12.0, 11.0],
scale="2c",
)


@pytest.mark.mpl_image_compare
def test_meca_spec_dataframe():
"""
Expand All @@ -72,7 +106,9 @@ def test_meca_spec_dataframe():
depth=[12, 11.0],
)
spec_dataframe = pd.DataFrame(data=focal_mechanisms)
fig.meca(spec_dataframe, region=[-125, -122, 47, 49], scale="2c", projection="M14c")
fig.meca(
spec=spec_dataframe, region=[-125, -122, 47, 49], scale="2c", projection="M14c"
)
return fig


Expand Down Expand Up @@ -104,7 +140,7 @@ def test_meca_spec_1d_array():
]
focal_mech_array = np.asarray(focal_mechanism)
fig.meca(
focal_mech_array,
spec=focal_mech_array,
convention="mt",
component="full",
region=[-128, -127, 40, 41],
Expand Down Expand Up @@ -145,7 +181,7 @@ def test_meca_spec_2d_array():
]
focal_mechs_array = np.asarray(focal_mechanisms)
fig.meca(
focal_mechs_array,
spec=focal_mechs_array,
convention="gcmt",
region=[-128, -127, 40, 41],
scale="2c",
Expand Down