Skip to content

Commit

Permalink
add multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
ninsbl committed Nov 3, 2022
1 parent 8dd3ce7 commit d4abf62
Showing 1 changed file with 110 additions and 70 deletions.
180 changes: 110 additions & 70 deletions python/grass/temporal/univar_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,89 @@
:authors: Soeren Gebbert
"""
from __future__ import print_function
from multiprocessing import Pool
from copy import deepcopy
from subprocess import PIPE

from .core import SQLDatabaseInterfaceConnection, get_current_mapset
from .factory import dataset_factory
from .open_stds import open_old_stds
import grass.script as gscript
import grass.script as gs
from grass.pygrass.modules import Module

###############################################################################


def compute_univar_stats(row, stats_module, fs, rast_region=False):
"""Compute univariate statistics for a map of a space time raster or raster3d dataset
:param row: Must be "strds" or "str3ds"
:param stats_module: Pre-configured PyGRASS Module to compute univariate statistics with
:param fs: Field separator
:param rast_region: If set True ignore the current region settings
and use the raster map regions for univar statistical calculation.
Only available for strds.
"""
string = ""
id = row["id"]
start = row["start_time"]
end = row["end_time"]
semantic_label = (
""
if stats_module.name == "r3.univar" or not row["semantic_label"]
else row["semantic_label"]
)

stats_module.inputs.map = id
if rast_region:
stats_module.env = gs.region_env(raster=id)
stats_module.run()

univar_stats = stats_module.outputs.stdout

if not univar_stats:
gs.warning(
_(
"Unable to get statistics for {voxel}raster map "
"<{rmap}>".format(
rmap=id, voxel="" if stats_module.name == "r.univar" else "3d "
)
)
)
return None
eol = ""

for idx, stats_kv in enumerate(univar_stats.split(";")):
stats = gs.utils.parse_key_val(stats_kv)
string += (
f"{id}{fs}{semantic_label}{fs}{start}{fs}{end}"
if stats_module.name == "r.univar"
else f"{id}{fs}{start}{fs}{end}"
)
if stats_module.inputs.zones:
if idx == 0:
zone = str(stats["zone"])
string = ""
continue
string += f"{fs}{zone}"
if "zone" in stats:
zone = str(stats["zone"])
eol = "\n"
else:
eol = ""
string += f'{fs}{stats["mean"]}{fs}{stats["min"]}'
string += f'{fs}{stats["max"]}{fs}{stats["mean_of_abs"]}'
string += f'{fs}{stats["stddev"]}{fs}{stats["variance"]}'
string += f'{fs}{stats["coeff_var"]}{fs}{stats["sum"]}'
string += f'{fs}{stats["null_cells"]}{fs}{stats["n"]}'
string += f'{fs}{stats["n"]}'
if "median" in stats:
string += f'{fs}{stats["first_quartile"]}{fs}{stats["median"]}'
string += f'{fs}{stats["third_quartile"]}{fs}{stats["percentile_90"]}'
string += eol
return string


def print_gridded_dataset_univar_statistics(
type,
input,
Expand All @@ -38,6 +112,7 @@ def print_gridded_dataset_univar_statistics(
fs="|",
rast_region=False,
zones=None,
nprocs=1,
):
"""Print univariate statistics for a space time raster or raster3d dataset
Expand All @@ -48,17 +123,12 @@ def print_gridded_dataset_univar_statistics(
:param extended: If True compute extended statistics
:param no_header: Suppress the printing of column names
:param fs: Field separator
:param nprocs: Number of cores to use for processing
:param rast_region: If set True ignore the current region settings
and use the raster map regions for univar statistical calculation.
Only available for strds.
:param zones: raster map with zones to calculate statistics for
"""

stats_module = {
"strds": "r.univar",
"str3ds": "r3.univar",
}[type]

# We need a database interface
dbif = SQLDatabaseInterfaceConnection()
dbif.connect()
Expand All @@ -80,7 +150,7 @@ def print_gridded_dataset_univar_statistics(
err = "Space time %(sp)s dataset <%(i)s> is empty"
if where:
err += " or where condition is wrong"
gscript.fatal(
gs.fatal(
_(err) % {"sp": sp.get_new_map_instance(None).get_type(), "i": sp.get_id()}
)

Expand Down Expand Up @@ -116,73 +186,43 @@ def print_gridded_dataset_univar_statistics(
else:
out_file.write(string + "\n")

# Define flags
flag = "g"

if extended is True:
flag += "e"
if type == "strds" and rast_region is True:
flag += "r"

for row in rows:
string = ""
id = row["id"]
start = row["start_time"]
end = row["end_time"]
semantic_label = (
""
if type != "strds" or not row["semantic_label"]
else row["semantic_label"]
)

univar_stats = gscript.read_command(
stats_module, map=id, flags=flag, zones=zones
).rstrip()

if not univar_stats:
if type == "strds":
gscript.warning(
_("Unable to get statistics for raster map " "<%s>") % id
)
elif type == "str3ds":
gscript.warning(
_("Unable to get statistics for 3d raster map" " <%s>") % id
)
continue
eol = ""
# Setup pygrass module to use for computation
univar_module = Module(
"r.univar" if type == "strds" else "r3.univar",
flags=flag,
zones=zones,
stdout_=PIPE,
run_=False,
)

for idx, stats_kv in enumerate(univar_stats.split(";")):
stats = gscript.utils.parse_key_val(stats_kv)
string += (
f"{id}{fs}{semantic_label}{fs}{start}{fs}{end}"
if type == "strds"
else f"{id}{fs}{start}{fs}{end}"
if nprocs == 1:
strings = [
compute_univar_stats(
row,
univar_module,
fs,
)
if zones:
if idx == 0:
zone = str(stats["zone"])
string = ""
continue
string += f"{fs}{zone}"
if "zone" in stats:
zone = str(stats["zone"])
eol = "\n"
else:
eol = ""
string += f'{fs}{stats["mean"]}{fs}{stats["min"]}'
string += f'{fs}{stats["max"]}{fs}{stats["mean_of_abs"]}'
string += f'{fs}{stats["stddev"]}{fs}{stats["variance"]}'
string += f'{fs}{stats["coeff_var"]}{fs}{stats["sum"]}'
string += f'{fs}{stats["null_cells"]}{fs}{stats["n"]}'
string += f'{fs}{stats["n"]}'
if extended is True:
string += f'{fs}{stats["first_quartile"]}{fs}{stats["median"]}'
string += f'{fs}{stats["third_quartile"]}{fs}{stats["percentile_90"]}'
string += eol
for row in rows
]
else:
with Pool(min(nprocs, len(rows))) as pool:
strings = pool.starmap(
compute_univar_stats, [(dict(row), univar_module, fs) for row in rows]
)
pool.close()
pool.join()

if output is None:
print(string)
else:
out_file.write(string + "\n")
if output is None:
print("\n".join(filter(None, strings)))
else:
out_file.write("\n".join(filter(None, strings)))

dbif.close()

Expand Down Expand Up @@ -229,7 +269,7 @@ def print_vector_dataset_univar_statistics(

if sp.is_in_db(dbif) is False:
dbif.close()
gscript.fatal(
gs.fatal(
_("Space time %(sp)s dataset <%(i)s> not found")
% {"sp": sp.get_new_map_instance(None).get_type(), "i": id}
)
Expand All @@ -242,7 +282,7 @@ def print_vector_dataset_univar_statistics(

if not rows:
dbif.close()
gscript.fatal(
gs.fatal(
_("Space time %(sp)s dataset <%(i)s> is empty")
% {"sp": sp.get_new_map_instance(None).get_type(), "i": id}
)
Expand Down Expand Up @@ -316,7 +356,7 @@ def print_vector_dataset_univar_statistics(
if not mylayer:
mylayer = layer

stats = gscript.parse_command(
stats = gs.parse_command(
"v.univar",
map=id,
where=where,
Expand All @@ -329,7 +369,7 @@ def print_vector_dataset_univar_statistics(
string = ""

if not stats:
gscript.warning(_("Unable to get statistics for vector map <%s>") % id)
gs.warning(_("Unable to get statistics for vector map <%s>") % id)
continue

string += str(id) + fs + str(start) + fs + str(end)
Expand Down

0 comments on commit d4abf62

Please sign in to comment.