Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use contextvars to maintain a call stack during the usage calls #1882

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions sdk/python/feast/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
import contextvars
import enum
import logging
import os
import sys
import uuid
from collections import defaultdict
from datetime import datetime
from functools import wraps
from os.path import expanduser, join
from pathlib import Path
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import requests

Expand All @@ -31,6 +33,7 @@
_logger = logging.getLogger(__name__)

executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
call_stack: contextvars.ContextVar = contextvars.ContextVar("call_stack", default=[])


@enum.unique
Expand All @@ -51,6 +54,8 @@ def __str__(self):
class Usage:
def __init__(self):
self._usage_enabled: bool = False
self._is_test = os.getenv("FEAST_IS_USAGE_TEST", "False") == "True"
self._usage_counter = defaultdict(lambda: 0)
self.check_env_and_configure()

def check_env_and_configure(self):
Expand All @@ -68,9 +73,6 @@ def check_env_and_configure(self):
Path(feast_home_dir).mkdir(exist_ok=True)
usage_filepath = join(feast_home_dir, "usage")

self._is_test = os.getenv("FEAST_IS_USAGE_TEST", "False") == "True"
self._usage_counter = {}

if os.path.exists(usage_filepath):
with open(usage_filepath, "r") as f:
self._usage_id = f.read()
Expand Down Expand Up @@ -106,9 +108,8 @@ def _send_usage_request(self, json):
def log_function(self, function_name: str):
self.check_env_and_configure()
if self._usage_enabled and self.usage_id:
if (
function_name == "get_online_features"
and not self.should_log_for_get_online_features_event(function_name)
if "get_online_features" in call_stack.get() and not self.should_log_for_get_online_features_event(
"get_online_features"
):
return
json = {
Expand All @@ -121,10 +122,10 @@ def log_function(self, function_name: str):
}
self._send_usage_request(json)

def should_log_for_get_online_features_event(self, event_name: str):
if event_name not in self._usage_counter:
self._usage_counter[event_name] = 0
def increment_event_count(self, event_name: Union[UsageEvent, str]):
self._usage_counter[event_name] += 1

def should_log_for_get_online_features_event(self, event_name: str):
if self._usage_counter[event_name] % 10000 != 2:
return False
self._usage_counter[event_name] = 2 # avoid overflow
Expand Down Expand Up @@ -174,6 +175,7 @@ def log_exceptions(func):
@wraps(func)
def exception_logging_wrapper(*args, **kwargs):
try:
call_stack.set(call_stack.get() + [func.__name__])
result = func(*args, **kwargs)
except Exception as e:
error_type = type(e).__name__
Expand All @@ -190,6 +192,9 @@ def exception_logging_wrapper(*args, **kwargs):
tb = tb.tb_next
usage.log_exception(error_type, trace_to_log)
raise
finally:
if len(call_stack.get()) > 0:
call_stack.set(call_stack.get()[:-1])
return result

return exception_logging_wrapper
Expand All @@ -199,6 +204,8 @@ def log_exceptions_and_usage(func):
@wraps(func)
def exception_logging_wrapper(*args, **kwargs):
try:
call_stack.set(call_stack.get() + [func.__name__])
usage.increment_event_count(func.__name__)
result = func(*args, **kwargs)
usage.log_function(func.__name__)
except Exception as e:
Expand All @@ -216,12 +223,16 @@ def exception_logging_wrapper(*args, **kwargs):
tb = tb.tb_next
usage.log_exception(error_type, trace_to_log)
raise
finally:
if len(call_stack.get()) > 0:
call_stack.set(call_stack.get()[:-1])
return result

return exception_logging_wrapper


def log_event(event: UsageEvent):
usage.increment_event_count(event)
usage.log_event(event)


Expand Down