Skip to content

Commit

Permalink
Fix price repair tests, remove unrelated changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ValueRaider committed Jan 29, 2023
1 parent 685f2ec commit a4f11b0
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 81 deletions.
67 changes: 45 additions & 22 deletions tests/prices.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,38 @@ def test_weekly_2rows_fix(self):
df = dat.history(start=start, interval="1wk")
self.assertTrue((df.index.weekday == 0).all())

class TestPriceRepair(unittest.TestCase):
session = None

@classmethod
def setUpClass(cls):
cls.session = requests_cache.CachedSession(backend='memory')

@classmethod
def tearDownClass(cls):
if cls.session is not None:
cls.session.close()

def test_reconstruct_2m(self):
# 2m repair requires 1m data.
# Yahoo restricts 1m fetches to 7 days max within last 30 days.
# Need to test that '_reconstruct_intervals_batch()' can handle this.

tkrs = ["BHP.AX", "IMP.JO", "BP.L", "PNL.L", "INTC"]

dt_now = _pd.Timestamp.utcnow()
td_7d = _dt.timedelta(days=7)
td_60d = _dt.timedelta(days=60)

# Round time for 'requests_cache' reuse
dt_now = dt_now.ceil("1h")

for tkr in tkrs:
dat = yf.Ticker(tkr, session=self.session)
end_dt = dt_now
start_dt = end_dt - td_60d
df = dat.history(start=start_dt, end=end_dt, interval="2m", repair=True)

def test_repair_100x_weekly(self):
# Setup:
tkr = "PNL.L"
Expand Down Expand Up @@ -452,38 +484,29 @@ def test_repair_zeroes_hourly(self):
dat = yf.Ticker(tkr, session=self.session)
tz_exchange = dat.info["exchangeTimezoneName"]

df_bad = _pd.DataFrame(data={"Open": [29.68, 29.49, 29.545, _np.nan, 29.485],
"High": [29.68, 29.625, 29.58, _np.nan, 29.49],
"Low": [29.46, 29.4, 29.45, _np.nan, 29.31],
"Close": [29.485, 29.545, 29.485, _np.nan, 29.325],
"Adj Close": [29.485, 29.545, 29.485, _np.nan, 29.325],
"Volume": [3258528, 2140195, 1621010, 0, 0]},
index=_pd.to_datetime([_dt.datetime(2022,11,25, 9,30),
_dt.datetime(2022,11,25, 10,30),
_dt.datetime(2022,11,25, 11,30),
_dt.datetime(2022,11,25, 12,30),
_dt.datetime(2022,11,25, 13,00)]))
df_bad = df_bad.sort_index()
df_bad.index.name = "Date"
df_bad.index = df_bad.index.tz_localize(tz_exchange)
correct_df = dat.history(period="1wk", interval="1h", auto_adjust=False, repair=True)

df_bad = correct_df.copy()
bad_idx = correct_df.index[10]
df_bad.loc[bad_idx, "Open"] = _np.nan
df_bad.loc[bad_idx, "High"] = _np.nan
df_bad.loc[bad_idx, "Low"] = _np.nan
df_bad.loc[bad_idx, "Close"] = _np.nan
df_bad.loc[bad_idx, "Adj Close"] = _np.nan
df_bad.loc[bad_idx, "Volume"] = 0

repaired_df = dat._fix_zeroes(df_bad, "1h", tz_exchange, prepost=False)

correct_df = df_bad.copy()
idx = _pd.Timestamp(2022,11,25, 12,30).tz_localize(tz_exchange)
correct_df.loc[idx, "Open"] = 29.485001
correct_df.loc[idx, "High"] = 29.49
correct_df.loc[idx, "Low"] = 29.43
correct_df.loc[idx, "Close"] = 29.455
correct_df.loc[idx, "Adj Close"] = 29.455
correct_df.loc[idx, "Volume"] = 609164
for c in ["Open", "Low", "High", "Close"]:
try:
self.assertTrue(_np.isclose(repaired_df[c], correct_df[c], rtol=1e-7).all())
except:
print("COLUMN", c)
print("- repaired_df")
print(repaired_df)
print("- correct_df[c]:")
print(correct_df[c])
print("- diff:")
print(repaired_df[c] - correct_df[c])
raise

