Skip to content

Commit

Permalink
feat(flagd-rpc): add caching with tests
Browse files Browse the repository at this point in the history
Signed-off-by: Simon Schrottner <simon.schrottner@dynatrace.com>
  • Loading branch information
aepfli committed Nov 29, 2024
1 parent b62d3d1 commit 79f69bc
Show file tree
Hide file tree
Showing 13 changed files with 432 additions and 12 deletions.
4 changes: 3 additions & 1 deletion providers/openfeature-provider-flagd/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"panzi-json-logic>=1.0.1",
"semver>=3,<4",
"pyyaml>=6.0.1",
"cachebox"
]
requires-python = ">=3.8"

Expand All @@ -46,7 +47,7 @@ pre-install-commands = [
]

[tool.hatch.envs.hatch-test.scripts]
run = "pytest {args:tests}"
run = "pytest -m 'not customCert and not unixsocket and not (sync and not events)' {args:tests}"
run-cov = "coverage run -m pytest {args:tests}"
cov-combine = "coverage combine"
cov-report = [
Expand All @@ -59,6 +60,7 @@ cov = [
"cov-report",
]


[tool.hatch.envs.mypy]
dependencies = [
"mypy[faster-cache]>=1.13.0",
Expand Down
10 changes: 10 additions & 0 deletions providers/openfeature-provider-flagd/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[pytest]
markers =
rpc: tests for rpc mode.
in-process: tests for rpc mode.
customCert: Supports custom certs.
unixsocket: Supports unixsockets.
events: Supports events.
sync: Supports sync.
caching: Supports caching.
offline: Supports offline.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ class ResolverType(Enum):
IN_PROCESS = "in-process"


class CacheType(Enum):
LRU = "lru"
DISABLED = "disabled"


DEFAULT_CACHE = CacheType.LRU
DEFAULT_CACHE_SIZE = 1000
DEFAULT_DEADLINE = 500
DEFAULT_HOST = "localhost"
DEFAULT_KEEP_ALIVE = 0
Expand All @@ -19,12 +26,14 @@ class ResolverType(Enum):
DEFAULT_STREAM_DEADLINE = 600000
DEFAULT_TLS = False

ENV_VAR_CACHE_SIZE = "FLAGD_MAX_CACHE_SIZE"
ENV_VAR_CACHE_TYPE = "FLAGD_CACHE"
ENV_VAR_DEADLINE_MS = "FLAGD_DEADLINE_MS"
ENV_VAR_HOST = "FLAGD_HOST"
ENV_VAR_KEEP_ALIVE_TIME_MS = "FLAGD_KEEP_ALIVE_TIME_MS"
ENV_VAR_OFFLINE_FLAG_SOURCE_PATH = "FLAGD_OFFLINE_FLAG_SOURCE_PATH"
ENV_VAR_PORT = "FLAGD_PORT"
ENV_VAR_RESOLVER_TYPE = "FLAGD_RESOLVER_TYPE"
ENV_VAR_RESOLVER_TYPE = "FLAGD_RESOLVER"
ENV_VAR_RETRY_BACKOFF_MS = "FLAGD_RETRY_BACKOFF_MS"
ENV_VAR_STREAM_DEADLINE_MS = "FLAGD_STREAM_DEADLINE_MS"
ENV_VAR_TLS = "FLAGD_TLS"
Expand Down Expand Up @@ -57,6 +66,8 @@ def __init__( # noqa: PLR0913
deadline: typing.Optional[int] = None,
stream_deadline_ms: typing.Optional[int] = None,
keep_alive_time: typing.Optional[int] = None,
cache_type: typing.Optional[CacheType] = None,
max_cache_size: typing.Optional[int] = None,
):
self.host = env_or_default(ENV_VAR_HOST, DEFAULT_HOST) if host is None else host

Expand Down Expand Up @@ -125,3 +136,15 @@ def __init__( # noqa: PLR0913
if keep_alive_time is None
else keep_alive_time
)

self.cache_type = (
CacheType(env_or_default(ENV_VAR_CACHE_TYPE, DEFAULT_CACHE))
if cache_type is None
else cache_type
)

self.max_cache_size: int = (
int(env_or_default(ENV_VAR_CACHE_SIZE, DEFAULT_CACHE_SIZE, cast=int))
if max_cache_size is None
else max_cache_size
)
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from openfeature.provider.metadata import Metadata
from openfeature.provider.provider import AbstractProvider

from .config import Config, ResolverType
from .config import CacheType, Config, ResolverType
from .resolvers import AbstractResolver, GrpcResolver, InProcessResolver

T = typing.TypeVar("T")
Expand All @@ -50,6 +50,8 @@ def __init__( # noqa: PLR0913
offline_flag_source_path: typing.Optional[str] = None,
stream_deadline_ms: typing.Optional[int] = None,
keep_alive_time: typing.Optional[int] = None,
cache_type: typing.Optional[CacheType] = None,
max_cache_size: typing.Optional[int] = None,
):
"""
Create an instance of the FlagdProvider
Expand Down Expand Up @@ -83,6 +85,8 @@ def __init__( # noqa: PLR0913
offline_flag_source_path=offline_flag_source_path,
stream_deadline_ms=stream_deadline_ms,
keep_alive_time=keep_alive_time,
cache_type=cache_type,
max_cache_size=max_cache_size,
)

self.resolver = self.setup_resolver()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing

import grpc
from cachebox import BaseCacheImpl, LRUCache
from google.protobuf.json_format import MessageToDict
from google.protobuf.struct_pb2 import Struct

Expand All @@ -18,13 +19,13 @@
ProviderNotReadyError,
TypeMismatchError,
)
from openfeature.flag_evaluation import FlagResolutionDetails
from openfeature.flag_evaluation import FlagResolutionDetails, Reason
from openfeature.schemas.protobuf.flagd.evaluation.v1 import (
evaluation_pb2,
evaluation_pb2_grpc,
)

from ..config import Config
from ..config import CacheType, Config
from ..flag_type import FlagType

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -57,6 +58,12 @@ def __init__(
self.deadline = config.deadline * 0.001
self.connected = False

self._cache: typing.Optional[BaseCacheImpl] = (
LRUCache(maxsize=self.config.max_cache_size)
if self.config.cache_type == CacheType.LRU
else None
)

def _create_stub(
self,
) -> typing.Tuple[evaluation_pb2_grpc.ServiceStub, grpc.Channel]:
Expand All @@ -71,10 +78,20 @@ def _create_stub(

def initialize(self, evaluation_context: EvaluationContext) -> None:
self.connect()
self.retry_backoff_seconds = 0.1
self.connected = False

self._cache = (
LRUCache(maxsize=self.config.max_cache_size)
if self.config.cache_type == CacheType.LRU
else None
)

def shutdown(self) -> None:
self.active = False
self.channel.close()
if self._cache:
self._cache.clear()

def connect(self) -> None:
self.active = True
Expand All @@ -96,7 +113,6 @@ def connect(self) -> None:

def listen(self) -> None:
retry_delay = self.retry_backoff_seconds

call_args = (
{"timeout": self.streamline_deadline_seconds}
if self.streamline_deadline_seconds > 0
Expand Down Expand Up @@ -148,6 +164,10 @@ def listen(self) -> None:
def handle_changed_flags(self, data: typing.Any) -> None:
changed_flags = list(data["flags"].keys())

if self._cache:
for flag in changed_flags:
self._cache.pop(flag)

self.emit_provider_configuration_changed(ProviderEventDetails(changed_flags))

def resolve_boolean_details(
Expand Down Expand Up @@ -190,13 +210,18 @@ def resolve_object_details(
) -> FlagResolutionDetails[typing.Union[dict, list]]:
return self._resolve(key, FlagType.OBJECT, default_value, evaluation_context)

def _resolve( # noqa: PLR0915
def _resolve( # noqa: PLR0915 C901
self,
flag_key: str,
flag_type: FlagType,
default_value: T,
evaluation_context: typing.Optional[EvaluationContext],
) -> FlagResolutionDetails[T]:
if self._cache is not None and flag_key in self._cache:
cached_flag: FlagResolutionDetails[T] = self._cache[flag_key]
cached_flag.reason = Reason.CACHED
return cached_flag

context = self._convert_context(evaluation_context)
call_args = {"timeout": self.deadline}
try:
Expand Down Expand Up @@ -249,12 +274,17 @@ def _resolve( # noqa: PLR0915
raise GeneralError(message) from e

# Got a valid flag and valid type. Return it.
return FlagResolutionDetails(
result = FlagResolutionDetails(
value=value,
reason=response.reason,
variant=response.variant,
)

if response.reason == Reason.STATIC and self._cache is not None:
self._cache.insert(flag_key, result)

return result

def _convert_context(
self, evaluation_context: typing.Optional[EvaluationContext]
) -> Struct:
Expand Down
Loading

0 comments on commit 79f69bc

Please sign in to comment.