diff --git a/temporalio/testing/_workflow.py b/temporalio/testing/_workflow.py index 3015a86d..c971f7fe 100644 --- a/temporalio/testing/_workflow.py +++ b/temporalio/testing/_workflow.py @@ -91,6 +91,7 @@ async def start_local( download_dest_dir: Optional[str] = None, ui: bool = False, runtime: Optional[temporalio.runtime.Runtime] = None, + search_attributes: Sequence[temporalio.common.SearchAttributeKey] = [], dev_server_existing_path: Optional[str] = None, dev_server_database_filename: Optional[str] = None, dev_server_log_format: str = "pretty", @@ -138,6 +139,8 @@ async def start_local( needed. If unset, this is the system's temporary directory. ui: If ``True``, will start a UI in the dev server. runtime: Specific runtime to use or default if unset. + search_attributes: Search attributes to register with the dev + server. dev_server_existing_path: Existing path to the CLI binary. If present, no download will be attempted to fetch the binary. dev_server_database_filename: Path to the Sqlite database to use @@ -167,6 +170,14 @@ async def start_local( dev_server_log_level = "error" else: dev_server_log_level = "fatal" + # Add search attributes + if search_attributes: + new_args = [] + for attr in search_attributes: + new_args.append("--search-attribute") + new_args.append(f"{attr.name}={attr._metadata_type}") + new_args += dev_server_extra_args + dev_server_extra_args = new_args # Start CLI dev server runtime = runtime or temporalio.runtime.Runtime.default() server = await temporalio.bridge.testing.EphemeralServer.start_dev_server( diff --git a/tests/testing/test_workflow.py b/tests/testing/test_workflow.py index 22bdd818..7538cb29 100644 --- a/tests/testing/test_workflow.py +++ b/tests/testing/test_workflow.py @@ -12,11 +12,17 @@ Client, Interceptor, OutboundInterceptor, + RPCError, StartWorkflowInput, WorkflowFailureError, WorkflowHandle, ) -from temporalio.common import RetryPolicy +from temporalio.common import ( + RetryPolicy, + SearchAttributeKey, + SearchAttributePair, + TypedSearchAttributes, +) from temporalio.exceptions import ( ActivityError, ApplicationError, @@ -245,6 +251,70 @@ def assert_proper_error(err: Optional[BaseException]) -> None: assert_proper_error(err.value.cause) +async def test_search_attributes_on_dev_server( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip("Only testing for local dev server") + + # Search attributes + sa_prefix = f"{uuid.uuid4()}_" + text_attr = SearchAttributeKey.for_text(f"{sa_prefix}text") + keyword_attr = SearchAttributeKey.for_keyword(f"{sa_prefix}keyword") + keyword_list_attr = SearchAttributeKey.for_keyword_list(f"{sa_prefix}keyword_list") + int_attr = SearchAttributeKey.for_int(f"{sa_prefix}int") + float_attr = SearchAttributeKey.for_float(f"{sa_prefix}double") + bool_attr = SearchAttributeKey.for_bool(f"{sa_prefix}bool") + datetime_attr = SearchAttributeKey.for_datetime(f"{sa_prefix}datetime") + attrs = TypedSearchAttributes( + [ + SearchAttributePair(text_attr, "text1"), + SearchAttributePair(keyword_attr, "keyword1"), + SearchAttributePair( + keyword_list_attr, + ["keywordlist1", "keywordlist2"], + ), + SearchAttributePair(int_attr, 123), + SearchAttributePair(float_attr, 456.78), + SearchAttributePair(bool_attr, True), + SearchAttributePair( + datetime_attr, datetime(2001, 2, 3, 4, 5, 6, tzinfo=timezone.utc) + ), + ] + ) + + # Confirm that we can't start a workflow on existing environment + with pytest.raises(RPCError) as err: + await client.start_workflow( + "some-workflow", + id=f"wf-{uuid.uuid4()}", + task_queue=f"tq-{uuid.uuid4()}", + search_attributes=attrs, + ) + assert "no mapping defined" in str(err.value) + + # But we can in a new environment with the attrs set + async with await WorkflowEnvironment.start_local( + search_attributes=[ + text_attr, + keyword_attr, + keyword_list_attr, + int_attr, + float_attr, + bool_attr, + datetime_attr, + ] + ) as env: + handle = await env.client.start_workflow( + "some-workflow", + id=f"wf-{uuid.uuid4()}", + task_queue=f"tq-{uuid.uuid4()}", + search_attributes=attrs, + ) + desc = await handle.describe() + assert attrs == desc.typed_search_attributes + + def assert_timestamp_from_now( ts: Union[datetime, float], expected_from_now: float, max_delta: float = 30 ) -> None: