Skip to content

Commit

Permalink
port enum env var support from #4248 (#4251)
Browse files Browse the repository at this point in the history
* port enum env var support from #4248

* add some tests for interpret env var functions
  • Loading branch information
benedikt-bartscher authored and masenf committed Oct 29, 2024
1 parent cdbe7f8 commit 2e100e3
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
26 changes: 26 additions & 0 deletions reflex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from __future__ import annotations

import dataclasses
import enum
import importlib
import inspect
import os
import sys
import urllib.parse
Expand Down Expand Up @@ -221,6 +223,28 @@ def interpret_path_env(value: str, field_name: str) -> Path:
return path


def interpret_enum_env(value: str, field_type: GenericType, field_name: str) -> Any:
"""Interpret an enum environment variable value.
Args:
value: The environment variable value.
field_type: The field type.
field_name: The field name.
Returns:
The interpreted value.
Raises:
EnvironmentVarValueError: If the value is invalid.
"""
try:
return field_type(value)
except ValueError as ve:
raise EnvironmentVarValueError(
f"Invalid enum value: {value} for {field_name}"
) from ve


def interpret_env_var_value(
value: str, field_type: GenericType, field_name: str
) -> Any:
Expand Down Expand Up @@ -252,6 +276,8 @@ def interpret_env_var_value(
return interpret_int_env(value, field_name)
elif field_type is Path:
return interpret_path_env(value, field_name)
elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
return interpret_enum_env(value, field_type, field_name)

else:
raise ValueError(
Expand Down
26 changes: 22 additions & 4 deletions tests/units/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@

import reflex as rx
import reflex.config
from reflex.config import environment
from reflex.constants import Endpoint
from reflex.config import (
environment,
interpret_boolean_env,
interpret_enum_env,
interpret_int_env,
)
from reflex.constants import Endpoint, Env


def test_requires_app_name():
Expand Down Expand Up @@ -208,11 +213,11 @@ def test_replace_defaults(
assert getattr(c, key) == value


def reflex_dir_constant():
def reflex_dir_constant() -> Path:
return environment.REFLEX_DIR


def test_reflex_dir_env_var(monkeypatch, tmp_path):
def test_reflex_dir_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Test that the REFLEX_DIR environment variable is used to set the Reflex.DIR constant.
Args:
Expand All @@ -224,3 +229,16 @@ def test_reflex_dir_env_var(monkeypatch, tmp_path):
mp_ctx = multiprocessing.get_context(method="spawn")
with mp_ctx.Pool(processes=1) as pool:
assert pool.apply(reflex_dir_constant) == tmp_path


def test_interpret_enum_env() -> None:
assert interpret_enum_env(Env.PROD.value, Env, "REFLEX_ENV") == Env.PROD


def test_interpret_int_env() -> None:
assert interpret_int_env("3001", "FRONTEND_PORT") == 3001


@pytest.mark.parametrize("value, expected", [("true", True), ("false", False)])
def test_interpret_bool_env(value: str, expected: bool) -> None:
assert interpret_boolean_env(value, "TELEMETRY_ENABLED") == expected

0 comments on commit 2e100e3

Please sign in to comment.