From 493d4c38d2ed361e702bcac0a2581ff93bb71fd7 Mon Sep 17 00:00:00 2001 From: Luis Trinidad Date: Thu, 17 Oct 2024 17:19:44 -0400 Subject: [PATCH] feat(connections): use the new wizard for setting up table groups When in the connection screen and no table groups exists, users will be presented with a two-step wizard to create the table group and (optionally) run the profiling. --- testgen/ui/components/frontend/css/shared.css | 56 ++- .../frontend/js/components/button.js | 74 ++- testgen/ui/components/widgets/__init__.py | 2 + testgen/ui/components/widgets/button.py | 10 +- .../components/widgets/testgen_component.py | 2 +- testgen/ui/components/widgets/wizard.py | 213 ++++++++ testgen/ui/forms.py | 117 +++++ testgen/ui/queries/table_group_queries.py | 8 +- testgen/ui/services/connection_service.py | 2 +- testgen/ui/services/table_group_service.py | 4 +- testgen/ui/session.py | 27 +- testgen/ui/views/connections.py | 456 ------------------ testgen/ui/views/connections/__init__.py | 3 + testgen/ui/views/connections/forms.py | 250 ++++++++++ testgen/ui/views/connections/models.py | 8 + testgen/ui/views/connections/page.py | 444 +++++++++++++++++ testgen/ui/views/table_groups/__init__.py | 2 + testgen/ui/views/table_groups/forms.py | 170 +++++++ .../{table_groups.py => table_groups/page.py} | 4 +- testgen/utils/singleton.py | 4 +- 20 files changed, 1344 insertions(+), 512 deletions(-) create mode 100644 testgen/ui/components/widgets/wizard.py create mode 100644 testgen/ui/forms.py delete mode 100644 testgen/ui/views/connections.py create mode 100644 testgen/ui/views/connections/__init__.py create mode 100644 testgen/ui/views/connections/forms.py create mode 100644 testgen/ui/views/connections/models.py create mode 100644 testgen/ui/views/connections/page.py create mode 100644 testgen/ui/views/table_groups/__init__.py create mode 100644 testgen/ui/views/table_groups/forms.py rename testgen/ui/views/{table_groups.py => table_groups/page.py} (99%) diff --git a/testgen/ui/components/frontend/css/shared.css b/testgen/ui/components/frontend/css/shared.css index 3284332..04aab9a 100644 --- a/testgen/ui/components/frontend/css/shared.css +++ b/testgen/ui/components/frontend/css/shared.css @@ -24,8 +24,9 @@ body { --primary-text-color: #000000de; --secondary-text-color: #0000008a; --disabled-text-color: #00000042; - --caption-text-color: rgba(49, 51, 63, 0.6); - /* Match Streamlit's caption color */ + --caption-text-color: rgba(49, 51, 63, 0.6); /* Match Streamlit's caption color */ + --border-color: rgba(0, 0, 0, .12); + --dk-card-background: #fff; --sidebar-background-color: white; --sidebar-item-hover-color: #f5f5f5; @@ -34,22 +35,28 @@ body { --field-underline-color: #9e9e9e; - --button-text-color: var(--primary-text-color); - - --button-hover-state-background: var(--primary-color); --button-hover-state-opacity: 0.12; - --button-basic-text-color: var(--primary-color); --button-basic-background: transparent; + --button-basic-text-color: rgba(0, 0, 0, .54); + --button-basic-hover-state-background: rgba(0, 0, 0, .54); - --button-flat-text-color: rgba(255, 255, 255); - --button-flat-background: rgba(0, 0, 0, .54); + --button-basic-flat-text-color: rgba(0, 0, 0); + --button-basic-flat-background: rgba(0, 0, 0, .54); - --button-stroked-text-color: var(--primary-color); - --button-stroked-background: transparent; - --button-stroked-border: 1px solid rgba(0, 0, 0, .12); + --button-basic-stroked-text-color: rgba(0, 0, 0, .54); + --button-basic-stroked-background: transparent; - --dk-card-background: #fff; + --button-primary-background: transparent; + --button-primary-text-color: var(--primary-color); + --button-primary-hover-state-background: var(--primary-color); + + --button-primary-flat-text-color: rgba(255, 255, 255); + --button-primary-flat-background: var(--primary-color); + + --button-primary-stroked-text-color: var(--primary-color); + --button-primary-stroked-background: transparent; + --button-stroked-border: 1px solid var(--border-color); } @media (prefers-color-scheme: dark) { @@ -57,8 +64,9 @@ body { --primary-text-color: rgba(255, 255, 255); --secondary-text-color: rgba(255, 255, 255, .7); --disabled-text-color: rgba(255, 255, 255, .5); - --caption-text-color: rgba(250, 250, 250, .6); - /* Match Streamlit's caption color */ + --caption-text-color: rgba(250, 250, 250, .6); /* Match Streamlit's caption color */ + --border-color: rgba(255, 255, 255, .25); + --dk-card-background: #14181f; --sidebar-background-color: #14181f; --sidebar-item-hover-color: #10141b; @@ -66,13 +74,17 @@ body { --sidebar-active-item-border-color: #b4e3c9; --dk-text-value-background: unset; - --button-text-color: var(--primary-text-color); - - --button-flat-background: rgba(255, 255, 255, .54); - - --button-stroked-border: 1px solid rgba(255, 255, 255, .12); - - --dk-card-background: #14181f; + --button-basic-background: transparent; + --button-basic-text-color: rgba(255, 255, 255); + --button-basic-hover-state-background: rgba(255, 255, 255, .54); + + --button-basic-flat-text-color: rgba(255, 255, 255); + --button-basic-flat-background: rgba(255, 255, 255, .54); + + --button-basic-stroked-text-color: rgba(255, 255, 255, .85); + --button-basic-stroked-background: transparent; + + --button-stroked-border: 1px solid var(--border-color); } } @@ -441,4 +453,4 @@ body { .pl-7 { padding-left: 40px; } -/* */ \ No newline at end of file +/* */ diff --git a/testgen/ui/components/frontend/js/components/button.js b/testgen/ui/components/frontend/js/components/button.js index e3670e3..893a1b1 100644 --- a/testgen/ui/components/frontend/js/components/button.js +++ b/testgen/ui/components/frontend/js/components/button.js @@ -2,6 +2,7 @@ * @typedef Properties * @type {object} * @property {(string)} type + * @property {(string|null)} color * @property {(string|null)} label * @property {(string|null)} icon * @property {(string|null)} tooltip @@ -21,6 +22,11 @@ const BUTTON_TYPE = { ICON: 'icon', STROKED: 'stroked', }; +const BUTTON_COLOR = { + BASIC: 'basic', + PRIMARY: 'primary', +}; + const Button = (/** @type Properties */ props) => { loadStylesheet('button', stylesheet); @@ -32,6 +38,10 @@ const Button = (/** @type Properties */ props) => { if (isIconOnly) { // Force a 40px width for the parent iframe & handle window resizing enforceElementWidth(window.frameElement, 40); } + + if (props.width?.val) { + enforceElementWidth(window.frameElement, props.width?.val); + } } if (props.tooltip) { @@ -42,10 +52,10 @@ const Button = (/** @type Properties */ props) => { const onClickHandler = props.onclick || (() => emitEvent('ButtonClicked')); return button( { - class: `tg-button tg-${props.type.val}-button ${props.type.val !== 'icon' && isIconOnly ? 'tg-icon-button' : ''}`, - style: props.style?.val, + class: `tg-button tg-${props.type.val}-button tg-${props.color?.val ?? 'basic'}-button ${props.type.val !== 'icon' && isIconOnly ? 'tg-icon-button' : ''}`, + style: () => `width: ${props.width?.val ?? '100%'}; ${props.style?.val}`, onclick: onClickHandler, - disabled: !!props.disabled?.val, + disabled: props.disabled, }, span({class: 'tg-button-focus-state-indicator'}, ''), props.icon ? i({class: 'material-symbols-rounded'}, props.icon) : undefined, @@ -56,7 +66,6 @@ const Button = (/** @type Properties */ props) => { const stylesheet = new CSSStyleSheet(); stylesheet.replace(` button.tg-button { - width: 100%; height: 40px; position: relative; @@ -75,8 +84,6 @@ button.tg-button { cursor: pointer; font-size: 14px; - color: var(--button-text-color); - background: var(--button-basic-background); } button.tg-button .tg-button-focus-state-indicator::before { @@ -89,21 +96,9 @@ button.tg-button .tg-button-focus-state-indicator::before { position: absolute; pointer-events: none; border-radius: inherit; - background: var(--button-hover-state-background); -} - -button.tg-button.tg-basic-button { - color: var(--button-basic-text-color); -} - -button.tg-button.tg-flat-button { - color: var(--button-flat-text-color); - background: var(--button-flat-background); } button.tg-button.tg-stroked-button { - color: var(--button-stroked-text-color); - background: var(--button-stroked-background); border: var(--button-stroked-border); } @@ -135,6 +130,49 @@ button.tg-button > i:has(+ span) { button.tg-button:hover:not([disabled]) .tg-button-focus-state-indicator::before { opacity: var(--button-hover-state-opacity); } + + +/* Basic button colors */ +button.tg-button.tg-basic-button { + color: var(--button-basic-text-color); + background: var(--button-basic-background); +} + +button.tg-button.tg-basic-button .tg-button-focus-state-indicator::before { + background: var(--button-basic-hover-state-background); +} + +button.tg-button.tg-basic-button.tg-flat-button { + color: var(--button-basic-flat-text-color); + background: var(--button-basic-flat-background); +} + +button.tg-button.tg-basic-button.tg-stroked-button { + color: var(--button-basic-stroked-text-color); + background: var(--button-basic-stroked-background); +} +/* ... */ + +/* Primary button colors */ +button.tg-button.tg-primary-button { + color: var(--button-primary-text-color); + background: var(--button-primary-background); +} + +button.tg-button.tg-primary-button .tg-button-focus-state-indicator::before { + background: var(--button-primary-hover-state-background); +} + +button.tg-button.tg-primary-button.tg-flat-button { + color: var(--button-primary-flat-text-color); + background: var(--button-primary-flat-background); +} + +button.tg-button.tg-primary-button.tg-stroked-button { + color: var(--button-primary-stroked-text-color); + background: var(--button-primary-stroked-background); +} +/* ... */ `); export { Button }; diff --git a/testgen/ui/components/widgets/__init__.py b/testgen/ui/components/widgets/__init__.py index c847d35..d58047e 100644 --- a/testgen/ui/components/widgets/__init__.py +++ b/testgen/ui/components/widgets/__init__.py @@ -1,5 +1,6 @@ # ruff: noqa: F401 +from testgen.ui.components.utils.component import component from testgen.ui.components.widgets.breadcrumbs import breadcrumbs from testgen.ui.components.widgets.button import button from testgen.ui.components.widgets.card import card @@ -23,3 +24,4 @@ from testgen.ui.components.widgets.sorting_selector import sorting_selector from testgen.ui.components.widgets.summary_bar import summary_bar from testgen.ui.components.widgets.testgen_component import testgen_component +from testgen.ui.components.widgets.wizard import wizard, WizardStep diff --git a/testgen/ui/components/widgets/button.py b/testgen/ui/components/widgets/button.py index 4b0a2d0..3c32630 100644 --- a/testgen/ui/components/widgets/button.py +++ b/testgen/ui/components/widgets/button.py @@ -3,17 +3,20 @@ from testgen.ui.components.utils.component import component ButtonType = typing.Literal["basic", "flat", "icon", "stroked"] +ButtonColor = typing.Literal["basic", "primary"] TooltipPosition = typing.Literal["left", "right"] def button( type_: ButtonType = "basic", + color: ButtonColor = "primary", label: str | None = None, icon: str | None = None, tooltip: str | None = None, tooltip_position: TooltipPosition = "left", on_click: typing.Callable[..., None] | None = None, disabled: bool = False, + width: str | int | float | None = None, style: str | None = None, key: str | None = None, ) -> typing.Any: @@ -26,7 +29,7 @@ def button( :param on_click: click handler for this button """ - props = {"type": type_, "disabled": disabled} + props = {"type": type_, "disabled": disabled, "color": color} if type_ != "icon": if not label: raise ValueError(f"A label is required for {type_} buttons") @@ -38,6 +41,11 @@ def button( if tooltip: props.update({"tooltip": tooltip, "tooltipPosition": tooltip_position}) + if width: + props.update({"width": width}) + if isinstance(width, (int, float,)): + props.update({"width": f"{width}px"}) + if style: props.update({"style": style}) diff --git a/testgen/ui/components/widgets/testgen_component.py b/testgen/ui/components/widgets/testgen_component.py index 447686e..7fb2be2 100644 --- a/testgen/ui/components/widgets/testgen_component.py +++ b/testgen/ui/components/widgets/testgen_component.py @@ -6,7 +6,7 @@ def testgen_component( - component_id: typing.Literal["profiling_runs", "test_runs"], + component_id: typing.Literal["profiling_runs", "test_runs", "database_flavor_selector"], props: dict, event_handlers: dict | None, ) -> dict | None: diff --git a/testgen/ui/components/widgets/wizard.py b/testgen/ui/components/widgets/wizard.py new file mode 100644 index 0000000..8a055f2 --- /dev/null +++ b/testgen/ui/components/widgets/wizard.py @@ -0,0 +1,213 @@ +import dataclasses +import logging +import inspect +import typing + +import streamlit as st +from streamlit.delta_generator import DeltaGenerator + +from testgen.ui.components import widgets as testgen +from testgen.ui.navigation.router import Router +from testgen.ui.session import temp_value + +ResultsType = typing.TypeVar("ResultsType", bound=typing.Any | None) +StepResults = tuple[typing.Any, bool] +logger = logging.getLogger("testgen") + + +def wizard( + *, + key: str, + steps: list[typing.Callable[..., StepResults] | "WizardStep"], + on_complete: typing.Callable[..., bool], + complete_label: str = "Complete", + navigate_to: str | None = None, + navigate_to_args: dict | None = None, +) -> None: + """ + Creates a Wizard with the provided steps and handles the session for + each step internally. + + For each step callable instances of WizardStep for the current step + and previous steps are optionally provided as keyword arguments with + specific names. + + Optional arguments that can be accessed as follows: + + ``` + def step_fn(current_step: WizardStep = ..., step_0: WizardStep = ...) + ... + ``` + + For the `on_complete` callable, on top of passing each WizardStep, a + Streamlit DeltaGenerator is also passed to allow rendering content + inside the step's body. + + ``` + def on_complete(container: DeltaGenerator, step_0: WizardStep = ..., step_1: WizardStep = ...): + ... + ``` + + After the `on_complete` callback returns, the wizard state is reset. + + :param key: used to cache current step and results of each step + :param steps: a list of WizardStep instances or callable objects + :param on_complete: callable object to execute after the last step. + should return true to trigger a Streamlit rerun + :param complete_label: customize the label for the complete button + + :return: None + """ + + if navigate_to: + Router().navigate(navigate_to, navigate_to_args or {}) + + current_step_idx = 0 + wizard_state = st.session_state.get(key) + if isinstance(wizard_state, int): + current_step_idx = wizard_state + + instance = Wizard( + key=key, + steps=[ + WizardStep( + key=f"{key}:{idx}", + body=step, + results=st.session_state.get(f"{key}:{idx}", None), + ) if not isinstance(step, WizardStep) else dataclasses.replace( + step, + key=f"{key}:{idx}", + results=st.session_state.get(f"{key}:{idx}", None), + ) + for idx, step in enumerate(steps) + ], + current_step=current_step_idx, + on_complete=on_complete, + ) + + current_step = instance.current_step + current_step_index = instance.current_step_index + testgen.caption( + f"Step {current_step_index + 1} of {len(steps)}{': ' + current_step.title if current_step.title else ''}" + ) + + step_body_container = st.empty() + with step_body_container.container(): + was_complete_button_clicked, set_complete_button_clicked = temp_value(f"{key}:complete-button") + + if was_complete_button_clicked(): + instance.complete(step_body_container) + else: + instance.render() + button_left_column, _, button_right_column = st.columns([0.30, 0.40, 0.30]) + with button_left_column: + if not instance.is_first_step(): + testgen.button( + type_="stroked", + color="basic", + label="Previous", + on_click=lambda: instance.previous(), + key=f"{key}:button-prev", + ) + + with button_right_column: + next_button_label = complete_label if instance.is_last_step() else "Next" + + testgen.button( + type_="stroked" if not instance.is_last_step() else "flat", + label=next_button_label, + on_click=lambda: set_complete_button_clicked(instance.next() or instance.is_last_step()), + key=f"{key}:button-next", + disabled=not current_step.is_valid, + ) + + +class Wizard: + def __init__( + self, + *, + key: str, + steps: list["WizardStep"], + on_complete: typing.Callable[..., bool] | None = None, + current_step: int = 0, + ) -> None: + self._key = key + self._steps = steps + self._current_step = current_step + self._on_complete = on_complete + + @property + def current_step(self) -> "WizardStep": + return self._steps[self._current_step] + + @property + def current_step_index(self) -> int: + return self._current_step + + def next(self) -> None: + next_step = self._current_step + 1 + if not self.is_last_step(): + st.session_state[self._key] = next_step + return + + def previous(self) -> None: + previous_step = self._current_step - 1 + if previous_step > -1: + st.session_state[self._key] = previous_step + + def is_first_step(self) -> bool: + return self._current_step == 0 + + def is_last_step(self) -> bool: + return self._current_step == len(self._steps) - 1 + + def complete(self, container: DeltaGenerator) -> None: + if self._on_complete: + signature = inspect.signature(self._on_complete) + accepted_params = [param.name for param in signature.parameters.values()] + kwargs: dict = { + key: step for idx, step in enumerate(self._steps) + if (key := f"step_{idx}") and key in accepted_params + } + if "container" in accepted_params: + kwargs["container"] = container + + do_rerun = self._on_complete(**kwargs) + self._reset() + if do_rerun: + st.rerun() + + def _reset(self) -> None: + del st.session_state[self._key] + for step_idx in range(len(self._steps)): + del st.session_state[f"{self._key}:{step_idx}"] + + def render(self) -> None: + step = self._steps[self._current_step] + + extra_args = {"current_step": step} + extra_args.update({f"step_{idx}": step for idx, step in enumerate(self._steps)}) + + signature = inspect.signature(step.body) + step_accepted_params = [param.name for param in signature.parameters.values() if param.name in extra_args] + extra_args = {key: value for key, value in extra_args.items() if key in step_accepted_params} + + try: + results, is_valid = step.body(**extra_args) + except TypeError as error: + logger.exception("Error on wizard step %s", self._current_step, exc_info=True, stack_info=True) + results, is_valid = None, True + + step.results = results + step.is_valid = is_valid + + st.session_state[f"{self._key}:{self._current_step}"] = step.results + + +@dataclasses.dataclass(kw_only=True, slots=True) +class WizardStep[ResultsType]: + body: typing.Callable[..., StepResults] + results: ResultsType = dataclasses.field(default=None) + title: str = dataclasses.field(default="") + key: str | None = dataclasses.field(default=None) + is_valid: bool = dataclasses.field(default=True) diff --git a/testgen/ui/forms.py b/testgen/ui/forms.py new file mode 100644 index 0000000..61a7120 --- /dev/null +++ b/testgen/ui/forms.py @@ -0,0 +1,117 @@ +import typing + +import streamlit as st +from pydantic import BaseModel, Field +from pydantic.json_schema import DEFAULT_REF_TEMPLATE, GenerateJsonSchema, JsonSchemaMode +from streamlit.delta_generator import DeltaGenerator +from streamlit_pydantic.ui_renderer import InputUI + + +class BaseForm(BaseModel): + def __init__(self, /, **data: typing.Any) -> None: + super().__init__(**data) + + @classmethod + def empty(cls) -> typing.Self: + non_validated_instance = cls.model_construct() + non_validated_instance.model_post_init(None) + + return non_validated_instance + + @property + def _disabled_fields(self) -> typing.Set[str]: + if not getattr(self, "_disabled_fields_set", None): + self._disabled_fields_set = set() + return self._disabled_fields_set + + def disable(self, field: str) -> None: + self._disabled_fields.add(field) + + def enable(self, field) -> None: + self._disabled_fields.remove(field) + + @classmethod + def model_json_schema( + self_or_cls, # type: ignore + by_alias: bool = True, + ref_template: str = DEFAULT_REF_TEMPLATE, + schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema, + mode: JsonSchemaMode = 'validation', + ) -> dict[str, typing.Any]: + schema = super().model_json_schema( + by_alias=by_alias, + ref_template=ref_template, + schema_generator=schema_generator, + mode=mode, + ) + + schema_properties: dict[str, dict] = schema.get("properties", {}) + disabled_fields: set[str] = getattr(self_or_cls, "_disabled_fields_set", set()) + for property_name, property_schema in schema_properties.items(): + if property_name in disabled_fields and not property_schema.get("readOnly"): + property_schema["readOnly"] = True + + return schema + + @classmethod + def get_field_label(cls, field_name: str) -> str: + schema = cls.model_json_schema() + schema_properties = schema.get("properties", {}) + field_schema = schema_properties[field_name] + return field_schema.get("st_kwargs_label") or field_schema.get("title") + + +class ManualRender: + @property + def input_ui(self): + if not getattr(self, "_input_ui", None): + self._input_ui = InputUI( + self.form_key(), + self, # type: ignore + group_optional_fields="no", # type: ignore + lowercase_labels=False, + ignore_empty_values=False, + return_model=False, + ) + return self._input_ui + + def form_key(self): + raise NotImplementedError() + + def render_input_ui(self, container: DeltaGenerator, session_state: dict) -> typing.Self: + raise NotImplementedError() + + def render_field(self, field_name: str, container: DeltaGenerator | None = None) -> typing.Any: + streamlit_container = container or self.input_ui._streamlit_container + model_property = self.input_ui._schema_properties[field_name] + initial_value = getattr(self, field_name, None) or self.input_ui._get_value(field_name) + is_disabled = field_name in getattr(self, "_disabled_fields", set()) + + if is_disabled: + model_property["readOnly"] = True + + if model_property.get("type") != "boolean" and initial_value not in [None, ""]: + model_property["init_value"] = initial_value + + new_value = self.input_ui._render_property(streamlit_container, field_name, model_property) + self.update_field_value(field_name, new_value) + + return new_value + + def update_field_value(self, field_name: str, value: typing.Any) -> typing.Any: + self.input_ui._store_value(field_name, value) + setattr(self, field_name, value) + return value + + def get_field_value(self, field_name: str, latest: bool = False) -> typing.Any: + if latest: + return st.session_state.get(self.get_field_key(field_name)) + return self.input_ui._get_value(field_name) + + def reset_cache(self) -> None: + for field_name in typing.cast(type[BaseForm], type(self)).model_fields.keys(): + st.session_state.pop(self.get_field_key(field_name), None) + st.session_state.pop(self.form_key() + "-data", None) + + def get_field_key(self, field_name: str) -> typing.Any: + return str(self.input_ui._session_state.run_id) + "-" + str(self.input_ui._key) + "-" + field_name diff --git a/testgen/ui/queries/table_group_queries.py b/testgen/ui/queries/table_group_queries.py index 0663a6f..c13e62a 100644 --- a/testgen/ui/queries/table_group_queries.py +++ b/testgen/ui/queries/table_group_queries.py @@ -1,3 +1,5 @@ +import uuid + import streamlit as st import testgen.ui.services.database_service as db @@ -108,7 +110,8 @@ def edit(schema, table_group): st.cache_data.clear() -def add(schema, table_group): +def add(schema, table_group) -> str: + new_table_group_id = str(uuid.uuid4()) sql = f"""INSERT INTO {schema}.table_groups (id, project_code, @@ -132,7 +135,7 @@ def add(schema, table_group): source_process, stakeholder_group) SELECT - gen_random_uuid(), + '{new_table_group_id}', '{table_group["project_code"]}', '{table_group["connection_id"]}', '{table_group["table_groups_name"]}', @@ -155,6 +158,7 @@ def add(schema, table_group): ;""" db.execute_sql(sql) st.cache_data.clear() + return new_table_group_id def delete(schema, table_group_ids): diff --git a/testgen/ui/services/connection_service.py b/testgen/ui/services/connection_service.py index 394c82a..27ebf7e 100644 --- a/testgen/ui/services/connection_service.py +++ b/testgen/ui/services/connection_service.py @@ -207,7 +207,7 @@ def form_overwritten_connection_url(connection): "dbname": connection["project_db"], "url": None, "connect_by_url": None, - "connect_by_key": connection["connect_by_key"], + "connect_by_key": connection.get("connect_by_key"), "private_key": None, "private_key_passphrase": "", "dbschema": "", diff --git a/testgen/ui/services/table_group_service.py b/testgen/ui/services/table_group_service.py index 57ea6bd..f51d360 100644 --- a/testgen/ui/services/table_group_service.py +++ b/testgen/ui/services/table_group_service.py @@ -21,9 +21,9 @@ def edit(table_group): table_group_queries.edit(schema, table_group) -def add(table_group): +def add(table_group: dict) -> str: schema = st.session_state["dbschema"] - table_group_queries.add(schema, table_group) + return table_group_queries.add(schema, table_group) def cascade_delete(table_group_names, dry_run=False): diff --git a/testgen/ui/session.py b/testgen/ui/session.py index 0802132..bb198a8 100644 --- a/testgen/ui/session.py +++ b/testgen/ui/session.py @@ -1,16 +1,20 @@ -import typing +from typing import Any, Callable, Literal, TypeVar import streamlit as st from streamlit.runtime.state import SessionStateProxy from testgen.utils.singleton import Singleton +T = TypeVar("T") +TempValueGetter = Callable[..., T] +TempValueSetter = Callable[[T], None] + class TestgenSession(Singleton): cookies_ready: int logging_in: bool logging_out: bool - page_pending_cookies: st.Page + page_pending_cookies: st.Page # type: ignore page_pending_login: str page_pending_sidebar: str page_args_pending_router: dict @@ -23,7 +27,7 @@ class TestgenSession(Singleton): name: str username: str authentication_status: bool - auth_role: typing.Literal["admin", "edit", "read"] + auth_role: Literal["admin", "edit", "read"] project: str add_project: bool @@ -34,13 +38,13 @@ class TestgenSession(Singleton): def __init__(self, state: SessionStateProxy) -> None: super().__setattr__("_state", state) - def __getattr__(self, key: str) -> typing.Any: + def __getattr__(self, key: str) -> Any: state = object.__getattribute__(self, "_state") if key not in state: return None return state[key] - def __setattr__(self, key: str, value: typing.Any) -> None: + def __setattr__(self, key: str, value: Any) -> None: object.__getattribute__(self, "_state")[key] = value def __delattr__(self, key: str) -> None: @@ -49,4 +53,17 @@ def __delattr__(self, key: str) -> None: del state[key] +def temp_value(session_key: str, *, default: T | None = None) -> tuple[TempValueGetter[T | None], TempValueSetter[T]]: + scoped_session_key = f"tg-session:tmp-value:{session_key}" + + def getter() -> T | None: + if scoped_session_key not in st.session_state: + return default + return st.session_state.pop(scoped_session_key, None) + + def setter(value: T): + st.session_state[scoped_session_key] = value + + return getter, setter + session: TestgenSession = TestgenSession(st.session_state) diff --git a/testgen/ui/views/connections.py b/testgen/ui/views/connections.py deleted file mode 100644 index 33df711..0000000 --- a/testgen/ui/views/connections.py +++ /dev/null @@ -1,456 +0,0 @@ -import dataclasses -import logging -import os -import time -import typing - -import streamlit as st - -import testgen.ui.services.database_service as db -from testgen.commands.run_setup_profiling_tools import get_setup_profiling_tools_queries -from testgen.common.database.database_service import empty_cache -from testgen.ui.components import widgets as testgen -from testgen.ui.navigation.menu import MenuItem -from testgen.ui.navigation.page import Page -from testgen.ui.services import authentication_service, connection_service -from testgen.ui.session import session - -LOG = logging.getLogger("testgen") - - -class ConnectionsPage(Page): - path = "connections" - can_activate: typing.ClassVar = [ - lambda: session.authentication_status, - ] - menu_item = MenuItem(icon="database", label="Data Configuration", order=4) - - def render(self, project_code: str, **_kwargs) -> None: - dataframe = connection_service.get_connections(project_code) - connection = dataframe.iloc[0] - - testgen.page_header( - "Connection", - "https://docs.datakitchen.io/article/dataops-testgen-help/connect-your-database", - ) - - _, actions_column = st.columns([.1, .9]) - testgen.flex_row_end(actions_column) - - enable_table_groups = connection["project_host"] and connection["project_db"] and connection["project_qc_schema"] - - with st.container(border=True): - self.show_connection_form(connection, "edit", project_code) - - if actions_column.button( - "Configure QC Utility Schema", - help="Creates the required Utility schema and related functions in the target database", - ): - self.create_qc_schema_dialog(connection) - - if actions_column.button( - f":{'gray' if not enable_table_groups else 'green'}[Table Groups →]", - help="Create or edit Table Groups for the Connection", - ): - self.router.navigate( - "connections:table-groups", - {"connection_id": connection["connection_id"]}, - ) - - @st.dialog(title="Configure QC Utility Schema") - def create_qc_schema_dialog(self, selected_connection): - connection_id = selected_connection["connection_id"] - project_qc_schema = selected_connection["project_qc_schema"] - sql_flavor = selected_connection["sql_flavor"] - user = selected_connection["project_user"] - - create_qc_schema = st.toggle("Create QC Utility Schema", value=True) - grant_privileges = st.toggle("Grant access privileges to TestGen user", value=True) - - user_role = None - - # TODO ALEX: This textbox may be needed if we want to grant permissions to user role - # if sql_flavor == "snowflake": - # user_role_textbox_label = f"Primary role for database user {user}" - # user_role = st.text_input(label=user_role_textbox_label, max_chars=100) - - admin_credentials_expander = st.expander("Admin credential options", expanded=True) - with admin_credentials_expander: - admin_connection_option_index = 0 - admin_connection_options = ["Do not use admin credentials", "Use admin credentials with Password"] - if sql_flavor == "snowflake": - admin_connection_options.append("Use admin credentials with Key-Pair") - - admin_connection_option = st.radio( - "Admin credential options", - label_visibility="hidden", - options=admin_connection_options, - index=admin_connection_option_index, - horizontal=True, - ) - - st.markdown("

 
", unsafe_allow_html=True) - - db_user = None - db_password = None - admin_private_key_passphrase = None - admin_private_key = None - if admin_connection_option == admin_connection_options[0]: - st.markdown(":orange[User created in the connection dialog will be used.]") - else: - db_user = st.text_input(label="Admin db user", max_chars=40) - if admin_connection_option == admin_connection_options[1]: - db_password = st.text_input( - label="Admin db password", max_chars=40, type="password" - ) - st.markdown(":orange[Note: Admin credentials are not stored, are only used for this operation.]") - - if len(admin_connection_options) > 2 and admin_connection_option == admin_connection_options[2]: - admin_private_key_passphrase = st.text_input( - label="Private Key Passphrase", - key="create-qc-schema-private-key-password", - type="password", - max_chars=200, - help="Passphrase used while creating the private Key (leave empty if not applicable)", - ) - - admin_uploaded_file = st.file_uploader("Upload private key (rsa_key.p8)", key="admin-uploaded-file") - if admin_uploaded_file: - admin_private_key = admin_uploaded_file.getvalue().decode("utf-8") - - st.markdown(":orange[Note: Admin credentials are not stored, are only used for this operation.]") - - submit = st.button("Update Configuration") - - if submit: - empty_cache() - script_expander = st.expander("Script Details") - - operation_status = st.empty() - operation_status.info(f"Configuring QC Utility Schema '{project_qc_schema}'...") - - try: - skip_granting_privileges = not grant_privileges - queries = get_setup_profiling_tools_queries(sql_flavor, create_qc_schema, skip_granting_privileges, project_qc_schema, user, user_role) - with script_expander: - st.code( - os.linesep.join(queries), - language="sql", - line_numbers=True) - - connection_service.create_qc_schema( - connection_id, - create_qc_schema, - db_user if db_user else None, - db_password if db_password else None, - skip_granting_privileges, - admin_private_key_passphrase=admin_private_key_passphrase, - admin_private_key=admin_private_key, - user_role=user_role, - ) - operation_status.empty() - operation_status.success("Operation has finished successfully.") - - except Exception as e: - operation_status.empty() - operation_status.error("Error configuring QC Utility Schema.") - error_message = e.args[0] - st.text_area("Error Details", value=error_message) - - def show_connection_form(self, selected_connection, mode, project_code): - flavor_options = ["redshift", "snowflake", "mssql", "postgresql"] - connection_options = ["Connect by Password", "Connect by Key-Pair"] - - left_column, right_column = st.columns([0.75, 0.25]) - - mid_column = st.columns(1)[0] - url_override_toogle_container = st.container() - bottom_left_column, bottom_right_column = st.columns([0.25, 0.75]) - button_left_column, button_right_column = st.columns([0.20, 0.80]) - connection_status_wrapper = st.container() - - connection_id = selected_connection["connection_id"] if mode == "edit" else None - connection_name = selected_connection["connection_name"] if mode == "edit" else "" - sql_flavor_index = flavor_options.index(selected_connection["sql_flavor"]) if mode == "edit" else 0 - project_port = selected_connection["project_port"] if mode == "edit" else "" - project_host = selected_connection["project_host"] if mode == "edit" else "" - project_db = selected_connection["project_db"] if mode == "edit" else "" - project_user = selected_connection["project_user"] if mode == "edit" else "" - url = selected_connection["url"] if mode == "edit" else "" - project_qc_schema = selected_connection["project_qc_schema"] if mode == "edit" else "qc" - password = selected_connection["password"] if mode == "edit" else "" - max_threads = selected_connection["max_threads"] if mode == "edit" else 4 - max_query_chars = selected_connection["max_query_chars"] if mode == "edit" else 10000 - connect_by_url = selected_connection["connect_by_url"] if mode == "edit" else False - connect_by_key = selected_connection["connect_by_key"] if mode == "edit" else False - connection_option_index = 1 if connect_by_key else 0 - private_key = selected_connection["private_key"] if mode == "edit" else None - private_key_passphrase = selected_connection["private_key_passphrase"] if mode == "edit" else "" - - new_connection = { - "connection_id": connection_id, - "project_code": project_code, - "private_key": private_key, - "private_key_passphrase": private_key_passphrase, - "password": password, - "url": url, - "max_threads": right_column.number_input( - label="Max Threads (Advanced Tuning)", - min_value=1, - max_value=8, - value=max_threads, - help=( - "Maximum number of concurrent threads that run tests. Default values should be retained unless " - "test queries are failing." - ), - key=f"connections:form:max-threads:{connection_id or 0}", - ), - "max_query_chars": right_column.number_input( - label="Max Expression Length (Advanced Tuning)", - min_value=500, - max_value=14000, - value=max_query_chars, - help="Some tests are consolidated into queries for maximum performance. Default values should be retained unless test queries are failing.", - key=f"connections:form:max-length:{connection_id or 0}", - ), - "connection_name": left_column.text_input( - label="Connection Name", - max_chars=40, - value=connection_name, - help="Your name for this connection. Can be any text.", - key=f"connections:form:name:{connection_id or 0}", - ), - "sql_flavor": left_column.selectbox( - label="SQL Flavor", - options=flavor_options, - index=sql_flavor_index, - help="The type of database server that you will connect to. This determines TestGen's drivers and SQL dialect.", - key=f"connections:form:flavor:{connection_id or 0}", - ) - } - - st.session_state.disable_url_widgets = connect_by_url - - new_connection["project_port"] = right_column.text_input( - label="Port", - max_chars=5, - value=project_port, - disabled=st.session_state.disable_url_widgets, - key=f"connections:form:port:{connection_id or 0}", - ) - new_connection["project_host"] = left_column.text_input( - label="Host", - max_chars=250, - value=project_host, - disabled=st.session_state.disable_url_widgets, - key=f"connections:form:host:{connection_id or 0}", - ) - new_connection["project_db"] = left_column.text_input( - label="Database", - max_chars=100, - value=project_db, - help="The name of the database defined on your host where your schemas and tables is present.", - disabled=st.session_state.disable_url_widgets, - key=f"connections:form:database:{connection_id or 0}", - ) - - new_connection["project_user"] = left_column.text_input( - label="User", - max_chars=50, - value=project_user, - help="Username to connect to your database.", - key=f"connections:form:user:{connection_id or 0}", - ) - - new_connection["project_qc_schema"] = right_column.text_input( - label="QC Utility Schema", - max_chars=50, - value=project_qc_schema, - help="The name of the schema on your database that will contain TestGen's profiling functions.", - key=f"connections:form:qcschema:{connection_id or 0}", - ) - - if new_connection["sql_flavor"] == "snowflake": - mid_column.divider() - - connection_option = mid_column.radio( - "Connection options", - options=connection_options, - index=connection_option_index, - horizontal=True, - help="Connection strategy", - key=f"connections:form:type_options:{connection_id or 0}", - ) - - new_connection["connect_by_key"] = connection_option == "Connect by Key-Pair" - password_column = mid_column - else: - new_connection["connect_by_key"] = False - password_column = left_column - - uploaded_file = None - - if new_connection["connect_by_key"]: - new_connection["private_key_passphrase"] = mid_column.text_input( - label="Private Key Passphrase", - type="password", - max_chars=200, - value=private_key_passphrase, - help="Passphrase used while creating the private Key (leave empty if not applicable)", - key=f"connections:form:passphrase:{connection_id or 0}", - ) - - uploaded_file = mid_column.file_uploader("Upload private key (rsa_key.p8)") - else: - new_connection["password"] = password_column.text_input( - label="Password", - max_chars=50, - type="password", - value=password, - help="Password to connect to your database.", - key=f"connections:form:password:{connection_id or 0}", - ) - - mid_column.divider() - - url_override_help_text = "If this switch is set to on, the connection string will be driven by the field below. " - if new_connection["connect_by_key"]: - url_override_help_text += "Only user name will be passed per the relevant fields above." - else: - url_override_help_text += "Only user name and password will be passed per the relevant fields above." - - def on_connect_by_url_change(): - value = st.session_state.connect_by_url_toggle - st.session_state.disable_url_widgets = value - - new_connection["connect_by_url"] = url_override_toogle_container.toggle( - "URL override", - value=connect_by_url, - key="connect_by_url_toggle", - help=url_override_help_text, - on_change=on_connect_by_url_change, - ) - - if new_connection["connect_by_url"]: - connection_string = connection_service.form_overwritten_connection_url(new_connection) - connection_string_beginning, connection_string_end = connection_string.split("@", 1) - connection_string_header = connection_string_beginning + "@" - connection_string_header = connection_string_header.replace("%3E", ">") - connection_string_header = connection_string_header.replace("%3C", "<") - - if not url: - url = connection_string_end - - new_connection["url"] = bottom_right_column.text_input( - label="URL Suffix", - max_chars=200, - value=url, - help="Provide a connection string directly. This will override connection parameters if the 'Connect by URL' switch is set.", - ) - - bottom_left_column.text_input(label="URL Prefix", value=connection_string_header, disabled=True) - - bottom_left_column.markdown("

 
", unsafe_allow_html=True) - - testgen.flex_row_end(button_right_column) - submit = button_right_column.button( - "Save" if mode == "edit" else "Add Connection", - disabled=authentication_service.current_user_has_read_role(), - ) - - if submit: - if not new_connection["password"] and not new_connection["connect_by_key"]: - st.error("Enter a valid password.") - else: - if uploaded_file: - new_connection["private_key"] = uploaded_file.getvalue().decode("utf-8") - - if mode == "edit": - connection_service.edit_connection(new_connection) - else: - connection_service.add_connection(new_connection) - success_message = ( - "Changes have been saved successfully. " - if mode == "edit" - else "New connection added successfully. " - ) - st.success(success_message) - time.sleep(1) - st.rerun() - - test_connection = button_left_column.button("Test Connection") - - if test_connection: - single_element_container = connection_status_wrapper.empty() - single_element_container.info("Connecting ...") - connection_status = self.test_connection(new_connection) - - with single_element_container.container(): - renderer = { - True: st.success, - False: st.error, - }[connection_status.successful] - - renderer(connection_status.message) - if not connection_status.successful and connection_status.details: - st.caption("Connection Error Details") - - with st.container(border=True): - st.markdown(connection_status.details) - else: - # This is needed to fix a strange bug in Streamlit when using dialog + input fields + button - # If an input field is changed and the button is clicked immediately (without unfocusing the input first), - # two fragment reruns happen successively, one for unfocusing the input and the other for clicking the button - # Some or all (it seems random) of the input fields disappear when this happens - time.sleep(0.1) - - def test_connection(self, connection: dict) -> "ConnectionStatus": - if connection["connect_by_key"] and connection["connection_id"] is None: - return ConnectionStatus( - message="Please add the connection before testing it (so that we can get your private key file).", - successful=False, - ) - - empty_cache() - try: - sql_query = "select 1;" - results = db.retrieve_target_db_data( - connection["sql_flavor"], - connection["project_host"], - connection["project_port"], - connection["project_db"], - connection["project_user"], - connection["password"], - connection["url"], - connection["connect_by_url"], - connection["connect_by_key"], - connection["private_key"], - connection["private_key_passphrase"], - sql_query, - ) - connection_successful = len(results) == 1 and results[0][0] == 1 - - if not connection_successful: - return ConnectionStatus(message="Error completing a query to the database server.", successful=False) - - qc_error_message = "The connection was successful, but there is an issue with the QC Utility Schema" - try: - qc_results = connection_service.test_qc_connection(connection["project_code"], connection) - if not all(qc_results): - return ConnectionStatus( - message=qc_error_message, - details=f"QC Utility Schema confirmation failed. details: {qc_results}", - successful=False, - ) - return ConnectionStatus(message="The connection was successful.", successful=True) - except Exception as error: - return ConnectionStatus(message=qc_error_message, details=error.args[0], successful=False) - except Exception as error: - return ConnectionStatus(message="Error attempting the Connection.", details=error.args[0], successful=False) - - -@dataclasses.dataclass(frozen=True, slots=True) -class ConnectionStatus: - message: str - successful: bool - details: str | None = dataclasses.field(default=None) diff --git a/testgen/ui/views/connections/__init__.py b/testgen/ui/views/connections/__init__.py new file mode 100644 index 0000000..76f8c37 --- /dev/null +++ b/testgen/ui/views/connections/__init__.py @@ -0,0 +1,3 @@ +from testgen.ui.views.connections.page import ConnectionsPage +from testgen.ui.views.connections.models import ConnectionStatus +from testgen.ui.views.connections.forms import BaseConnectionForm, PasswordConnectionForm, KeyPairConnectionForm diff --git a/testgen/ui/views/connections/forms.py b/testgen/ui/views/connections/forms.py new file mode 100644 index 0000000..942c42a --- /dev/null +++ b/testgen/ui/views/connections/forms.py @@ -0,0 +1,250 @@ +# type: ignore +import base64 +import typing + +from pydantic import computed_field +import streamlit as st +from streamlit.delta_generator import DeltaGenerator + +from testgen.ui.components import widgets as testgen +from testgen.ui.forms import BaseForm, Field, ManualRender +from testgen.ui.services import connection_service + +SQL_FLAVORS = ["redshift", "snowflake", "mssql", "postgresql"] +SQLFlavor = typing.Literal[*SQL_FLAVORS] + + +class BaseConnectionForm(BaseForm, ManualRender): + connection_name: str = Field( + default="", + min_length=3, + max_length=40, + st_kwargs_max_chars=40, + st_kwargs_label="Connection Name", + st_kwargs_help="Your name for this connection. Can be any text.", + ) + project_host: str = Field( + default="", + max_length=250, + st_kwargs_max_chars=250, + st_kwargs_label="Host", + ) + project_port: str = Field(default="", max_length=5, st_kwargs_max_chars=5, st_kwargs_label="Port") + project_db: str = Field( + default="", + max_length=100, + st_kwargs_max_chars=100, + st_kwargs_label="Database", + st_kwargs_help="The name of the database defined on your host where your schemas and tables is present.", + ) + project_user: str = Field( + default="", + max_length=50, + st_kwargs_max_chars=50, + st_kwargs_label="User", + st_kwargs_help="Username to connect to your database.", + ) + connect_by_url: bool = Field( + default=False, + st_kwargs_label="URL override", + st_kwargs_help=( + "If this switch is set to on, the connection string will be driven by the field below. " + "Only user name and password will be passed per the relevant fields above." + ), + ) + url_prefix: str = Field( + default="", + readOnly=True, + st_kwargs_label="URL Prefix", + ) + url: str = Field( + default="", + max_length=200, + st_kwargs_label="URL Suffix", + st_kwargs_max_chars=200, + st_kwargs_help=( + "Provide a connection string directly. This will override connection parameters if " + "the 'Connect by URL' switch is set." + ), + ) + max_threads: int = Field( + default=4, + ge=1, + le=8, + st_kwargs_min_value=1, + st_kwargs_max_value=8, + st_kwargs_label="Max Threads (Advanced Tuning)", + st_kwargs_help=( + "Maximum number of concurrent threads that run tests. Default values should be retained unless " + "test queries are failing." + ), + ) + max_query_chars: int = Field( + default=10000, + ge=500, + le=14000, + st_kwargs_label="Max Expression Length (Advanced Tuning)", + st_kwargs_min_value=500, + st_kwargs_max_value=14000, + st_kwargs_help=( + "Some tests are consolidated into queries for maximum performance. Default values should be retained " + "unless test queries are failing." + ), + ) + project_qc_schema: str = Field( + default="qc", + max_length=50, + st_kwargs_label="QC Utility Schema", + st_kwargs_max_chars=50, + st_kwargs_help="The name of the schema on your database that will contain TestGen's profiling functions.", + ) + + connection_id: int | None = Field(default=None) + + sql_flavor: SQLFlavor = Field( + ..., + st_kwargs_label="SQL Flavor", + st_kwargs_options=SQL_FLAVORS, + st_kwargs_help=( + "The type of database server that you will connect to. This determines TestGen's drivers and SQL dialect." + ), + ) + + def form_key(self): + return f"connection_form:{self.connection_id or 'new'}" + + def render_input_ui(self, container: DeltaGenerator, data: dict) -> typing.Self: + main_fields_container, optional_fields_container = container.columns([0.7, 0.3]) + + if self.get_field_value("connect_by_url", latest=True): + self.disable("project_host") + self.disable("project_port") + self.disable("project_db") + + self.render_field("sql_flavor", container=main_fields_container) + self.render_field("connection_name", container=main_fields_container) + host_field_container, port_field_container = main_fields_container.columns([0.6, 0.4]) + self.render_field("project_host", container=host_field_container) + self.render_field("project_port", container=port_field_container) + + self.render_field("project_db", container=main_fields_container) + self.render_field("project_user", container=main_fields_container) + self.render_field("project_qc_schema", container=optional_fields_container) + self.render_field("max_threads", container=optional_fields_container) + self.render_field("max_query_chars", container=optional_fields_container) + + self.render_extra(container, main_fields_container, optional_fields_container, data) + + testgen.divider(margin_top=8, margin_bottom=8, container=container) + + self.url_prefix = data.get("url_prefix", "") + self.render_field("connect_by_url") + if self.connect_by_url: + connection_string = connection_service.form_overwritten_connection_url(data) + connection_string_beginning, connection_string_end = connection_string.split("@", 1) + + self.update_field_value( + "url_prefix", + f"{connection_string_beginning}@".replace("%3E", ">").replace("%3C", "<"), + ) + if not data.get("url", ""): + self.update_field_value("url", connection_string_end) + + url_override_left_column, url_override_right_column = st.columns([0.25, 0.75]) + self.render_field("url_prefix", container=url_override_left_column) + self.render_field("url", container=url_override_right_column) + + return self + + def render_extra( + self, + container: DeltaGenerator, + left_fields_container: DeltaGenerator, + right_fields_container: DeltaGenerator, + data: dict, + ) -> None: + ... + + @staticmethod + def for_flavor(flavor: SQLFlavor) -> type["BaseConnectionForm"]: + return { + "redshift": PasswordConnectionForm, + "snowflake": KeyPairConnectionForm, + "mssql": PasswordConnectionForm, + "postgresql": PasswordConnectionForm, + }[flavor] + + +class PasswordConnectionForm(BaseConnectionForm): + password: str = Field( + default="", + max_length=50, + writeOnly=True, + st_kwargs_label="Password", + st_kwargs_max_chars=50, + st_kwargs_help="Password to connect to your database.", + ) + + def render_extra( + self, + container: DeltaGenerator, + left_fields_container: DeltaGenerator, + right_fields_container: DeltaGenerator, + data: dict, + ) -> None: + self.render_field("password", left_fields_container) + + +class KeyPairConnectionForm(PasswordConnectionForm): + connect_by_key: bool = Field(default=None) + private_key_passphrase: str = Field( + default="", + max_length=200, + writeOnly=True, + st_kwargs_max_chars=200, + st_kwargs_help=( + "Passphrase used while creating the private Key (leave empty if not applicable)" + ), + st_kwargs_label="Private Key Passphrase", + ) + private_key_inner: str = Field( + default="", + format="base64", + st_kwargs_label="Upload private key (rsa_key.p8)", + ) + + @computed_field + @property + def private_key(self) -> str: + if not self.private_key_inner: + return "" + return base64.b64decode(self.private_key_inner).decode("utf-8") + + def render_extra( + self, + container: DeltaGenerator, + left_fields_container: DeltaGenerator, + right_fields_container: DeltaGenerator, + data: dict, + ) -> None: + testgen.divider(margin_top=8, margin_bottom=8, container=container) + + connect_by_key = self.connect_by_key + if connect_by_key is None: + connect_by_key = self.get_field_value("connect_by_key") + + connection_option: typing.Literal["Connect by Password", "Connect by Key-Pair"] = container.radio( + "Connection options", + options=["Connect by Password", "Connect by Key-Pair"], + index=1 if connect_by_key else 0, + horizontal=True, + help="Connection strategy", + key=self.get_field_key("connection_option"), + ) + self.update_field_value("connect_by_key", connection_option == "Connect by Key-Pair") + + if connection_option == "Connect by Password": + self.render_field("password", container) + else: + self.render_field("private_key_passphrase", container) + self.render_field("private_key_inner", container) diff --git a/testgen/ui/views/connections/models.py b/testgen/ui/views/connections/models.py new file mode 100644 index 0000000..90f16ca --- /dev/null +++ b/testgen/ui/views/connections/models.py @@ -0,0 +1,8 @@ +import dataclasses + + +@dataclasses.dataclass(frozen=True, slots=True) +class ConnectionStatus: + message: str + successful: bool + details: str | None = dataclasses.field(default=None) diff --git a/testgen/ui/views/connections/page.py b/testgen/ui/views/connections/page.py new file mode 100644 index 0000000..770b764 --- /dev/null +++ b/testgen/ui/views/connections/page.py @@ -0,0 +1,444 @@ +from functools import partial +import logging +import os +import time +import typing + +from pydantic import ValidationError +import streamlit as st +from streamlit.delta_generator import DeltaGenerator +import streamlit_pydantic as sp + +import testgen.ui.services.database_service as db +from testgen.ui.services import table_group_service +from testgen.commands.run_setup_profiling_tools import get_setup_profiling_tools_queries +from testgen.commands.run_profiling_bridge import run_profiling_in_background +from testgen.common.database.database_service import empty_cache +from testgen.ui.components import widgets as testgen +from testgen.ui.views.connections.forms import BaseConnectionForm +from testgen.ui.views.table_groups.forms import TableGroupForm +from testgen.ui.navigation.menu import MenuItem +from testgen.ui.navigation.page import Page +from testgen.ui.services import connection_service +from testgen.ui.session import session, temp_value +from testgen.ui.views.connections.models import ConnectionStatus + +LOG = logging.getLogger("testgen") + + +class ConnectionsPage(Page): + path = "connections" + can_activate: typing.ClassVar = [ + lambda: session.authentication_status, + ] + menu_item = MenuItem(icon="database", label="Data Configuration", order=4) + + def render(self, project_code: str, **_kwargs) -> None: + dataframe = connection_service.get_connections(project_code) + connection = dataframe.iloc[1] + has_table_groups = ( + len(connection_service.get_table_group_names_by_connection([connection["connection_id"]]) or []) > 0 + ) + + testgen.page_header( + "Connection", + "https://docs.datakitchen.io/article/dataops-testgen-help/connect-your-database", + ) + + _, actions_column = st.columns([.1, .9]) + testgen.flex_row_end(actions_column) + + with st.container(border=True): + self.show_connection_form(connection.to_dict(), "edit", project_code) + + if has_table_groups: + with actions_column: + testgen.link( + href="connections:table-groups", + params={"connection_id": str(connection["connection_id"])}, + label="Table Groups", + right_icon="chevron_right", + style="margin-left: auto;", + ) + else: + with actions_column: + testgen.button( + type_="stroked", + color="basic", + label="Setup Table Groups", + style="background: white;", + width=200, + on_click=lambda: self.setup_data_configuration(project_code, connection.to_dict()), + ) + + def show_connection_form(self, selected_connection: dict, mode, project_code) -> None: + connection = selected_connection or {} + connection_id = connection.get("connection_id", None) + sql_flavor = connection.get("sql_flavor", "postgresql") + data = {} + + try: + form = BaseConnectionForm.for_flavor(sql_flavor).model_construct(sql_flavor=sql_flavor) + if connection: + connection["password"] = connection["password"] or "" + form = BaseConnectionForm.for_flavor(sql_flavor)(**connection) + + sql_flavor = form.get_field_value("sql_flavor", latest=True) or sql_flavor + if form.sql_flavor != sql_flavor: + form = BaseConnectionForm.for_flavor(sql_flavor)(sql_flavor=sql_flavor) + + form_errors_container = st.empty() + data = sp.pydantic_input( + key=f"connection_form:{connection_id or 'new'}", + model=form, # type: ignore + ) + data.update({ + "project_code": project_code, + }) + if "private_key" not in data: + data.update({ + "connect_by_key": False, + "private_key_passphrase": None, + "private_key": None, + }) + + try: + BaseConnectionForm.for_flavor(sql_flavor).model_validate(data) + except ValidationError as error: + form_errors_container.warning("\n".join([ + f"- {field_label}: {err['msg']}" for err in error.errors() + if (field_label := TableGroupForm.get_field_label(str(err['loc'][0]))) + ])) + except Exception: + LOG.exception("unexpected form validation error") + st.error("Unexpected error displaying the form. Try again") + + test_button_column, config_qc_column, _, save_button_column = st.columns([.2, .2, .4, .2]) + is_submitted, set_submitted = temp_value(f"connection_form-{connection_id or 'new'}:submit") + get_connection_status, set_connection_status = temp_value( + f"connection_form-{connection_id or 'new'}:test_conn" + ) + + with save_button_column: + testgen.button( + type_="flat", + label="Save", + key=f"connection_form:{connection_id or 'new'}:submit", + on_click=lambda: set_submitted(True), + ) + + with test_button_column: + testgen.button( + type_="stroked", + color="basic", + label="Test Connection", + key=f"connection_form:{connection_id or 'new'}:test", + on_click=lambda: set_connection_status(self.test_connection(data)), + ) + + with config_qc_column: + testgen.button( + type_="stroked", + color="basic", + label="Configure QC Utility Schema", + key=f"connection_form:{connection_id or 'new'}:config-qc-schema", + tooltip="Creates the required Utility schema and related functions in the target database", + on_click=lambda: self.create_qc_schema_dialog(connection) + ) + + if (connection_status := get_connection_status()): + single_element_container = st.empty() + single_element_container.info("Connecting ...") + + with single_element_container.container(): + renderer = { + True: st.success, + False: st.error, + }[connection_status.successful] + + renderer(connection_status.message) + if not connection_status.successful and connection_status.details: + st.caption("Connection Error Details") + + with st.container(border=True): + st.markdown(connection_status.details) + + connection_status = None + else: + # This is needed to fix a strange bug in Streamlit when using dialog + input fields + button + # If an input field is changed and the button is clicked immediately (without unfocusing the input first), + # two fragment reruns happen successively, one for unfocusing the input and the other for clicking the button + # Some or all (it seems random) of the input fields disappear when this happens + time.sleep(0.1) + + if is_submitted(): + if not data.get("password") and not data.get("connect_by_key"): + st.error("Enter a valid password.") + else: + if data.get("private_key"): + data["private_key"] = data["private_key"].getvalue().decode("utf-8") + + connection_service.edit_connection(data) + st.success("Changes have been saved successfully.") + time.sleep(1) + st.rerun() + + def test_connection(self, connection: dict) -> "ConnectionStatus": + if connection["connect_by_key"] and connection["connection_id"] is None: + return ConnectionStatus( + message="Please add the connection before testing it (so that we can get your private key file).", + successful=False, + ) + + empty_cache() + try: + sql_query = "select 1;" + results = db.retrieve_target_db_data( + connection["sql_flavor"], + connection["project_host"], + connection["project_port"], + connection["project_db"], + connection["project_user"], + connection["password"], + connection["url"], + connection["connect_by_url"], + connection["connect_by_key"], + connection["private_key"], + connection["private_key_passphrase"], + sql_query, + ) + connection_successful = len(results) == 1 and results[0][0] == 1 + + if not connection_successful: + return ConnectionStatus(message="Error completing a query to the database server.", successful=False) + + qc_error_message = "The connection was successful, but there is an issue with the QC Utility Schema" + try: + qc_results = connection_service.test_qc_connection(connection["project_code"], connection) + if not all(qc_results): + return ConnectionStatus( + message=qc_error_message, + details=f"QC Utility Schema confirmation failed. details: {qc_results}", + successful=False, + ) + return ConnectionStatus(message="The connection was successful.", successful=True) + except Exception as error: + return ConnectionStatus(message=qc_error_message, details=error.args[0], successful=False) + except Exception as error: + return ConnectionStatus(message="Error attempting the Connection.", details=error.args[0], successful=False) + + @st.dialog(title="Configure QC Utility Schema") + def create_qc_schema_dialog(self, selected_connection): + connection_id = selected_connection["connection_id"] + project_qc_schema = selected_connection["project_qc_schema"] + sql_flavor = selected_connection["sql_flavor"] + user = selected_connection["project_user"] + + create_qc_schema = st.toggle("Create QC Utility Schema", value=True) + grant_privileges = st.toggle("Grant access privileges to TestGen user", value=True) + + user_role = None + + # TODO ALEX: This textbox may be needed if we want to grant permissions to user role + # if sql_flavor == "snowflake": + # user_role_textbox_label = f"Primary role for database user {user}" + # user_role = st.text_input(label=user_role_textbox_label, max_chars=100) + + admin_credentials_expander = st.expander("Admin credential options", expanded=True) + with admin_credentials_expander: + admin_connection_option_index = 0 + admin_connection_options = ["Do not use admin credentials", "Use admin credentials with Password"] + if sql_flavor == "snowflake": + admin_connection_options.append("Use admin credentials with Key-Pair") + + admin_connection_option = st.radio( + "Admin credential options", + label_visibility="hidden", + options=admin_connection_options, + index=admin_connection_option_index, + horizontal=True, + ) + + st.markdown("

 
", unsafe_allow_html=True) + + db_user = None + db_password = None + admin_private_key_passphrase = None + admin_private_key = None + if admin_connection_option == admin_connection_options[0]: + st.markdown(":orange[User created in the connection dialog will be used.]") + else: + db_user = st.text_input(label="Admin db user", max_chars=40) + if admin_connection_option == admin_connection_options[1]: + db_password = st.text_input( + label="Admin db password", max_chars=40, type="password" + ) + st.markdown(":orange[Note: Admin credentials are not stored, are only used for this operation.]") + + if len(admin_connection_options) > 2 and admin_connection_option == admin_connection_options[2]: + admin_private_key_passphrase = st.text_input( + label="Private Key Passphrase", + key="create-qc-schema-private-key-password", + type="password", + max_chars=200, + help="Passphrase used while creating the private Key (leave empty if not applicable)", + ) + + admin_uploaded_file = st.file_uploader("Upload private key (rsa_key.p8)", key="admin-uploaded-file") + if admin_uploaded_file: + admin_private_key = admin_uploaded_file.getvalue().decode("utf-8") + + st.markdown(":orange[Note: Admin credentials are not stored, are only used for this operation.]") + + submit = st.button("Update Configuration") + + if submit: + empty_cache() + script_expander = st.expander("Script Details") + + operation_status = st.empty() + operation_status.info(f"Configuring QC Utility Schema '{project_qc_schema}'...") + + try: + skip_granting_privileges = not grant_privileges + queries = get_setup_profiling_tools_queries(sql_flavor, create_qc_schema, skip_granting_privileges, project_qc_schema, user, user_role) + with script_expander: + st.code( + os.linesep.join(queries), + language="sql", + line_numbers=True) + + connection_service.create_qc_schema( + connection_id, + create_qc_schema, + db_user if db_user else None, + db_password if db_password else None, + skip_granting_privileges, + admin_private_key_passphrase=admin_private_key_passphrase, + admin_private_key=admin_private_key, + user_role=user_role, + ) + operation_status.empty() + operation_status.success("Operation has finished successfully.") + + except Exception as e: + operation_status.empty() + operation_status.error("Error configuring QC Utility Schema.") + error_message = e.args[0] + st.text_area("Error Details", value=error_message) + + @st.dialog(title="Data Configuration Setup") + def setup_data_configuration(self, project_code: str, connection: dict) -> None: + will_run_profiling = st.session_state.get("connection_form-new:run-profiling-toggle", True) + testgen.wizard( + key="connections:setup-wizard", + steps=[ + testgen.WizardStep( + title="Create a Table Group", + body=partial(self.create_table_group_step, project_code, connection), + ), + testgen.WizardStep( + title="Run Profiling", + body=self.run_data_profiling_step, + ), + ], + on_complete=self.execute_setup, + complete_label="Save & Run Profiling" if will_run_profiling else "Finish Setup", + navigate_to=st.session_state.pop("setup_data_config:navigate-to", None), + navigate_to_args=st.session_state.pop("setup_data_config:navigate-to-args", {}), + ) + + def create_table_group_step(self, project_code: str, connection: dict) -> tuple[dict | None, bool]: + is_valid: bool = True + data: dict = {} + + try: + form = TableGroupForm.model_construct() + form_errors_container = st.empty() + data = sp.pydantic_input(key="table_form:new", model=form) # type: ignore + + try: + TableGroupForm.model_validate(data) + form_errors_container.empty() + data.update({"project_code": project_code, "connection_id": connection["connection_id"]}) + except ValidationError as error: + form_errors_container.warning("\n".join([ + f"- {field_label}: {err['msg']}" for err in error.errors() + if (field_label := TableGroupForm.get_field_label(str(err['loc'][0]))) + ])) + is_valid = False + except Exception: + LOG.exception("unexpected form validation error") + st.error("Unexpected error displaying the form. Try again") + is_valid = False + + return data, is_valid + + def run_data_profiling_step(self, step_0: testgen.WizardStep | None = None) -> tuple[bool, bool]: + if not step_0 or not step_0.results: + st.error("A table group is required to complete this step.") + return False, False + + run_profiling = True + profiling_message = "Profiling will be performed in a background process." + table_group = step_0.results + + with st.container(): + run_profiling = st.checkbox( + label=f"Execute profiling for the table group **{table_group['table_groups_name']}**?", + key="connection_form-new:run-profiling-toggle", + value=True, + ) + if not run_profiling: + profiling_message = ( + "Profiling will be skipped. You can run this step later from the Profiling Runs page." + ) + st.markdown(f":material/info: _{profiling_message}_") + + return run_profiling, True + + def execute_setup( + self, + container: DeltaGenerator, + step_0: testgen.WizardStep[dict], + step_1: testgen.WizardStep[bool], + ) -> bool: + table_group = step_0.results + table_group_name: str = table_group["table_groups_name"] + should_run_profiling: bool = step_1.results + + with container.container(): + status_container = st.empty() + + try: + status_container.info(f"Creating table group **{table_group_name.strip()}**.") + table_group_id = table_group_service.add(table_group) + TableGroupForm.model_construct().reset_cache() + except Exception as err: + status_container.error(f"Error creating table group: {err!s}.") + + if should_run_profiling: + try: + status_container.info("Starting profiling run ...") + run_profiling_in_background(table_group_id) + status_container.success(f"Profiling run started for table group **{table_group_name.strip()}**.") + except Exception as err: + status_container.error(f"Profiling run encountered errors: {err!s}.") + + _, link_column = st.columns([.7, .3]) + with link_column: + testgen.button( + type_="stroked", + color="primary", + label="Go to Profiling Runs", + icon="chevron_right", + key="setup_data_config:keys:go-to-runs", + on_click=lambda: ( + st.session_state.__setattr__("setup_data_config:navigate-to", "profiling-runs") + or st.session_state.__setattr__("setup_data_config:navigate-to-args", { + "table_group": table_group_id + }) + ), + ) + + return not should_run_profiling diff --git a/testgen/ui/views/table_groups/__init__.py b/testgen/ui/views/table_groups/__init__.py new file mode 100644 index 0000000..99df82c --- /dev/null +++ b/testgen/ui/views/table_groups/__init__.py @@ -0,0 +1,2 @@ +from testgen.ui.views.table_groups.page import TableGroupsPage +# from testgen.ui.views.table_groups.forms import ... diff --git a/testgen/ui/views/table_groups/forms.py b/testgen/ui/views/table_groups/forms.py new file mode 100644 index 0000000..9087307 --- /dev/null +++ b/testgen/ui/views/table_groups/forms.py @@ -0,0 +1,170 @@ +# type: ignore +import typing + +from streamlit.delta_generator import DeltaGenerator + +from testgen.ui.components import widgets as testgen +from testgen.ui.forms import BaseForm, Field, ManualRender + +SQLFlavor = typing.Literal["redshift", "snowflake", "mssql", "postgresql"] + + +class TableGroupForm(BaseForm, ManualRender): + table_groups_name: str = Field( + default="", + min_length=1, + max_length=40, + st_kwargs_label="Name", + st_kwargs_max_chars=40, + st_kwargs_help="A unique name to describe the table group", + ) + profiling_include_mask: str = Field( + default="%", + max_length=40, + st_kwargs_label="Tables to Include Mask", + st_kwargs_max_chars=40, + st_kwargs_help="A SQL filter supported by your database's LIKE operator for table names to include", + ) + profiling_exclude_mask: str = Field( + default="tmp%", + st_kwargs_label="Tables to Exclude Mask", + st_kwargs_max_chars=40, + st_kwargs_help="A SQL filter supported by your database's LIKE operator for table names to exclude", + ) + profiling_table_set: str = Field( + default="", + st_kwargs_label="Explicit Table List", + st_kwargs_max_chars=2000, + st_kwargs_help="A list of specific table names to include, separated by commas", + ) + table_group_schema: str = Field( + default="", + min_length=1, + max_length=40, + st_kwargs_label="Schema", + st_kwargs_max_chars=40, + st_kwargs_help="The database schema containing the tables in the Table Group", + ) + profile_id_column_mask: str = Field( + default="%_id", + st_kwargs_label="Profiling ID column mask", + st_kwargs_max_chars=40, + st_kwargs_help="A SQL filter supported by your database's LIKE operator representing ID columns (optional)", + ) + profile_sk_column_mask: str = Field( + default="%_sk", + st_kwargs_label="Profiling Surrogate Key column mask", + st_kwargs_max_chars=40, + st_kwargs_help="A SQL filter supported by your database's LIKE operator representing surrogate key columns (optional)", + ) + profiling_delay_days: int = Field( + default=0, + st_kwargs_label="Min Profiling Age, Days", + st_kwargs_min_value=0, + st_kwargs_max_value=999, + st_kwargs_help="The number of days to wait before new profiling will be available to generate tests", + ) + profile_use_sampling: bool = Field( + default=True, + st_kwargs_label="Use profile sampling", + st_kwargs_help="Toggle on to base profiling on a sample of records instead of the full table", + ) + profile_sample_percent: int = Field( + default=30, + st_kwargs_label="Sample percent", + st_kwargs_min_value=1, + st_kwargs_max_value=100, + st_kwargs_help="Percent of records to include in the sample, unless the calculated count falls below the specified minimum.", + ) + profile_sample_min_count: int = Field( + default=15000, + st_kwargs_label="Min Sample Record Count", + st_kwargs_min_value=1, + st_kwargs_max_value=1000000, + st_kwargs_help="The minimum number of records to be included in any sample (if available)", + ) + data_source: str = Field( + default="", + st_kwargs_label="Data Source", + st_kwargs_max_chars=40, + st_kwargs_help="Original source of all tables in this dataset. This can be overridden at the table level. (Optional)", + ) + source_system: str = Field( + default="", + st_kwargs_label="System of Origin", + st_kwargs_max_chars=40, + st_kwargs_help="Enterprise system source for all tables in this dataset. " + "This can be overridden at the table level. (Optional)", + ) + business_domain: str = Field( + default="", + st_kwargs_label="Business Domain", + st_kwargs_max_chars=40, + st_kwargs_help="Business division responsible for all tables in this dataset. " + "e.g. Finance, Sales, Manufacturing. (Optional)", + ) + data_location: str = Field( + default="", + st_kwargs_label="Location", + st_kwargs_max_chars=40, + st_kwargs_help="Physical or virtual location of all tables in this dataset. " + "e.g. Headquarters, Cloud, etc. (Optional)", + ) + transform_level: str = Field( + default="", + st_kwargs_label="Transform Level", + st_kwargs_max_chars=40, + st_kwargs_help="Data warehouse processing layer. " + "Indicates the processing stage: e.g. Raw, Conformed, Processed, Reporting. (Optional)", + ) + source_process: str = Field( + default="", + st_kwargs_label="Source Process", + st_kwargs_max_chars=40, + st_kwargs_help="The process, program or data flow that produced this data. (Optional)", + ) + stakeholder_group: str = Field( + default="", + st_kwargs_label="Stakeholder Group", + st_kwargs_max_chars=40, + st_kwargs_help="Designator for data owners or stakeholders who are responsible for this data. (Optional)", + ) + table_group_id: int | None = Field(default=None) + + def form_key(self): + return f"table_group_form:{self.table_group_id or 'new'}" + + def render_input_ui(self, container: DeltaGenerator, data: dict) -> typing.Self: + left_column, right_column = container.columns([.5, .5]) + + self.render_field("table_groups_name", left_column) + self.render_field("profiling_include_mask", left_column) + self.render_field("profiling_exclude_mask", left_column) + self.render_field("profiling_table_set", left_column) + + self.render_field("table_group_schema", right_column) + self.render_field("profile_id_column_mask", right_column) + self.render_field("profile_sk_column_mask", right_column) + self.render_field("profiling_delay_days", right_column) + + self.render_field("profile_use_sampling", container) + profile_sampling_expander = container.expander("Sampling Parameters", expanded=False) + with profile_sampling_expander: + expander_left_column, expander_right_column = profile_sampling_expander.columns([0.50, 0.50]) + self.render_field("profile_sample_percent", expander_left_column) + self.render_field("profile_sample_min_count", expander_right_column) + + provenance_expander = container.expander("Data Provenance (Optional)", expanded=False) + with provenance_expander: + provenance_left_column, provenance_right_column = provenance_expander.columns([0.50, 0.50]) + + self.render_field("data_source", provenance_left_column) + self.render_field("source_system", provenance_left_column) + self.render_field("business_domain", provenance_left_column) + self.render_field("data_location", provenance_left_column) + + self.render_field("transform_level", provenance_right_column) + self.render_field("source_process", provenance_right_column) + self.render_field("stakeholder_group", provenance_right_column) + + return self diff --git a/testgen/ui/views/table_groups.py b/testgen/ui/views/table_groups/page.py similarity index 99% rename from testgen/ui/views/table_groups.py rename to testgen/ui/views/table_groups/page.py index e62787c..7b9e8a9 100644 --- a/testgen/ui/views/table_groups.py +++ b/testgen/ui/views/table_groups/page.py @@ -29,7 +29,7 @@ class TableGroupsPage(Page): def render(self, connection_id: str, **_kwargs) -> None: connection = connection_service.get_by_id(connection_id, hide_passwords=False) if not connection: - self.router.navigate_with_warning( + return self.router.navigate_with_warning( f"Connection with ID '{connection_id}' does not exist. Redirecting to list of Connections ...", "connections", ) @@ -40,7 +40,7 @@ def render(self, connection_id: str, **_kwargs) -> None: testgen.page_header( "Table Groups", "https://docs.datakitchen.io/article/dataops-testgen-help/create-a-table-group", - breadcrumbs=[ + breadcrumbs=[ # type: ignore { "label": "Connections", "path": "connections", "params": { "project_code": project_code } }, { "label": connection["connection_name"] }, ], diff --git a/testgen/utils/singleton.py b/testgen/utils/singleton.py index 0c87de3..722f7f2 100644 --- a/testgen/utils/singleton.py +++ b/testgen/utils/singleton.py @@ -2,9 +2,9 @@ class SingletonType(type): - _instances: typing.ClassVar[dict[type, object]] = {} + _instances: typing.ClassVar[dict[type, typing.Any]] = {} - def __call__(cls, *args, **kwargs) -> typing.Any: + def __call__(cls, *args, **kwargs): if cls not in cls._instances: cls._instances[cls] = super().__call__(*args, **kwargs) return cls._instances[cls]