Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/tz cache lazy load #1709

Merged
merged 3 commits into from
Oct 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
import os


class TestUtils(unittest.TestCase):
session = None

class TestCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.tempCacheDir = tempfile.TemporaryDirectory()
Expand Down Expand Up @@ -51,9 +49,41 @@ def test_setTzCacheLocation(self):

self.assertTrue(os.path.exists(os.path.join(self.tempCacheDir.name, "tkr-tz.db")))


class TestCacheNoPermission(unittest.TestCase):
@classmethod
def setUpClass(cls):
yf.set_tz_cache_location("/root/yf-cache")

def test_tzCacheRootStore(self):
# Test that if cache path in read-only filesystem, no exception.
tkr = 'AMZN'
tz1 = "America/New_York"

# During attempt to store, will discover cannot write
yf.utils.get_tz_cache().store(tkr, tz1)

# Handling the store failure replaces cache with a dummy
cache = yf.utils.get_tz_cache()
self.assertTrue(cache.dummy)
cache.store(tkr, tz1)

def test_tzCacheRootLookup(self):
# Test that if cache path in read-only filesystem, no exception.
tkr = 'AMZN'
# During attempt to lookup, will discover cannot write
yf.utils.get_tz_cache().lookup(tkr)

# Handling the lookup failure replaces cache with a dummy
cache = yf.utils.get_tz_cache()
self.assertTrue(cache.dummy)
cache.lookup(tkr)


def suite():
suite = unittest.TestSuite()
suite.addTest(TestUtils('Test utils'))
suite.addTest(TestCache('Test cache'))
suite.addTest(TestCacheNoPermission('Test cache no permission'))
return suite


Expand Down
78 changes: 55 additions & 23 deletions yfinance/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,21 +934,15 @@ class _TzCacheManager:
_tz_cache = None

@classmethod
def get_tz(cls):
def get_tz_cache(cls):
if cls._tz_cache is None:
with _cache_init_lock:
cls._initialise()
return cls._tz_cache

@classmethod
def _initialise(cls, cache_dir=None):
try:
cls._tz_cache = _TzCache()
except _TzCacheException as err:
get_yf_logger().info(f"Failed to create TzCache, reason: {err}. "
"TzCache will not be used. "
"Tip: You can direct cache to use a different location with 'set_tz_cache_location(mylocation)'")
cls._tz_cache = _TzCacheDummy()
cls._tz_cache = _TzCache()


class _DBManager:
Expand Down Expand Up @@ -1008,34 +1002,78 @@ def get_location(cls):
_atexit.register(_DBManager.close_db)


db_proxy = _peewee.Proxy()
class _KV(_peewee.Model):
key = _peewee.CharField(primary_key=True)
value = _peewee.CharField(null=True)

class Meta:
try:
database = _DBManager.get_database()
except Exception:
# This code runs at import, so Logger won't be ready yet, so must discard exception.
database = None
database = db_proxy
without_rowid = True


class _TzCache:
def __init__(self):
db = _DBManager.get_database()
self.initialised = -1
self.db = None
self.dummy = False

def get_db(self):
if self.db is not None:
return self.db

try:
self.db = _DBManager.get_database()
except _TzCacheException as err:
get_yf_logger().info(f"Failed to create TzCache, reason: {err}. "
"TzCache will not be used. "
"Tip: You can direct cache to use a different location with 'set_tz_cache_location(mylocation)'")
self.dummy = True
return None
return self.db

def initialise(self):
if self.initialised != -1:
return

db = self.get_db()
if db is None:
self.initialised = 0 # failure
return

db.connect()
db_proxy.initialize(db)
db.create_tables([_KV])

self.initialised = 1 # success

def lookup(self, key):
if self.dummy:
return None

if self.initialised == -1:
self.initialise()

if self.initialised == 0: # failure
return None

try:
return _KV.get(_KV.key == key).value
except _KV.DoesNotExist:
return None

def store(self, key, value):
db = _DBManager.get_database()
if self.dummy:
return

if self.initialised == -1:
self.initialise()

if self.initialised == 0: # failure
return

db = self.get_db()
if db is None:
return
try:
if value is None:
q = _KV.delete().where(_KV.key == key)
Expand All @@ -1054,13 +1092,7 @@ def store(self, key, value):


def get_tz_cache():
"""
Get the timezone cache, initializes it and creates cache folder if needed on first call.
If folder cannot be created for some reason it will fall back to initialize a
dummy cache with same interface as real cash.
"""
# as this can be called from multiple threads, protect it.
return _TzCacheManager.get_tz()
return _TzCacheManager.get_tz_cache()


def set_tz_cache_location(cache_dir: str):
Expand Down