Skip to content

Commit

Permalink
Session.virtualfile_to_dataset: Add 'strings' output type for the arr…
Browse files Browse the repository at this point in the history
…ay of trailing texts
  • Loading branch information
seisman committed Apr 3, 2024
1 parent d35741c commit d9d2f9b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 17 deletions.
29 changes: 23 additions & 6 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1775,7 +1775,7 @@ def read_virtualfile(
def virtualfile_to_dataset(
self,
vfname: str,
output_type: Literal["pandas", "numpy", "file"] = "pandas",
output_type: Literal["pandas", "numpy", "file", "strings"] = "pandas",
column_names: list[str] | None = None,
dtype: type | dict[str, type] | None = None,
index_col: str | int | None = None,
Expand All @@ -1796,6 +1796,7 @@ def virtualfile_to_dataset(
- ``"pandas"`` will return a :class:`pandas.DataFrame` object.
- ``"numpy"`` will return a :class:`numpy.ndarray` object.
- ``"file"`` means the result was saved to a file and will return ``None``.
- ``"strings"`` will return the trailing text only as a list of strings.
column_names
The column names for the :class:`pandas.DataFrame` output.
dtype
Expand Down Expand Up @@ -1841,6 +1842,16 @@ def virtualfile_to_dataset(
... assert result is None
... assert Path(outtmp.name).stat().st_size > 0
...
... # strings output
... with Session() as lib:
... with lib.virtualfile_out(kind="dataset") as vouttbl:
... lib.call_module("read", f"{tmpfile.name} {vouttbl} -Td")
... outstr = lib.virtualfile_to_dataset(
... vfname=vouttbl, output_type="strings"
... )
... assert isinstance(outstr, np.ndarray)
... assert outstr.dtype.kind in ("S", "U")
...
... # numpy output
... with Session() as lib:
... with lib.virtualfile_out(kind="dataset") as vouttbl:
Expand Down Expand Up @@ -1869,6 +1880,9 @@ def virtualfile_to_dataset(
... column_names=["col1", "col2", "col3", "coltext"],
... )
... assert isinstance(outpd2, pd.DataFrame)
>>> outstr
array(['TEXT1 TEXT23', 'TEXT4 TEXT567', 'TEXT8 TEXT90',
'TEXT123 TEXT456789'], dtype='<U18')
>>> outnp
array([[1.0, 2.0, 3.0, 'TEXT1 TEXT23'],
[4.0, 5.0, 6.0, 'TEXT4 TEXT567'],
Expand All @@ -1890,11 +1904,14 @@ def virtualfile_to_dataset(
if output_type == "file": # Already written to file, so return None
return None

# Read the virtual file as a GMT dataset and convert to pandas.DataFrame
result = self.read_virtualfile(vfname, kind="dataset").contents.to_dataframe(
column_names=column_names,
dtype=dtype,
index_col=index_col,
# Read the virtual file as a GMT dataset
result = self.read_virtualfile(vfname, kind="dataset").contents

if output_type == "strings": # strings output
return result.to_strings()

result = result.to_dataframe(
column_names=column_names, dtype=dtype, index_col=index_col
)
if output_type == "numpy": # numpy.ndarray output
return result.to_numpy()
Expand Down
33 changes: 22 additions & 11 deletions pygmt/datatypes/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,21 @@ class _GMT_DATASEGMENT(ctp.Structure): # noqa: N801
("hidden", ctp.c_void_p),
]

def to_strings(self) -> np.ndarray:
"""
Convert the trailing text column to an array of strings.
"""
textvector = []
for itbl in range(self.n_tables):
dtbl = self.table[itbl].contents
for iseg in range(dtbl.n_segments):
dseg = dtbl.segment[iseg].contents
if dseg.text:
textvector.extend(dseg.text[: dseg.n_rows])
if textvector:
textvector = np.char.decode(textvector)
return textvector

def to_dataframe(
self,
column_names: pd.Index | None = None,
Expand Down Expand Up @@ -194,7 +209,11 @@ def to_dataframe(
... with lib.virtualfile_out(kind="dataset") as vouttbl:
... lib.call_module("read", f"{tmpfile.name} {vouttbl} -Td")
... ds = lib.read_virtualfile(vouttbl, kind="dataset")
... text = ds.contents.to_strings()
... df = ds.contents.to_dataframe()
>>> text
array(['TEXT1 TEXT23', 'TEXT4 TEXT567', 'TEXT8 TEXT90',
'TEXT123 TEXT456789'], dtype='<U18')
>>> df
0 1 2 3
0 1.0 2.0 3.0 TEXT1 TEXT23
Expand All @@ -218,17 +237,9 @@ def to_dataframe(
vectors.append(pd.Series(data=np.concatenate(colvector)))

# Deal with trailing text column
textvector = []
for itbl in range(self.n_tables):
dtbl = self.table[itbl].contents
for iseg in range(dtbl.n_segments):
dseg = dtbl.segment[iseg].contents
if dseg.text:
textvector.extend(dseg.text[: dseg.n_rows])
if textvector:
vectors.append(
pd.Series(data=np.char.decode(textvector), dtype=pd.StringDtype())
)
textvector = self.to_strings()
if len(textvector) != 0:
vectors.append(pd.Series(data=textvector, dtype=pd.StringDtype()))

if len(vectors) == 0:
# Return an empty DataFrame if no columns are found.
Expand Down

0 comments on commit d9d2f9b

Please sign in to comment.