diff --git a/pygmt/modules.py b/pygmt/modules.py index 7339241002a..94512ec325d 100644 --- a/pygmt/modules.py +++ b/pygmt/modules.py @@ -1,6 +1,7 @@ """ Non-plot GMT modules. """ +import numpy as np import xarray as xr from .clib import Session @@ -55,7 +56,7 @@ def grdinfo(grid, **kwargs): @fmt_docstring @use_alias(C="per_column", I="spacing", T="nearest_multiple") -def info(fname, **kwargs): +def info(table, **kwargs): """ Get information about data tables. @@ -74,8 +75,9 @@ def info(fname, **kwargs): Parameters ---------- - fname : str - The file name of the input data table file. + table : pandas.DataFrame or np.ndarray or str + Either a pandas dataframe, a 1D/2D numpy.ndarray or a file name to an + ASCII data table. per_column : bool Report the min/max values per column in separate columns. spacing : str @@ -88,14 +90,25 @@ def info(fname, **kwargs): Report the min/max of the first (0'th) column to the nearest multiple of dz and output this as the string *-Tzmin/zmax/dz*. """ - if not isinstance(fname, str): - raise GMTInvalidInput("'info' only accepts file names.") + kind = data_kind(table) + with Session() as lib: + if kind == "file": + file_context = dummy_context(table) + elif kind == "matrix": + _table = np.asanyarray(table) + if table.ndim == 1: # 1D arrays need to be 2D and transposed + _table = np.transpose(np.atleast_2d(_table)) + file_context = lib.virtualfile_from_matrix(_table) + else: + raise GMTInvalidInput(f"Unrecognized data type: {type(table)}") - with GMTTempFile() as tmpfile: - arg_str = " ".join([fname, build_arg_string(kwargs), "->" + tmpfile.name]) - with Session() as lib: - lib.call_module("info", arg_str) - return tmpfile.read() + with GMTTempFile() as tmpfile: + with file_context as fname: + arg_str = " ".join( + [fname, build_arg_string(kwargs), "->" + tmpfile.name] + ) + lib.call_module("info", arg_str) + return tmpfile.read() @fmt_docstring diff --git a/pygmt/tests/test_info.py b/pygmt/tests/test_info.py index 3e9da3abf81..b7eadc53649 100644 --- a/pygmt/tests/test_info.py +++ b/pygmt/tests/test_info.py @@ -4,7 +4,9 @@ import os import numpy as np +import pandas as pd import pytest +import xarray as xr from .. import info from ..exceptions import GMTInvalidInput @@ -14,8 +16,8 @@ def test_info(): - "Make sure info works" - output = info(fname=POINTS_DATA) + "Make sure info works on file name inputs" + output = info(table=POINTS_DATA) expected_output = ( f"{POINTS_DATA}: N = 20 " "<11.5309/61.7074> " @@ -25,33 +27,61 @@ def test_info(): assert output == expected_output +def test_info_dataframe(): + "Make sure info works on pandas.DataFrame inputs" + table = pd.read_csv(POINTS_DATA, sep=" ", header=None) + output = info(table=table) + expected_output = ( + ": N = 20 <11.5309/61.7074> <-2.9289/7.8648> <0.1412/0.9338>\n" + ) + assert output == expected_output + + +def test_info_2d_array(): + "Make sure info works on 2D numpy.ndarray inputs" + table = np.loadtxt(POINTS_DATA) + output = info(table=table) + expected_output = ( + ": N = 20 <11.5309/61.7074> <-2.9289/7.8648> <0.1412/0.9338>\n" + ) + assert output == expected_output + + +def test_info_1d_array(): + "Make sure info works on 1D numpy.ndarray inputs" + output = info(table=np.arange(20)) + expected_output = ": N = 20 <0/19>\n" + assert output == expected_output + + def test_info_per_column(): "Make sure the per_column option works" - output = info(fname=POINTS_DATA, per_column=True) + output = info(table=POINTS_DATA, per_column=True) assert output == "11.5309 61.7074 -2.9289 7.8648 0.1412 0.9338\n" def test_info_spacing(): "Make sure the spacing option works" - output = info(fname=POINTS_DATA, spacing=0.1) + output = info(table=POINTS_DATA, spacing=0.1) assert output == "-R11.5/61.8/-3/7.9\n" def test_info_per_column_spacing(): "Make sure the per_column and spacing options work together" - output = info(fname=POINTS_DATA, per_column=True, spacing=0.1) + output = info(table=POINTS_DATA, per_column=True, spacing=0.1) assert output == "11.5 61.8 -3 7.9 0.1412 0.9338\n" def test_info_nearest_multiple(): "Make sure the nearest_multiple option works" - output = info(fname=POINTS_DATA, nearest_multiple=0.1) + output = info(table=POINTS_DATA, nearest_multiple=0.1) assert output == "-T11.5/61.8/0.1\n" def test_info_fails(): - "Make sure info raises an exception if not given a file name" - with pytest.raises(GMTInvalidInput): - info(fname=21) + """ + Make sure info raises an exception if not given either a file name, pandas + DataFrame, or numpy ndarray + """ with pytest.raises(GMTInvalidInput): - info(fname=np.arange(20)) + info(table=xr.DataArray(21))