Skip to content

Commit

Permalink
fix inverse transform
Browse files Browse the repository at this point in the history
  • Loading branch information
sashhhaka committed Jul 25, 2024
1 parent 69bbdab commit c4ce822
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
3 changes: 3 additions & 0 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def predict():
predictions = model.predict(df)
print("Predictions:", predictions)

predictions = scaler.inverse_transform(predictions.reshape(-1, 1))

print("Predictions conversed:", predictions)
# Return predictions as JSON
return jsonify({"predictions": predictions.tolist()})

Expand Down
30 changes: 22 additions & 8 deletions src/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@
import requests
import hydra
from data import read_features
import zenml
from hydra import compose, initialize
import numpy as np


with initialize(config_path="../configs", version_base=None):
cfg = compose(config_name="main")


column_name = cfg.data.target_cols[0]
scaler = zenml.load_artifact(f"{column_name}_scaler")


# type: ignore
Expand All @@ -14,7 +25,6 @@ def predict(cfg=None):

example = X.iloc[0, :]
example_target = y[0]

example = json.dumps({
"inputs": example.to_dict()
})
Expand All @@ -27,14 +37,18 @@ def predict(cfg=None):
headers={"Content-Type": "application/json"},
)

# print(predictions)
# predictions = scaler.inverse_transform(predictions.reshape(-1, 1))

# print("Predictions:", predictions)

print(response.json())
print("encoded target labels: ", example_target)
# print("target labels: ", list(cfg.data.labels)[example_target])

# take the value from {'predictions': [-0.467059463262558]}
prediction = response.json()["predictions"]
print("encoded prediction:", prediction)

prediction = np.array(prediction)
prediction = scaler.inverse_transform(prediction.reshape(-1, 1))
print("decoded prediction:", prediction)

print("encoded target labels:", example_target)
print("decoded target labels:", scaler.inverse_transform(np.array([example_target]).reshape(-1, 1)))


if __name__ == "__main__":
Expand Down

0 comments on commit c4ce822

Please sign in to comment.