Skip to content

Commit

Permalink
230 arima failing (#231)
Browse files Browse the repository at this point in the history
* Alternative method to annoy ARIMA #230
  • Loading branch information
IanGrimstead authored and thanasions committed Apr 5, 2019
1 parent f73bafc commit c3d2bb7
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
16 changes: 10 additions & 6 deletions scripts/algorithms/arima.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings

import numpy as np
from numpy import clip, inf
from sklearn.metrics import mean_squared_error
Expand Down Expand Up @@ -44,12 +45,15 @@ def __evaluate_arima_model(self, X, arima_order, ground_truth_in_history=False):
def __arima_model_predict(self, X, arima_order, steps_ahead):
# make predictions
predictions = list()
for t in range(steps_ahead):
model = ARIMA(X, order=arima_order)
model_fit = model.fit(disp=0)
yhat = model_fit.forecast()[0][0]
predictions.append(yhat)
X= np.append(X, yhat)
try:
for t in range(steps_ahead):
model = ARIMA(X, order=arima_order)
model_fit = model.fit(disp=0)
yhat = model_fit.forecast()[0][0]
predictions.append(yhat)
X = np.append(X, yhat)
except:
predictions.extend([np.nan] * (steps_ahead - len(predictions)))

return predictions

Expand Down
20 changes: 20 additions & 0 deletions tests/algorithms/test_arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,26 @@ def test_static_sequence(self):

np_test.assert_almost_equal(actual_prediction, expected_prediction, decimal=4)

def test_linear_sequence(self):
time_series = [1.0, 2.0, 3.0, 4.0, 5.0]
num_predicted_periods = 3
expected_prediction = [6.0, 7.0, 8.0]
arima = ARIMAForecast(time_series, num_predicted_periods)

actual_prediction = arima.predict_counts()

np_test.assert_almost_equal(actual_prediction, expected_prediction, decimal=4)

def test_flakey_sequence(self):
time_series = [20.0, -20.0]
num_predicted_periods = 3
expected_prediction = [np.nan] * 3
arima = ARIMAForecast(time_series, num_predicted_periods)

actual_prediction = arima.predict_counts()

np_test.assert_almost_equal(actual_prediction, expected_prediction, decimal=1)

def test_linearly_increasing_sequence_fuel_cell(self):
time_series = pd.read_csv(os.path.join('tests','data', 'fuel_cell_quarterly.csv')).values.tolist()
time_series = [item for sublist in time_series for item in sublist]
Expand Down

0 comments on commit c3d2bb7

Please sign in to comment.