Skip to content

Commit

Permalink
Merge pull request #1709 from ranaroussi/feature/tz-cache-lazy-load
Browse files Browse the repository at this point in the history
Feature/tz cache lazy load
  • Loading branch information
ValueRaider committed Oct 1, 2023
2 parents 13acc3d + cc1ac7b commit 38f8ccd
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 27 deletions.
38 changes: 34 additions & 4 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
import os


class TestUtils(unittest.TestCase):
session = None

class TestCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.tempCacheDir = tempfile.TemporaryDirectory()
Expand Down Expand Up @@ -51,9 +49,41 @@ def test_setTzCacheLocation(self):

self.assertTrue(os.path.exists(os.path.join(self.tempCacheDir.name, "tkr-tz.db")))


class TestCacheNoPermission(unittest.TestCase):
@classmethod
def setUpClass(cls):
yf.set_tz_cache_location("/root/yf-cache")

def test_tzCacheRootStore(self):
# Test that if cache path in read-only filesystem, no exception.
tkr = 'AMZN'
tz1 = "America/New_York"

# During attempt to store, will discover cannot write
yf.utils.get_tz_cache().store(tkr, tz1)

# Handling the store failure replaces cache with a dummy
cache = yf.utils.get_tz_cache()
self.assertTrue(cache.dummy)
cache.store(tkr, tz1)

def test_tzCacheRootLookup(self):
# Test that if cache path in read-only filesystem, no exception.
tkr = 'AMZN'
# During attempt to lookup, will discover cannot write
yf.utils.get_tz_cache().lookup(tkr)

# Handling the lookup failure replaces cache with a dummy
cache = yf.utils.get_tz_cache()
self.assertTrue(cache.dummy)
cache.lookup(tkr)


def suite():
suite = unittest.TestSuite()
suite.addTest(TestUtils('Test utils'))
suite.addTest(TestCache('Test cache'))
suite.addTest(TestCacheNoPermission('Test cache no permission'))
return suite


Expand Down
78 changes: 55 additions & 23 deletions yfinance/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,21 +934,15 @@ class _TzCacheManager:
_tz_cache = None

@classmethod
def get_tz(cls):
def get_tz_cache(cls):
if cls._tz_cache is None:
with _cache_init_lock:
cls._initialise()
return cls._tz_cache

@classmethod
def _initialise(cls, cache_dir=None):
try:
cls._tz_cache = _TzCache()
except _TzCacheException as err:
get_yf_logger().info(f"Failed to create TzCache, reason: {err}. "
"TzCache will not be used. "
"Tip: You can direct cache to use a different location with 'set_tz_cache_location(mylocation)'")
cls._tz_cache = _TzCacheDummy()
cls._tz_cache = _TzCache()


class _DBManager:
Expand Down Expand Up @@ -1008,34 +1002,78 @@ def get_location(cls):
_atexit.register(_DBManager.close_db)


db_proxy = _peewee.Proxy()
class _KV(_peewee.Model):
key = _peewee.CharField(primary_key=True)
value = _peewee.CharField(null=True)

class Meta:
try:
database = _DBManager.get_database()
except Exception:
# This code runs at import, so Logger won't be ready yet, so must discard exception.
database = None
database = db_proxy
without_rowid = True


class _TzCache:
def __init__(self):
db = _DBManager.get_database()
self.initialised = -1
self.db = None
self.dummy = False

def get_db(self):
if self.db is not None:
return self.db

try:
self.db = _DBManager.get_database()
except _TzCacheException as err:
get_yf_logger().info(f"Failed to create TzCache, reason: {err}. "
"TzCache will not be used. "
"Tip: You can direct cache to use a different location with 'set_tz_cache_location(mylocation)'")
self.dummy = True
return None
return self.db

def initialise(self):
if self.initialised != -1:
return

db = self.get_db()
if db is None:
self.initialised = 0 # failure
return

db.connect()
db_proxy.initialize(db)
db.create_tables([_KV])

self.initialised = 1 # success

def lookup(self, key):
if self.dummy:
return None

if self.initialised == -1:
self.initialise()

if self.initialised == 0: # failure
return None

try:
return _KV.get(_KV.key == key).value
except _KV.DoesNotExist:
return None

def store(self, key, value):
db = _DBManager.get_database()
if self.dummy:
return

if self.initialised == -1:
self.initialise()

if self.initialised == 0: # failure
return

db = self.get_db()
if db is None:
return
try:
if value is None:
q = _KV.delete().where(_KV.key == key)
Expand All @@ -1054,13 +1092,7 @@ def store(self, key, value):


def get_tz_cache():
"""
Get the timezone cache, initializes it and creates cache folder if needed on first call.
If folder cannot be created for some reason it will fall back to initialize a
dummy cache with same interface as real cash.
"""
# as this can be called from multiple threads, protect it.
return _TzCacheManager.get_tz()
return _TzCacheManager.get_tz_cache()


def set_tz_cache_location(cache_dir: str):
Expand Down

0 comments on commit 38f8ccd

Please sign in to comment.