From 18f7078441e7d3fe02efb51f1175aa1595ab2df3 Mon Sep 17 00:00:00 2001 From: Josh Usiskin <56369778+jusiskin@users.noreply.github.com> Date: Thu, 5 Sep 2024 15:08:43 -0500 Subject: [PATCH] feat!: job/step/task API functions and session log assertions (#150) Signed-off-by: Josh Usiskin <56369778+jusiskin@users.noreply.github.com> BREAKING CHANGE: Job.lifecycle_status changed to enum and TaskStatus.UNKNOWN removed --- src/deadline_test_fixtures/__init__.py | 28 +- .../deadline/__init__.py | 38 +- .../deadline/resources.py | 553 ++++++++++- test/unit/deadline/test_resources.py | 919 +++++++++++++++++- 4 files changed, 1499 insertions(+), 39 deletions(-) diff --git a/src/deadline_test_fixtures/__init__.py b/src/deadline_test_fixtures/__init__.py index 9b9db10..ee6eae6 100644 --- a/src/deadline_test_fixtures/__init__.py +++ b/src/deadline_test_fixtures/__init__.py @@ -17,6 +17,10 @@ PipInstall, Queue, QueueFleetAssociation, + Session, + SessionLog, + Step, + Task, TaskStatus, ) from .fixtures import ( @@ -49,30 +53,34 @@ "CloudWatchLogEvent", "CodeArtifactRepositoryInfo", "CommandResult", - "DeadlineResources", "DeadlineClient", + "DeadlineResources", "DeadlineWorker", "DeadlineWorkerConfiguration", "DockerContainerWorker", "EC2InstanceWorker", - "WindowsInstanceWorkerBase", - "WindowsInstanceBuildWorker", - "PosixInstanceWorkerBase", - "PosixInstanceBuildWorker", "Farm", "Fleet", "Job", - "JobAttachmentSettings", "JobAttachmentManager", + "JobAttachmentSettings", "JobRunAsUser", + "OperatingSystem", "PipInstall", + "PosixInstanceBuildWorker", + "PosixInstanceWorkerBase", "PosixSessionUser", - "S3Object", - "ServiceModel", - "OperatingSystem", "Queue", "QueueFleetAssociation", + "S3Object", + "ServiceModel", + "Session", + "SessionLog", + "Step", + "Task", "TaskStatus", + "WindowsInstanceBuildWorker", + "WindowsInstanceWorkerBase", "bootstrap_resources", "codeartifact", "deadline_client", @@ -81,6 +89,6 @@ "install_service_model", "service_model", "version", - "worker", "worker_config", + "worker", ] diff --git a/src/deadline_test_fixtures/deadline/__init__.py b/src/deadline_test_fixtures/deadline/__init__.py index b2a5dca..e191e6e 100644 --- a/src/deadline_test_fixtures/deadline/__init__.py +++ b/src/deadline_test_fixtures/deadline/__init__.py @@ -1,15 +1,29 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +from .client import DeadlineClient from .resources import ( CloudWatchLogEvent, + DependencyCounts, Farm, Fleet, + FloatTaskParameterValue, + IntTaskParameterValue, Job, + JobLifecycleStatus, + LogConfiguration, + PathTaskParameterValue, Queue, QueueFleetAssociation, + Session, + SessionLifecycleStatus, + SessionLog, + Step, + StepLifecycleStatus, + StringTaskParameterValue, + Task, + TaskParameterValue, TaskStatus, ) -from .client import DeadlineClient from .worker import ( CommandResult, DeadlineWorker, @@ -29,17 +43,31 @@ "DeadlineClient", "DeadlineWorker", "DeadlineWorkerConfiguration", + "DependencyCounts", "DockerContainerWorker", "EC2InstanceWorker", - "WindowsInstanceWorkerBase", - "WindowsInstanceBuildWorker", - "PosixInstanceWorkerBase", - "PosixInstanceBuildWorker", "Farm", "Fleet", + "FloatTaskParameterValue", + "IntTaskParameterValue", "Job", + "JobLifecycleStatus", + "LogConfiguration", + "PathTaskParameterValue", "PipInstall", + "PosixInstanceBuildWorker", + "PosixInstanceWorkerBase", "Queue", "QueueFleetAssociation", + "Session", + "SessionLifecycleStatus", + "SessionLog", + "Step", + "StepLifecycleStatus", + "StringTaskParameterValue", + "Task", + "TaskParameterValue", "TaskStatus", + "WindowsInstanceBuildWorker", + "WindowsInstanceWorkerBase", ] diff --git a/src/deadline_test_fixtures/deadline/resources.py b/src/deadline_test_fixtures/deadline/resources.py index e9ffa66..6210709 100644 --- a/src/deadline_test_fixtures/deadline/resources.py +++ b/src/deadline_test_fixtures/deadline/resources.py @@ -4,9 +4,12 @@ import datetime import json import logging +import re +import time from dataclasses import asdict, dataclass, fields +from datetime import timedelta from enum import Enum -from typing import Any, Callable, Literal, TYPE_CHECKING, Optional +from typing import Any, Callable, Generator, Literal, TYPE_CHECKING, Optional, Union from botocore.client import BaseClient @@ -299,8 +302,34 @@ class StrEnum(str, Enum): pass +class JobLifecycleStatus(StrEnum): + ARCHIVED = "ARCHIVED" + CREATE_COMPLETE = "CREATE_COMPLETE" + CREATE_FAILED = "CREATE_FAILED" + CREATE_IN_PROGRESS = "CREATE_IN_PROGRESS" + UPDATE_FAILED = "UPDATE_FAILED" + UPDATE_IN_PROGRESS = "UPDATE_IN_PROGRESS" + UPDATE_SUCCEEDED = "UPDATE_SUCCEEDED" + UPLOAD_FAILED = "UPLOAD_FAILED" + UPLOAD_IN_PROGRESS = "UPLOAD_IN_PROGRESS" + + +class StepLifecycleStatus(StrEnum): + CREATE_COMPLETE = "CREATE_COMPLETE" + UPDATE_FAILED = "UPDATE_FAILED" + UPDATE_IN_PROGRESS = "UPDATE_IN_PROGRESS" + UPDATE_SUCCEEDED = "UPDATE_SUCCEEDED" + + +class SessionLifecycleStatus(StrEnum): + ENDED = "ENDED" + STARTED = "STARTED" + UPDATE_FAILED = "UPDATE_FAILED" + UPDATE_IN_PROGRESS = "UPDATE_IN_PROGRESS" + UPDATE_SUCCEEDED = "UPDATE_SUCCEEDED" + + class TaskStatus(StrEnum): - UNKNOWN = "UNKNOWN" PENDING = "PENDING" READY = "READY" RUNNING = "RUNNING" @@ -324,6 +353,37 @@ class TaskStatus(StrEnum): ) +@dataclass +class IpAddresses: + ip_v4_addresses: list[str] | None = None + ip_v6_addresses: list[str] | None = None + + @staticmethod + def from_api_response(response: dict[str, Any]) -> IpAddresses: + return IpAddresses( + ip_v4_addresses=response.get("ipV4Addresses", None), + ip_v6_addresses=response.get("ipV6Addresses", None), + ) + + +@dataclass +class WorkerHostProperties: + ec2_instance_arn: str | None = None + ec2_instance_type: str | None = None + host_name: str | None = None + ip_addresses: IpAddresses | None = None + + @staticmethod + def from_api_response(response: dict[str, Any]) -> WorkerHostProperties: + ip_addresses = response.get("ipAddresses", None) + return WorkerHostProperties( + ec2_instance_arn=response.get("ec2InstanceArn", None), + ec2_instance_type=response.get("ec2InstanceType", None), + host_name=response.get("hostName", None), + ip_addresses=IpAddresses.from_api_response(ip_addresses) if ip_addresses else None, + ) + + @dataclass class Job: id: str @@ -332,7 +392,7 @@ class Job: template: dict name: str - lifecycle_status: str + lifecycle_status: JobLifecycleStatus lifecycle_status_message: str priority: int created_at: datetime.datetime @@ -439,7 +499,7 @@ def get_optional_field( return { "id": response["jobId"], "name": response["name"], - "lifecycle_status": response["lifecycleStatus"], + "lifecycle_status": JobLifecycleStatus(response["lifecycleStatus"]), "lifecycle_status_message": response["lifecycleStatusMessage"], "priority": response["priority"], "created_at": response["createdAt"], @@ -603,6 +663,114 @@ def _is_job_complete(): max_retries=max_retries, ) + def list_steps( + self, + *, + deadline_client: DeadlineClient, + ) -> Generator[Step, None, None]: + list_steps_paginator: Paginator = deadline_client.get_paginator("list_steps") + list_steps_pages: PageIterator = call_api( + description=f"Listing steps for job {self.id}", + fn=lambda: list_steps_paginator.paginate( + farmId=self.farm.id, + queueId=self.queue.id, + jobId=self.id, + ), + ) + + for page in list_steps_pages: + for step in page["steps"]: + dependency_counts = step.get("dependencyCounts", None) + yield Step( + farm=self.farm, + queue=self.queue, + job=self, + id=step["stepId"], + name=step["name"], + created_at=step["createdAt"], + created_by=step["createdBy"], + lifecycle_status=StepLifecycleStatus(step["lifecycleStatus"]), + task_run_status=TaskStatus(step["taskRunStatus"]), + task_run_status_counts={ + TaskStatus(key): value for key, value in step["taskRunStatusCounts"].items() + }, + lifecycle_status_message=step.get("lifeCycleStatusMessage", None), + target_task_run_status=step.get("targetTaskRunStatus", None), + updated_at=step.get("updatedAt", None), + updated_by=step.get("updatedBy", None), + started_at=step.get("startedAt", None), + ended_at=step.get("endedAt", None), + dependency_counts=( + DependencyCounts.from_api_response(dependency_counts) + if dependency_counts is not None + else None + ), + ) + + def assert_single_task_log_contains( + self, + *, + deadline_client: DeadlineClient, + logs_client: BaseClient, + expected_pattern: re.Pattern | str, + assert_fail_msg: str = "Expected message not found in session log", + retries: int = 4, + backoff_factor: timedelta = timedelta(milliseconds=300), + ) -> None: + """ + Asserts that the expected regular expression pattern exists in the job's session log. + + This method is intended for jobs with a single step and task. It checks the logs of the + last run session for the single task. + + The method accounts for the eventual-consistency of CloudWatch log delivery and availability + through CloudWatch APIs by retrying a configurable number of times using retries and + backs-off exponentially if the pattern is not initially found for a configurable number of + times. + + Args: + deadline_client (deadline_test_fixtures.client.DeadlineClient): Deadline boto client + logs_client (botocore.clients.BaseClient): CloudWatch logs boto client + expected_pattern (re.Pattern | str): Either a regular expression pattern string, or a + pre-compiled regular expression Pattern object. This is pattern is searched against + each of the job's session logs, contatenated as a multi-line string joined by + a single newline character (\\n). + assert_fail_msg (str): The assertion message to raise if the pattern is not found after + the configured exponential backoff attempts. The CloudWatch log group name is + appended to the end of this message to assist with diagnosis. The default is + "Expected message not found in session log". + retries (int): The number of retries with exponential back-off to attempt while the + expected pattern is not found. Default is 4. + backoff_factor (datetime.timedelta): A multiple used for exponential back-off delay + between attempts when the expected pattern is not found. The formula used is: + + delay = backoff_factor * 2 ** i + + where i is the 0-based retry number + + Default is 300ms + """ + # Coerce Regex str patterns to a re.Pattern + if isinstance(expected_pattern, str): + expected_pattern = re.compile(expected_pattern) + + # Assert there is a single step and task + steps = list(self.list_steps(deadline_client=deadline_client)) + assert len(steps) == 1, "Job contains multiple steps" + step = steps[0] + tasks = list(step.list_tasks(deadline_client=deadline_client)) + assert len(tasks) == 1, "Job contains multiple tasks" + task = tasks[0] + + session = task.get_last_session(deadline_client=deadline_client) + session.assert_log_contains( + logs_client=logs_client, + expected_pattern=expected_pattern, + assert_fail_msg=assert_fail_msg, + backoff_factor=backoff_factor, + retries=retries, + ) + @property def complete(self) -> bool: # pragma: no cover return self.task_run_status in COMPLETE_TASK_STATUSES @@ -652,11 +820,388 @@ def __str__(self) -> str: # pragma: no cover ) +@dataclass +class DependencyCounts: + consumers_resolved: int + consumers_unresolved: int + dependencies_resolved: int + dependencies_unresolved: int + + @staticmethod + def from_api_response(response: dict[str, Any]) -> DependencyCounts: + return DependencyCounts( + consumers_resolved=response["consumersResolved"], + consumers_unresolved=response["consumersUnresolved"], + dependencies_resolved=response["dependenciesResolved"], + dependencies_unresolved=response["dependenciesUnresolved"], + ) + + +@dataclass +class Step: + farm: Farm + queue: Queue + job: Job + id: str + + name: str + created_at: datetime.datetime + created_by: str + lifecycle_status: StepLifecycleStatus + task_run_status: TaskStatus + task_run_status_counts: dict[TaskStatus, int] + lifecycle_status_message: str | None = None + target_task_run_status: TaskStatus | None = None + updated_at: datetime.datetime | None = None + updated_by: str | None = None + started_at: datetime.datetime | None = None + ended_at: datetime.datetime | None = None + dependency_counts: DependencyCounts | None = None + + def list_tasks( + self, + *, + deadline_client: DeadlineClient, + ) -> Generator[Task, None, None]: + list_tasks_paginator: Paginator = deadline_client.get_paginator("list_tasks") + list_tasks_pages: PageIterator = call_api( + description=f"Listing steps for job {self.job.id}", + fn=lambda: list_tasks_paginator.paginate( + farmId=self.farm.id, + queueId=self.queue.id, + jobId=self.job.id, + stepId=self.id, + ), + ) + for page in list_tasks_pages: + for task in page["tasks"]: + target_task_run_status = task.get("targetTaskRunStatus", None) + yield Task( + farm=self.farm, + queue=self.queue, + job=self.job, + step=self, + id=task["taskId"], + created_at=task["createdAt"], + created_by=task["createdBy"], + run_status=task["runStatus"], + failure_retry_count=task["failureRetryCount"], + latest_session_action_id=task.get("latestSessionActionId", None), + parameters=task.get("parameters", None), + target_task_run_status=( + TaskStatus(target_task_run_status) if target_task_run_status else None + ), + updated_at=task.get("updatedAt", None), + updated_by=task.get("updatedBy", None), + started_at=task.get("startedAt", None), + ended_at=task.get("endedAt", None), + ) + + +@dataclass +class FloatTaskParameterValue: + float: str + + +@dataclass +class IntTaskParameterValue: + int: str + + +@dataclass +class PathTaskParameterValue: + path: str + + +@dataclass +class StringTaskParameterValue: + string: str + + +TaskParameterValue = Union[ + FloatTaskParameterValue, IntTaskParameterValue, PathTaskParameterValue, StringTaskParameterValue +] + + +@dataclass +class Task: + farm: Farm + queue: Queue + job: Job + step: Step + id: str + + created_at: datetime.datetime + created_by: str + run_status: TaskStatus + ended_at: datetime.datetime | None = None + failure_retry_count: int | None = None + latest_session_action_id: str | None = None + parameters: dict[str, TaskParameterValue] | None = None + started_at: datetime.datetime | None = None + target_task_run_status: TaskStatus | None = None + updated_at: datetime.datetime | None = None + updated_by: str | None = None + + def get_last_session( + self, + *, + deadline_client: DeadlineClient, + ) -> Session: + if not self.latest_session_action_id: + raise ValueError(f"No latest session action ID for {self.id}") + match = re.search( + r"^sessionaction-(?P[a-f0-9]{32})-\d+$", self.latest_session_action_id + ) + if not match: + raise ValueError( + f"Latest session action ID for task {self.id} ({self.latest_session_action_id}) does not match the expected pattern." + ) + session_id_hex = match.group("session_id_hex") + session_id = f"session-{session_id_hex}" + session = deadline_client.get_session( + farmId=self.farm.id, + queueId=self.queue.id, + jobId=self.job.id, + sessionId=session_id, + ) + host_properties = session.get("hostProperties", None) + return Session( + farm=self.farm, + queue=self.queue, + job=self.job, + fleet=Fleet(session["fleetId"], farm=self.farm), + id=session["sessionId"], + lifecycle_status=session["lifecycleStatus"], + worker_log=LogConfiguration.from_api_response(session["workerLog"]), + host_properties=( + WorkerHostProperties.from_api_response(host_properties) if host_properties else None + ), + logs=LogConfiguration.from_api_response(session["log"]), + started_at=session.get("startedAt", None), + ended_at=session.get("endedAt", None), + target_lifecycle_status=session.get("targetLifecycleStatus", None), + updated_at=session.get("updatedAt", None), + updated_by=session.get("updatedBy", None), + worker_id=session["workerId"], + ) + + def list_sessions(self, *, deadline_client: DeadlineClient) -> Generator[Session, None, None]: + list_sessions_paginator: Paginator = deadline_client.get_paginator("list_sessions") + list_sessions_pages: PageIterator = call_api( + description=f"Listing steps for job {self.job.id}", + fn=lambda: list_sessions_paginator.paginate( + farmId=self.farm.id, + queueId=self.queue.id, + jobId=self.id, + ), + ) + for page in list_sessions_pages: + for session in page["sessions"]: + host_properties = session.get("hostProperties", None) + worker_log_config = session.get("workerLog", None) + yield Session( + farm=self.farm, + queue=self.queue, + job=self.job, + fleet=Fleet(session["fleetId"], farm=self.farm), + id=session["sessionId"], + lifecycle_status=session["lifecycleStatus"], + host_properties=( + WorkerHostProperties.from_api_response(host_properties) + if host_properties + else None + ), + logs=LogConfiguration.from_api_response(session["log"]), + started_at=session.get("startedAt", None), + ended_at=session.get("endedAt", None), + target_lifecycle_status=session.get("targetLifecycleStatus", None), + updated_at=session.get("updatedAt", None), + updated_by=session.get("updatedBy", None), + worker_id=session["workerId"], + worker_log=( + LogConfiguration.from_api_response(worker_log_config) + if worker_log_config + else None + ), + ) + + @dataclass class JobLogs: log_group_name: str logs: dict[str, list[CloudWatchLogEvent]] + @property + def session_logs(self) -> dict[str, SessionLog]: + return { + session_id: SessionLog(session_id=session_id, logs=logs) + for session_id, logs in self.logs.items() + } + + +@dataclass +class LogConfiguration: + log_driver: Literal["awslogs"] + error: str | None = None + options: dict[str, str] | None = None + parameters: dict[str, str] | None = None + + @staticmethod + def from_api_response(response: dict[str, Any]) -> LogConfiguration: + return LogConfiguration( + log_driver=response["logDriver"], + error=response.get("error", None), + options=response.get("options", None), + parameters=response.get("parameters", None), + ) + + +@dataclass +class Session: + farm: Farm + queue: Queue + job: Job + fleet: Fleet + id: str + + lifecycle_status: SessionLifecycleStatus + logs: LogConfiguration + started_at: datetime.datetime + worker_id: str + ended_at: datetime.datetime | None = None + host_properties: WorkerHostProperties | None = None + target_lifecycle_status: Literal["ENDED"] | None = None + updated_at: datetime.datetime | None = None + updated_by: str | None = None + worker_log: LogConfiguration | None = None + + def get_session_log(self, *, logs_client: BaseClient) -> SessionLog: + if not (log_driver := self.logs.log_driver): + raise ValueError('No "logDriver" key in session API response') + elif log_driver != "awslogs": + raise NotImplementedError(f'Unsupported log driver "{log_driver}"') + if not (session_log_config_options := self.logs.options): + raise ValueError('No "options" key in session "log" API response') + if not (log_group_name := session_log_config_options.get("logGroupName", None)): + raise ValueError('No "logGroupName" key in session "log" -> "options" API response') + if not (log_stream_name := session_log_config_options.get("logStreamName", None)): + raise ValueError('No "logStreamName" key in session "log" -> "options" API response') + + filter_log_events_paginator: Paginator = logs_client.get_paginator("filter_log_events") + filter_log_events_pages: PageIterator = call_api( + description=f"Fetching log events for session {self.id} in log group {log_group_name}", + fn=lambda: filter_log_events_paginator.paginate( + logGroupName=log_group_name, + logStreamNames=[log_stream_name], + ), + ) + log_events = filter_log_events_pages.build_full_result() + log_events = [CloudWatchLogEvent.from_api_response(e) for e in log_events["events"]] + + return SessionLog(session_id=self.id, logs=log_events) + + def assert_log_contains( + self, + *, + logs_client: BaseClient, + expected_pattern: re.Pattern | str, + assert_fail_msg: str = "Expected message not found in session log", + retries: int = 4, + backoff_factor: timedelta = timedelta(milliseconds=300), + ) -> None: + """ + Asserts that the expected regular expression pattern exists in the job's session log. + + This method accounts for the eventual-consistency of CloudWatch log delivery and + availability through CloudWatch APIs by retrying a configurable number of times using + exponential back-off if the pattern is not initially found. + + Args: + logs_client (botocore.clients.BaseClient): CloudWatch logs boto client + expected_pattern (re.Pattern | str): Either a regular expression pattern string, or a + pre-compiled regular expression Pattern object. This is pattern is searched against + each of the job's session logs, contatenated as a multi-line string joined by + a single newline character (\\n). + assert_fail_msg (str): The assertion message to raise if the pattern is not found after + the configured exponential backoff attempts. The CloudWatch log group name is + appended to the end of this message to assist with diagnosis. The default is + "Expected message not found in session log". + retries (int): The number of retries with exponential back-off to attempt while the + expected pattern is not found. Default is 4. + backoff_factor (datetime.timedelta): A multiple used for exponential back-off delay + between attempts when the expected pattern is not found. The formula used is: + + delay = backoff_factor * 2 ** i + + where i is the 0-based retry number + + Default is 300ms + """ + # Coerce Regex str patterns to a re.Pattern + if isinstance(expected_pattern, str): + expected_pattern = re.compile(expected_pattern) + + if not (session_log_config_options := self.logs.options): + raise ValueError('No "options" key in session "log" API response') + if not (log_group_name := session_log_config_options.get("logGroupName", None)): + raise ValueError('No "logGroupName" key in session "log" -> "options" API response') + + for i in range(retries + 1): + session_log = self.get_session_log(logs_client=logs_client) + + try: + session_log.assert_pattern_in_log( + expected_pattern=expected_pattern, + failure_msg=f"{assert_fail_msg}. Logs are in CloudWatch log group: {log_group_name}", + ) + except AssertionError: + if i == retries: + raise + else: + delay: timedelta = (2**i) * backoff_factor + LOG.warning( + f"Expected pattern not found in session log {self.id}, delaying {delay} then retry." + ) + time.sleep(delay.total_seconds()) + else: + return + + +@dataclass +class SessionLog: + session_id: str + logs: list[CloudWatchLogEvent] + + def assert_pattern_in_log( + self, + *, + expected_pattern: re.Pattern | str, + failure_msg: str, + ) -> None: + """ + Asserts that a pattern is found in the session log + + Args: + expected_pattern (re.Pattern | str): Either a regular expression pattern string, or a + pre-compiled regular expression Pattern object. This is pattern is searched against + each of the job's session logs, contatenated as a multi-line string joined by + a single newline character (\\n). + failure_msg (str): A message to be raised in an AssertionError if the expected pattern + is not found + + Raises: + AssertionError + Raised when the expected pattern is not found in the session log. The argument to + the AssertionError is the value of the failure_msg argument + """ + # Coerce Regex str patterns to a re.Pattern + if isinstance(expected_pattern, str): + expected_pattern = re.compile(expected_pattern) + + full_session_log = "\n".join(le.message for le in self.logs) + assert expected_pattern.search(full_session_log), failure_msg + @dataclass class CloudWatchLogEvent: diff --git a/test/unit/deadline/test_resources.py b/test/unit/deadline/test_resources.py index 319915e..297d177 100644 --- a/test/unit/deadline/test_resources.py +++ b/test/unit/deadline/test_resources.py @@ -1,7 +1,10 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +from __future__ import annotations + import datetime import json +import re from dataclasses import asdict, replace from typing import Any, Generator, cast from unittest.mock import MagicMock, call, patch @@ -16,6 +19,9 @@ QueueFleetAssociation, Job, JobAttachmentSettings, + Session, + Step, + Task, TaskStatus, ) from deadline_test_fixtures.deadline import resources as mod @@ -48,12 +54,24 @@ def queue(farm: Farm) -> Queue: @pytest.fixture -def fleet(farm: Farm) -> Fleet: - return Fleet(id="fleet-123", farm=farm) +def fleet_id() -> str: + return "fleet-123" @pytest.fixture -def qfa(farm: Farm, queue: Queue, fleet: Fleet) -> QueueFleetAssociation: +def fleet( + farm: Farm, + fleet_id: str, +) -> Fleet: + return Fleet(id=fleet_id, farm=farm) + + +@pytest.fixture +def qfa( + farm: Farm, + queue: Queue, + fleet: Fleet, +) -> QueueFleetAssociation: return QueueFleetAssociation( farm=farm, queue=queue, @@ -61,6 +79,285 @@ def qfa(farm: Farm, queue: Queue, fleet: Fleet) -> QueueFleetAssociation: ) +@pytest.fixture +def job_id() -> str: + return "job-123" + + +@pytest.fixture +def job( + farm: Farm, + queue: Queue, + job_id: str, +) -> Job: + return Job( + id=job_id, + farm=farm, + queue=queue, + template={}, + name="Job Name", + lifecycle_status=mod.JobLifecycleStatus.CREATE_COMPLETE, + lifecycle_status_message="Nice", + priority=1, + created_at=datetime.datetime.now(), + created_by="test-user", + ) + + +@pytest.fixture +def step_id() -> str: + return "step-123" + + +@pytest.fixture +def step( + farm: Farm, + queue: Queue, + job: Job, + step_id: str, +) -> Step: + return Step( + id=step_id, + farm=farm, + queue=queue, + job=job, + name="Step Name", + created_at=datetime.datetime.now(), + created_by="test-user", + lifecycle_status=mod.StepLifecycleStatus.CREATE_COMPLETE, + task_run_status=TaskStatus.SUCCEEDED, + task_run_status_counts={ + TaskStatus.ASSIGNED: 0, + TaskStatus.CANCELED: 0, + TaskStatus.FAILED: 0, + TaskStatus.INTERRUPTING: 0, + TaskStatus.NOT_COMPATIBLE: 0, + TaskStatus.PENDING: 0, + TaskStatus.READY: 0, + TaskStatus.RUNNING: 0, + TaskStatus.SCHEDULED: 0, + TaskStatus.STARTING: 0, + TaskStatus.SUCCEEDED: 0, + TaskStatus.SUSPENDED: 0, + }, + ) + + +@pytest.fixture +def session_id_hex() -> str: + return "00001111222233334444555566667777" + + +@pytest.fixture +def session_id(session_id_hex: str) -> str: + return f"session-{session_id_hex}" + + +@pytest.fixture +def session_action_id(session_id_hex: str) -> str: + return f"sessionaction-{session_id_hex}-0" + + +@pytest.fixture +def task_id() -> str: + return "task-a12" + + +@pytest.fixture +def task_created_at() -> datetime.datetime: + return datetime.datetime.now() + + +@pytest.fixture +def task_created_by() -> str: + return "taskcreator" + + +@pytest.fixture +def task_run_status() -> TaskStatus: + return TaskStatus.SUCCEEDED + + +@pytest.fixture +def task( + farm: Farm, + queue: Queue, + job: Job, + step: Step, + session_action_id: str, + task_id: str, + task_created_at: datetime.datetime, + task_created_by: str, + task_run_status: TaskStatus, +) -> Task: + return Task( + farm=farm, + queue=queue, + job=job, + step=step, + id=task_id, + created_at=task_created_at, + created_by=task_created_by, + run_status=task_run_status, + latest_session_action_id=session_action_id, + ) + + +@pytest.fixture +def session_log_driver() -> str: + return "awslogs" + + +@pytest.fixture +def session_log_group_name() -> str: + return "sessionLogGroup" + + +@pytest.fixture +def session_log_stream_name() -> str: + return "sessionLogStream" + + +@pytest.fixture +def session_log_config( + session_log_driver: str, + session_log_group_name: str, + session_log_stream_name: str, +) -> mod.LogConfiguration: + return mod.LogConfiguration( + log_driver=session_log_driver, # type: ignore[arg-type] + options={ + "logGroupName": session_log_group_name, + "logStreamName": session_log_stream_name, + }, + parameters={}, + ) + + +@pytest.fixture +def worker_log_driver() -> str: + return "awslogs" + + +@pytest.fixture +def worker_log_group_name() -> str: + return "workerLogGroup" + + +@pytest.fixture +def worker_log_stream_name() -> str: + return "workerLogStream" + + +@pytest.fixture +def worker_log_config( + worker_log_driver: str, + worker_log_group_name: str, + worker_log_stream_name: str, +) -> mod.LogConfiguration: + return mod.LogConfiguration( + log_driver=worker_log_driver, # type: ignore[arg-type] + options={ + "logGroupName": worker_log_group_name, + "logStreamName": worker_log_stream_name, + }, + parameters={}, + ) + + +@pytest.fixture +def worker_id() -> str: + return "worker-abc" + + +@pytest.fixture +def session_started_at() -> datetime.datetime: + return datetime.datetime.now() + + +@pytest.fixture +def session_lifecycle_status() -> mod.SessionLifecycleStatus: + return mod.SessionLifecycleStatus.ENDED + + +@pytest.fixture +def ip_v4_addresses() -> list[str]: + return ["192.168.0.100"] + + +@pytest.fixture +def ip_v6_addresses() -> list[str]: + return ["::1"] + + +@pytest.fixture +def ip_addresses( + ip_v4_addresses: list[str], + ip_v6_addresses: list[str], +) -> mod.IpAddresses: + return mod.IpAddresses( + ip_v4_addresses=ip_v4_addresses, + ip_v6_addresses=ip_v6_addresses, + ) + + +@pytest.fixture +def ec2_instance_arn() -> str: + return "ec2_instance_arn" + + +@pytest.fixture +def ec2_instance_type() -> str: + return "t3.micro" + + +@pytest.fixture +def host_name() -> str: + return "hostname" + + +@pytest.fixture +def worker_host_properties( + ec2_instance_arn: str, + ec2_instance_type: str, + host_name: str, + ip_addresses: mod.IpAddresses, +) -> mod.WorkerHostProperties: + return mod.WorkerHostProperties( + ec2_instance_arn=ec2_instance_arn, + ec2_instance_type=ec2_instance_type, + ip_addresses=ip_addresses, + host_name=host_name, + ) + + +@pytest.fixture +def session( + farm: Farm, + queue: Queue, + job: Job, + fleet: Fleet, + session_id: str, + session_lifecycle_status: mod.SessionLifecycleStatus, + session_log_config: mod.LogConfiguration, + session_started_at: datetime.datetime, + worker_id: str, + worker_log_config: mod.LogConfiguration, +) -> Session: + return Session( + farm=farm, + queue=queue, + fleet=fleet, + job=job, + id=session_id, + lifecycle_status=session_lifecycle_status, + logs=session_log_config, + started_at=session_started_at, + worker_id=worker_id, + worker_log=worker_log_config, + ) + + class TestFarm: def test_create(self) -> None: # GIVEN @@ -354,21 +651,6 @@ def task_run_status_counts( "SUCCEEDED": succeeded, } - @pytest.fixture - def job(self, farm: Farm, queue: Queue) -> Job: - return Job( - id="job-123", - farm=farm, - queue=queue, - template={}, - name="Job Name", - lifecycle_status="CREATE_COMPLETE", - lifecycle_status_message="Nice", - priority=1, - created_at=datetime.datetime.now(), - created_by="test-user", - ) - def test_submit( self, farm: Farm, @@ -468,7 +750,7 @@ def test_submit( assert job.queue is queue assert job.template == template assert job.name == "Test Job" - assert job.lifecycle_status == "CREATE_COMPLETE" + assert job.lifecycle_status == mod.JobLifecycleStatus.CREATE_COMPLETE assert job.lifecycle_status_message == "Nice" assert job.priority == priority assert job.created_at == created_at @@ -570,7 +852,7 @@ def test_refresh_job_info(self, job: Job) -> None: get_job_response = { "jobId": job.id, "name": job.name, - "lifecycleStatus": job.lifecycle_status, + "lifecycleStatus": job.lifecycle_status.value, "lifecycleStatusMessage": job.lifecycle_status_message, "createdAt": job.created_at, "createdBy": job.created_by, @@ -742,3 +1024,600 @@ def test_get_logs(self, job: Job) -> None: assert session_log_map["session-2"] == [ CloudWatchLogEvent.from_api_response(le) for le in log_events[1]["events"] ] + + def test_assert_single_task_log_contains_success(self, job: Job, session: Session) -> None: + # GIVEN + deadline_client = MagicMock() + logs_client = MagicMock() + step = MagicMock() + task = MagicMock() + step.list_tasks.return_value = [task] + task.get_last_session.return_value = session + expected_pattern = re.compile(r"a message") + + with ( + patch.object(job, "list_steps", return_value=[step]) as mock_list_steps, + patch.object(session, "assert_log_contains") as mock_session_assert_log_contains, + ): + + # WHEN + job.assert_single_task_log_contains( + deadline_client=deadline_client, + logs_client=logs_client, + expected_pattern=expected_pattern, + ) + + # THEN + # This test is only to confirm that no assertion is raised, since the expected message + # is in the logs + mock_session_assert_log_contains.assert_called_once_with( + logs_client=logs_client, + expected_pattern=expected_pattern, + assert_fail_msg="Expected message not found in session log", + backoff_factor=datetime.timedelta(milliseconds=300), + retries=4, + ) + mock_list_steps.assert_called_once_with(deadline_client=deadline_client) + step.list_tasks.assert_called_once_with(deadline_client=deadline_client) + task.get_last_session.assert_called_once_with(deadline_client=deadline_client) + + def test_assert_single_task_log_contains_multi_step(self, job: Job) -> None: + # GIVEN + deadline_client = MagicMock() + logs_client = MagicMock() + step = MagicMock() + expected_pattern = re.compile(r"a message") + + with (patch.object(job, "list_steps", return_value=[step, step]) as mock_list_steps,): + + # WHEN + def when(): + job.assert_single_task_log_contains( + deadline_client=deadline_client, + logs_client=logs_client, + expected_pattern=expected_pattern, + ) + + # THEN + with pytest.raises(AssertionError) as raise_ctx: + when() + + print(raise_ctx.value) + + assert raise_ctx.match("Job contains multiple steps") + mock_list_steps.assert_called_once_with(deadline_client=deadline_client) + step.list_tasks.assert_not_called() + + def test_assert_single_task_log_contains_multi_task(self, job: Job, session: Session) -> None: + # GIVEN + deadline_client = MagicMock() + logs_client = MagicMock() + step = MagicMock() + task = MagicMock() + step.list_tasks.return_value = [task, task] + task.get_last_session.return_value = session + expected_pattern = re.compile(r"a message") + + with (patch.object(job, "list_steps", return_value=[step]) as mock_list_steps,): + + # WHEN + def when(): + job.assert_single_task_log_contains( + deadline_client=deadline_client, + logs_client=logs_client, + expected_pattern=expected_pattern, + ) + + # THEN + with pytest.raises(AssertionError) as raise_ctx: + when() + + print(raise_ctx.value) + + assert raise_ctx.match("Job contains multiple tasks") + mock_list_steps.assert_called_once_with(deadline_client=deadline_client) + step.list_tasks.assert_called_once_with(deadline_client=deadline_client) + task.get_last_session.assert_not_called() + + def test_list_steps( + self, + job: Job, + ) -> None: + # GIVEN + step_id = "step-97f70ac0e02d4dc0acb589b9bd890981" + step_name = "a step" + created_at = datetime.datetime(2024, 9, 3) + created_by = "username" + lifecycle_status = "CREATE_COMPLETE" + task_run_status = "ASSIGNED" + lifecycle_status_message = ("a message",) + target_task_run_status = ("READY",) + updated_at = datetime.datetime(2024, 9, 3) + updated_by = "someone" + started_at = datetime.datetime(2024, 9, 3) + ended_at = datetime.datetime(2024, 9, 3) + task_run_status_counts = { + "PENDING": 0, + "READY": 0, + "ASSIGNED": 0, + "STARTING": 0, + "SCHEDULED": 0, + "INTERRUPTING": 0, + "RUNNING": 0, + "SUSPENDED": 0, + "CANCELED": 0, + "FAILED": 0, + "SUCCEEDED": 0, + "NOT_COMPATIBLE": 0, + } + dependency_counts = { + "consumersResolved": 0, + "consumersUnresolved": 0, + "dependenciesResolved": 0, + "dependenciesUnresolved": 0, + } + + deadline_client = MagicMock() + deadline_client.get_paginator.return_value.paginate.return_value = [ + { + "steps": [ + { + "stepId": step_id, + "name": step_name, + "createdAt": created_at, + "createdBy": created_by, + "lifecycleStatus": lifecycle_status, + "taskRunStatus": task_run_status, + "taskRunStatusCounts": task_run_status_counts, + "lifeCycleStatusMessage": lifecycle_status_message, + "targetTaskRunStatus": target_task_run_status, + "updatedAt": updated_at, + "updatedBy": updated_by, + "startedAt": started_at, + "endedAt": ended_at, + "dependencyCounts": dependency_counts, + }, + ], + } + ] + result = job.list_steps(deadline_client=deadline_client) + + # WHEN + result_list = list(result) + + # THEN + assert len(result_list) == 1 + step = result_list[0] + assert step.id == step_id + assert step.name == step_name + assert step.created_at == created_at + assert step.created_by == created_by + assert step.lifecycle_status == mod.StepLifecycleStatus(lifecycle_status) + assert step.task_run_status == task_run_status + assert step.lifecycle_status_message == lifecycle_status_message + assert step.target_task_run_status == target_task_run_status + assert step.updated_at == updated_at + assert step.updated_by == updated_by + assert step.started_at == started_at + assert step.ended_at == ended_at + + for status in step.task_run_status_counts: + assert step.task_run_status_counts[status] == task_run_status_counts[status.value] + + assert step.dependency_counts is not None + assert step.dependency_counts.consumers_resolved == dependency_counts["consumersResolved"] + assert ( + step.dependency_counts.consumers_unresolved == dependency_counts["consumersUnresolved"] + ) + assert ( + step.dependency_counts.dependencies_resolved + == dependency_counts["dependenciesResolved"] + ) + assert ( + step.dependency_counts.dependencies_unresolved + == dependency_counts["dependenciesUnresolved"] + ) + + +class TestStep: + def test_list_tasks( + self, + step: Step, + session_action_id: str, + ) -> None: + # GIVEN + deadline_client = MagicMock() + mock_get_paginator: MagicMock = deadline_client.get_paginator + mock_paginate: MagicMock = mock_get_paginator.return_value.paginate + task_id = "task-b73de3af607f472687cafb16def7664e" + created_at = datetime.datetime.now() + created_by = "someone" + run_status = "READY" + failure_retry_count = 5 + target_task_run_status = "RUNNING" + updated_at = datetime.datetime.now() + updated_by = "someoneelse" + started_at = datetime.datetime.now() + ended_at = datetime.datetime.now() + mock_paginate.return_value = [ + { + "tasks": [ + { + "taskId": task_id, + "createdAt": created_at, + "createdBy": created_by, + "runStatus": run_status, + "failureRetryCount": failure_retry_count, + "latestSessionActionId": session_action_id, + "parameters": {}, + "targetTaskRunStatus": target_task_run_status, + "updatedAt": updated_at, + "updatedBy": updated_by, + "startedAt": started_at, + "endedAt": ended_at, + }, + ], + }, + ] + generator = step.list_tasks(deadline_client=deadline_client) + + # WHEN + result = list(generator) + + # THEN + assert len(result) == 1 + task = result[0] + assert task.id == task_id + assert task.created_at == created_at + assert task.created_by == created_by + assert task.run_status == TaskStatus(run_status) + assert task.failure_retry_count == failure_retry_count + assert task.latest_session_action_id == session_action_id + assert task.parameters == {} + assert task.target_task_run_status == TaskStatus(target_task_run_status) + assert task.updated_at == updated_at + assert task.updated_by == updated_by + assert task.started_at == started_at + assert task.ended_at == ended_at + mock_get_paginator.assert_called_once_with("list_tasks") + mock_paginate.assert_called_once_with( + farmId=step.farm.id, + queueId=step.queue.id, + jobId=step.job.id, + stepId=step.id, + ) + + +class TestTask: + def test_get_last_session( + self, + fleet_id: str, + task: Task, + session_id: str, + session_lifecycle_status: mod.SessionLifecycleStatus, + session_log_config: mod.LogConfiguration, + worker_id: str, + worker_log_config: mod.LogConfiguration, + ec2_instance_arn: str, + ec2_instance_type: str, + host_name: str, + ip_v4_addresses: list[str], + ip_v6_addresses: list[str], + ) -> None: + # GIVEN + deadline_client = MagicMock() + mock_get_session: MagicMock = deadline_client.get_session + started_at = datetime.datetime.now() + ended_at = datetime.datetime.now() + updated_at = datetime.datetime.now() + updated_by = "taskupdater" + mock_get_session.return_value = { + "sessionId": session_id, + "fleetId": fleet_id, + "lifecycleStatus": session_lifecycle_status.value, + "log": { + "logDriver": session_log_config.log_driver, + "options": session_log_config.options, + "parameters": session_log_config.parameters, + }, + "hostProperties": { + "ec2InstanceArn": ec2_instance_arn, + "ec2InstanceType": ec2_instance_type, + "hostName": host_name, + "ipAddresses": { + "ipV4Addresses": ip_v4_addresses, + "ipV6Addresses": ip_v6_addresses, + }, + }, + "startedAt": started_at, + "endedAt": ended_at, + "updatedAt": updated_at, + "updatedBy": updated_by, + "workerId": worker_id, + "workerLog": { + "logDriver": worker_log_config.log_driver, + "options": worker_log_config.options, + "parameters": worker_log_config.parameters, + }, + } + + # WHEN + returned_session = task.get_last_session(deadline_client=deadline_client) + + # THEN + assert isinstance(returned_session, Session) + assert returned_session.id == session_id + assert returned_session.fleet.id == fleet_id + assert returned_session.worker_id == worker_id + assert returned_session.lifecycle_status == session_lifecycle_status + assert returned_session.started_at == started_at + assert returned_session.ended_at == ended_at + assert returned_session.updated_at == updated_at + assert returned_session.updated_by == updated_by + + assert isinstance(returned_session.host_properties, mod.WorkerHostProperties) + assert returned_session.host_properties.ec2_instance_arn == ec2_instance_arn + assert returned_session.host_properties.ec2_instance_type == ec2_instance_type + assert returned_session.host_properties.host_name == host_name + + assert isinstance(returned_session.host_properties.ip_addresses, mod.IpAddresses) + assert returned_session.host_properties.ip_addresses.ip_v4_addresses == ip_v4_addresses + assert returned_session.host_properties.ip_addresses.ip_v6_addresses == ip_v6_addresses + + assert isinstance(returned_session.logs, mod.LogConfiguration) + assert returned_session.logs.log_driver == session_log_config.log_driver + assert returned_session.logs.options == session_log_config.options + assert returned_session.logs.parameters == session_log_config.parameters + + assert isinstance(returned_session.worker_log, mod.LogConfiguration) + assert returned_session.worker_log.log_driver == worker_log_config.log_driver + assert returned_session.worker_log.options == worker_log_config.options + assert returned_session.worker_log.parameters == worker_log_config.parameters + + +class TestSession: + @pytest.mark.parametrize( + argnames=("expected_pattern", "log_messages"), + argvalues=( + pytest.param("PATTERN", ["PATTERN"], id="exact-match"), + pytest.param("PATTERN", ["PATTERN at beginning"], id="match-beginning"), + pytest.param("PATTERN", ["ends with PATTERN"], id="match-end"), + pytest.param("PATTERN", ["multiline with", "the PATTERN"], id="match-end"), + pytest.param( + re.compile(r"This is\na multiline pattern", re.MULTILINE), + ["extra lines", "This is", "a multiline pattern", "embedded"], + id="multi-line-pattern", + ), + pytest.param( + re.compile(r"^anchored\nmultiline pattern", re.MULTILINE), + ["extra lines", "anchored", "multiline pattern", "trailing line"], + id="anchored-multi-line-pattern", + ), + ), + ) + def test_assert_logs_success( + self, + session: Session, + expected_pattern: str | re.Pattern, + log_messages: list[str], + ) -> None: + # GIVEN + logs_client = MagicMock() + logs = mod.SessionLog( + session_id=session.id, + logs=[ + mod.CloudWatchLogEvent( + ingestion_time=i, + message=message, + timestamp=i, + ) + for i, message in enumerate(log_messages) + ], + ) + + with ( + patch.object(session, "get_session_log", return_value=logs) as mock_get_session_log, + # Speed up tests + patch.object(mod.time, "sleep") as mock_time_sleep, + ): + + # WHEN + session.assert_log_contains( + logs_client=logs_client, + expected_pattern=expected_pattern, + ) + + # THEN + # (no exception is raised) + mock_get_session_log.assert_called_once_with(logs_client=logs_client) + mock_time_sleep.assert_not_called() + + @pytest.mark.parametrize( + argnames="assert_fail_msg", + argvalues=( + pytest.param(None, id="default"), + pytest.param("message to raise", id="provided-assert-msg"), + ), + ) + @pytest.mark.parametrize( + argnames="retries", + argvalues=(pytest.param(None, id="retries[default]"), pytest.param(3, id="rerties[3]")), + ) + @pytest.mark.parametrize( + argnames="backoff_factor", + argvalues=( + pytest.param(None, id="backoff_factor[default]"), + pytest.param(datetime.timedelta(seconds=10), id="backoff_factor[10s]"), + ), + ) + def test_assert_logs_contains_fail( + self, + session: Session, + assert_fail_msg: str | None, + retries: int | None, + backoff_factor: datetime.timedelta | None, + session_log_group_name: str, + ) -> None: + # GIVEN + logs_client = MagicMock() + session_id = "session-5815c7b8054c4548837c2538f0139661" + logs = mod.SessionLog( + session_id=session_id, + logs=[ + mod.CloudWatchLogEvent( + ingestion_time=i, + message=message, + timestamp=i, + ) + for i, message in enumerate( + [ + "this is not the expected message", + ] + ) + ], + ) + expected_assertion_msg = ( + f"{assert_fail_msg or 'Expected message not found in session log'}." + f" Logs are in CloudWatch log group: {session_log_group_name}" + ) + expected_retries = retries if retries is not None else 4 + expected_backoff_factor = ( + backoff_factor if backoff_factor is not None else datetime.timedelta(milliseconds=300) + ) + + with ( + patch.object(session, "get_session_log", return_value=logs) as mock_get_session_log, + # Speed up tests + patch.object(mod.time, "sleep") as mock_time_sleep, + ): + # WHEN + def when(): + kwargs: dict[str, Any] = { + "logs_client": logs_client, + "expected_pattern": re.compile("a message"), + } + if assert_fail_msg is not None: + kwargs["assert_fail_msg"] = assert_fail_msg + if retries is not None: + kwargs["retries"] = retries + if backoff_factor is not None: + kwargs["backoff_factor"] = backoff_factor + session.assert_log_contains(**kwargs) + + # THEN + with pytest.raises(AssertionError) as raise_ctx: + when() + + assert raise_ctx.value.args[0] == expected_assertion_msg + mock_get_session_log.assert_has_calls( + [call(logs_client=logs_client)] * (expected_retries + 1) + ) + assert mock_get_session_log.call_count == (expected_retries + 1) + mock_time_sleep.assert_has_calls( + [ + call((expected_backoff_factor * (2**i)).total_seconds()) + for i in range(expected_retries) + ] + ) + assert mock_time_sleep.call_count == expected_retries + + @pytest.mark.parametrize( + argnames="retries_before_success", + argvalues=(1, 2), + ) + def test_assert_logs_contain_cw_eventual_consistency( + self, + session: Session, + retries_before_success: int, + ) -> None: + # GIVEN + logs_client = MagicMock() + session_id = "session-5815c7b8054c4548837c2538f0139661" + message_only_in_complete_logs = "message only in complete logs" + partial_log = mod.SessionLog( + session_id=session_id, + logs=[ + mod.CloudWatchLogEvent( + ingestion_time=0, + message="this contains partial logs", + timestamp=0, + ), + ], + ) + complete_log = mod.SessionLog( + session_id=session_id, + logs=[ + mod.CloudWatchLogEvent( + ingestion_time=0, + message="this contains partial logs", + timestamp=0, + ), + mod.CloudWatchLogEvent( + ingestion_time=0, + message=message_only_in_complete_logs, + timestamp=0, + ), + ], + ) + + with ( + patch.object( + session, + "get_session_log", + side_effect=( + # Return partial log before performing the specified number of retries + ([partial_log] * retries_before_success) + # The complete log + + [complete_log] + ), + ) as mock_get_session_log, + patch.object(mod.time, "sleep") as mock_time_sleep, + ): + + # WHEN + session.assert_log_contains( + logs_client=logs_client, + expected_pattern=re.compile(re.escape(message_only_in_complete_logs)), + ) + + # THEN + mock_get_session_log.assert_has_calls( + [call(logs_client=logs_client)] * (retries_before_success + 1) + ) + mock_time_sleep.assert_has_calls( + [call(0.3 * (2.0**i)) for i in range(retries_before_success)] + ) + + @pytest.mark.parametrize( + argnames="log_event_messages", + argvalues=( + pytest.param(["a", "b"], id="2events"), + pytest.param(["a", "b", "c"], id="3events"), + ), + ) + def test_get_session_log( + self, + session: Session, + log_event_messages: list[str], + ) -> None: + # GIVEN + logs_client = MagicMock() + logs_client.get_paginator.return_value.paginate.return_value.build_full_result.return_value = { + "events": [ + { + "ingestionTime": i, + "message": message, + "timestamp": i, + } + for i, message in enumerate(log_event_messages) + ] + } + + # WHEN + result = session.get_session_log(logs_client=logs_client) + + # THEN + assert isinstance(result, mod.SessionLog) + assert result.session_id == session.id + for log_event, expected_message in zip(result.logs, log_event_messages): + assert log_event.message == expected_message