diff --git a/Tools/dea_tools/classification.py b/Tools/dea_tools/classification.py index 5ca8b411..27ac5e32 100644 --- a/Tools/dea_tools/classification.py +++ b/Tools/dea_tools/classification.py @@ -226,6 +226,7 @@ def predict_xr( chunk_size=None, persist=False, proba=False, + max_proba=True, clean=False, return_input=False, ): @@ -255,6 +256,11 @@ def predict_xr( distributed RAM. proba : bool If True, predict probabilities + max_proba : bool + If True, the probabilities array will be flattened to contain + only the probabiltiy for the "Predictions" class. If False, + the "Probabilities" object will be an array of prediction + probaiblities for each classes clean : bool If True, remove Infs and NaNs from input and output arrays return_input : bool @@ -282,7 +288,7 @@ def predict_xr( input_xr.chunks["y"][0] ) - def _predict_func(model, input_xr, persist, proba, clean, return_input): + def _predict_func(model, input_xr, persist, proba, max_proba, clean, return_input): x, y, crs = input_xr.x, input_xr.y, input_xr.geobox.crs input_data = [] @@ -330,18 +336,35 @@ def _predict_func(model, input_xr, persist, proba, clean, return_input): print(" probabilities...") out_proba = model.predict_proba(input_data_flattened) - # convert to % - out_proba = da.max(out_proba, axis=1) * 100.0 + # return either one band with the max probability, or the whole probability array + if max_proba == True: + print(" returning single probability band.") + out_proba = da.max(out_proba, axis=1) * 100.0 + out_proba = out_proba.reshape(len(y), len(x)) + out_proba = xr.DataArray( + out_proba, coords={"x": x, "y": y}, dims=["y", "x"] + ) + output_xr["Probabilities"] = out_proba + else: + print(" returning class probability array.") + out_proba = out_proba * 100.0 + class_names = model.classes_ # Get the unique class names from the fitted classifier + + # Loop through each class (band) + probabilities_dataset = xr.Dataset() + for i, class_name in enumerate(class_names): + reshaped_band = out_proba[:, i].reshape(len(y), len(x)) + reshaped_da = xr.DataArray( + reshaped_band, coords={"x": x, "y": y}, dims=["y", "x"] + ) + probabilities_dataset[f"prob_{class_name}"] = reshaped_da + # merge in the probabilities + output_xr = xr.merge([output_xr, probabilities_dataset]) + if clean == True: out_proba = da.where(da.isfinite(out_proba), out_proba, 0) - - out_proba = out_proba.reshape(len(y), len(x)) - - out_proba = xr.DataArray( - out_proba, coords={"x": x, "y": y}, dims=["y", "x"] - ) - output_xr["Probabilities"] = out_proba + if return_input == True: print(" input features...") @@ -391,12 +414,12 @@ def _predict_func(model, input_xr, persist, proba, clean, return_input): model = ParallelPostFit(model) with joblib.parallel_backend("dask"): output_xr = _predict_func( - model, input_xr, persist, proba, clean, return_input + model, input_xr, persist, proba, max_proba, clean, return_input ) else: output_xr = _predict_func( - model, input_xr, persist, proba, clean, return_input + model, input_xr, persist, proba, max_proba, clean, return_input ).compute() return output_xr