Expand Down
5 changes: 4 additions & 1 deletion yfinance/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,13 +852,16 @@ def _reconstruct_intervals_batch(self, df, interval, prepost, tag=-1, silent=Fal
f_recent = df.index >= min_dt
f_repair_rows = f_repair_rows & f_recent
if not f_repair_rows.any():
# print("data too old to repair")
if debug:
print("data too old to repair")
return df

dts_to_repair = df.index[f_repair_rows]
indices_to_repair = _np.where(f_repair_rows)[0]

if len(dts_to_repair) == 0:
if debug:
print("dts_to_repair[] is empty")
return df

df_v2 = df.copy()
Expand Down
74 changes: 16 additions & 58 deletions yfinance/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,6 @@
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'}


def TypeCheckSeries(var, varName):
if not isinstance(var, _pd.Series) or isinstance(var, _pd.DataFrame):
raise TypeError(f"'{varName}' must be _pd.Series not {type(var)}")


# From https://stackoverflow.com/a/59128615
from types import FunctionType
from inspect import getmembers
Expand Down Expand Up @@ -485,63 +480,26 @@ def fix_Yahoo_returning_live_separate(quotes, interval, tz_exchange):

if last_rows_same_interval:
# Last two rows are within same interval
ia = quotes.index[n - 2]
ib = quotes.index[n - 1]
quotes.loc[ia] = merge_two_prices_intervals(quotes.loc[ia], quotes.loc[ib])
quotes = quotes.drop(ib)
idx1 = quotes.index[n - 1]
idx2 = quotes.index[n - 2]
if _np.isnan(quotes.loc[idx2, "Open"]):
quotes.loc[idx2, "Open"] = quotes["Open"][n - 1]
# Note: nanmax() & nanmin() ignores NaNs
quotes.loc[idx2, "High"] = _np.nanmax([quotes["High"][n - 1], quotes["High"][n - 2]])
quotes.loc[idx2, "Low"] = _np.nanmin([quotes["Low"][n - 1], quotes["Low"][n - 2]])
quotes.loc[idx2, "Close"] = quotes["Close"][n - 1]
if "Adj High" in quotes.columns:
quotes.loc[idx2, "Adj High"] = _np.nanmax([quotes["Adj High"][n - 1], quotes["Adj High"][n - 2]])
if "Adj Low" in quotes.columns:
quotes.loc[idx2, "Adj Low"] = _np.nanmin([quotes["Adj Low"][n - 1], quotes["Adj Low"][n - 2]])
if "Adj Close" in quotes.columns:
quotes.loc[idx2, "Adj Close"] = quotes["Adj Close"][n - 1]
quotes.loc[idx2, "Volume"] += quotes["Volume"][n - 1]
quotes = quotes.drop(quotes.index[n - 1])

return quotes


def merge_two_prices_intervals(i1, i2):
TypeCheckSeries(i1, "i1")
TypeCheckSeries(i2, "i2")

price_cols = ["Open", "High", "Low", "Close"]
na1 = i1[price_cols].isna().all()
na2 = i2[price_cols].isna().all()
if na1 and na2:
return i1
elif na1:
return i2
elif na2:
return i1

# First check if two intervals are almost identical. If yes, keep 2nd
ratio = _np.mean(i2[price_cols+["Volume"]] / i1[price_cols+["Volume"]])
if ratio > 0.99 and ratio < 1.01:
return i2

m = i1.copy()

if _np.isnan(m["Open"]):
m["Open"] = i2["Open"]
if "Adj Open" in m.index:
m["Adj Open"] = i2["Adj Open"]

# Note: nanmax() & nanmin() ignores NaNs
m["High"] = _np.nanmax([i2["High"], i1["High"]])
m["Low"] = _np.nanmin([i2["Low"], i1["Low"]])
if not _np.isnan(i2["Close"]):
m["Close"] = i2["Close"]

if "Adj High" in m.index:
m["Adj High"] = _np.nanmax([i2["Adj High"], i1["Adj High"]])
if "Adj Low" in m.index:
m["Adj Low"] = _np.nanmin([i2["Adj Low"], i1["Adj Low"]])
if "Adj Close" in m.index:
m["Adj Close"] = i2["Adj Close"]

if _np.isnan(m["Volume"]):
m["Volume"] = i2["Volume"]
elif _np.isnan(i2["Volume"]):
pass
else:
m["Volume"] += i2["Volume"]

return m


def safe_merge_dfs(df_main, df_sub, interval):
# Carefully merge 'df_sub' onto 'df_main'
# If naive merge fails, try again with reindexing df_sub:
Expand Down

0 comments on commit a4f11b0

Please sign in to comment.