Skip to content

Commit

Permalink
Session.virtualfile_to_dataset: Add 'strings' output type for the arr…
Browse files Browse the repository at this point in the history
…ay of trailing texts (#3157)
  • Loading branch information
seisman authored Apr 6, 2024
1 parent b490b0f commit 6193938
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 25 deletions.
29 changes: 23 additions & 6 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1775,7 +1775,7 @@ def read_virtualfile(
def virtualfile_to_dataset(
self,
vfname: str,
output_type: Literal["pandas", "numpy", "file"] = "pandas",
output_type: Literal["pandas", "numpy", "file", "strings"] = "pandas",
column_names: list[str] | None = None,
dtype: type | dict[str, type] | None = None,
index_col: str | int | None = None,
Expand All @@ -1796,6 +1796,7 @@ def virtualfile_to_dataset(
- ``"pandas"`` will return a :class:`pandas.DataFrame` object.
- ``"numpy"`` will return a :class:`numpy.ndarray` object.
- ``"file"`` means the result was saved to a file and will return ``None``.
- ``"strings"`` will return the trailing text only as an array of strings.
column_names
The column names for the :class:`pandas.DataFrame` output.
dtype
Expand Down Expand Up @@ -1841,6 +1842,16 @@ def virtualfile_to_dataset(
... assert result is None
... assert Path(outtmp.name).stat().st_size > 0
...
... # strings output
... with Session() as lib:
... with lib.virtualfile_out(kind="dataset") as vouttbl:
... lib.call_module("read", f"{tmpfile.name} {vouttbl} -Td")
... outstr = lib.virtualfile_to_dataset(
... vfname=vouttbl, output_type="strings"
... )
... assert isinstance(outstr, np.ndarray)
... assert outstr.dtype.kind in ("S", "U")
...
... # numpy output
... with Session() as lib:
... with lib.virtualfile_out(kind="dataset") as vouttbl:
Expand Down Expand Up @@ -1869,6 +1880,9 @@ def virtualfile_to_dataset(
... column_names=["col1", "col2", "col3", "coltext"],
... )
... assert isinstance(outpd2, pd.DataFrame)
>>> outstr
array(['TEXT1 TEXT23', 'TEXT4 TEXT567', 'TEXT8 TEXT90',
'TEXT123 TEXT456789'], dtype='<U18')
>>> outnp
array([[1.0, 2.0, 3.0, 'TEXT1 TEXT23'],
[4.0, 5.0, 6.0, 'TEXT4 TEXT567'],
Expand All @@ -1890,11 +1904,14 @@ def virtualfile_to_dataset(
if output_type == "file": # Already written to file, so return None
return None

# Read the virtual file as a GMT dataset and convert to pandas.DataFrame
result = self.read_virtualfile(vfname, kind="dataset").contents.to_dataframe(
column_names=column_names,
dtype=dtype,
index_col=index_col,
# Read the virtual file as a _GMT_DATASET object
result = self.read_virtualfile(vfname, kind="dataset").contents

if output_type == "strings": # strings output
return result.to_strings()

result = result.to_dataframe(
column_names=column_names, dtype=dtype, index_col=index_col
)
if output_type == "numpy": # numpy.ndarray output
return result.to_numpy()
Expand Down
44 changes: 25 additions & 19 deletions pygmt/datatypes/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ class _GMT_DATASEGMENT(ctp.Structure): # noqa: N801
("hidden", ctp.c_void_p),
]

def to_strings(self) -> np.ndarray[Any, np.dtype[np.str_]]:
"""
Convert the trailing text column to an array of strings.
"""
textvector = []
for table in self.table[: self.n_tables]:
for segment in table.contents.segment[: table.contents.n_segments]:
if segment.contents.text:
textvector.extend(segment.contents.text[: segment.contents.n_rows])
return np.char.decode(textvector) if textvector else np.array([], dtype=str)

def to_dataframe(
self,
column_names: pd.Index | None = None,
Expand Down Expand Up @@ -194,7 +205,11 @@ def to_dataframe(
... with lib.virtualfile_out(kind="dataset") as vouttbl:
... lib.call_module("read", f"{tmpfile.name} {vouttbl} -Td")
... ds = lib.read_virtualfile(vouttbl, kind="dataset")
... text = ds.contents.to_strings()
... df = ds.contents.to_dataframe()
>>> text
array(['TEXT1 TEXT23', 'TEXT4 TEXT567', 'TEXT8 TEXT90',
'TEXT123 TEXT456789'], dtype='<U18')
>>> df
0 1 2 3
0 1.0 2.0 3.0 TEXT1 TEXT23
Expand All @@ -207,28 +222,19 @@ def to_dataframe(
vectors = []
# Deal with numeric columns
for icol in range(self.n_columns):
colvector = []
for itbl in range(self.n_tables):
dtbl = self.table[itbl].contents
for iseg in range(dtbl.n_segments):
dseg = dtbl.segment[iseg].contents
colvector.append(
np.ctypeslib.as_array(dseg.data[icol], shape=(dseg.n_rows,))
)
colvector = [
np.ctypeslib.as_array(
seg.contents.data[icol], shape=(seg.contents.n_rows,)
)
for tbl in self.table[: self.n_tables]
for seg in tbl.contents.segment[: tbl.contents.n_segments]
]
vectors.append(pd.Series(data=np.concatenate(colvector)))

# Deal with trailing text column
textvector = []
for itbl in range(self.n_tables):
dtbl = self.table[itbl].contents
for iseg in range(dtbl.n_segments):
dseg = dtbl.segment[iseg].contents
if dseg.text:
textvector.extend(dseg.text[: dseg.n_rows])
if textvector:
vectors.append(
pd.Series(data=np.char.decode(textvector), dtype=pd.StringDtype())
)
textvector = self.to_strings()
if len(textvector) != 0:
vectors.append(pd.Series(data=textvector, dtype=pd.StringDtype()))

if len(vectors) == 0:
# Return an empty DataFrame if no columns are found.
Expand Down

0 comments on commit 6193938

Please sign in to comment.