/object/object1/+`
+ return ServiceModelIdentifier("s3")
+
+ # detect S3 requests with "AWS id:key" Auth headers
+ auth_header = request.headers.get("Authorization") or ""
+ if auth_header.startswith("AWS "):
+ return ServiceModelIdentifier("s3")
+
+ if uses_host_addressing(request.headers):
+ # Note: This needs to be the last rule (and therefore is not in the host rules), since it is incredibly greedy
+ return ServiceModelIdentifier("s3")
+
+
+@singleton_factory
+def get_service_catalog() -> ServiceCatalog:
+ """Loads the ServiceCatalog (which contains all the service specs), and potentially re-uses a cached index."""
+ if not os.path.isdir(config.dirs.cache):
+ return ServiceCatalog()
+
+ try:
+ ls_ver = VERSION.replace(".", "_")
+ botocore_ver = botocore.__version__.replace(".", "_")
+ cache_file_name = f"service-catalog-{ls_ver}-{botocore_ver}.pickle"
+ cache_file = os.path.join(config.dirs.cache, cache_file_name)
+
+ if not os.path.exists(cache_file):
+ LOG.debug("building service catalog index cache file %s", cache_file)
+ index = build_service_index_cache(cache_file)
+ else:
+ LOG.debug("loading service catalog index cache file %s", cache_file)
+ index = load_service_index_cache(cache_file)
+
+ return ServiceCatalog(index)
+ except Exception:
+ LOG.exception(
+ "error while processing service catalog index cache, falling back to lazy-loaded index"
+ )
+ return ServiceCatalog()
+
+
+def resolve_conflicts(
+ candidates: Set[ServiceModelIdentifier], request: Request
+) -> ServiceModelIdentifier:
+ """
+ Some service definitions are overlapping to a point where they are _not_ distinguishable at all
+ (f.e. ``DescribeEndpints`` in timestream-query and timestream-write).
+ These conflicts need to be resolved manually.
+ """
+ service_name_candidates = {service.name for service in candidates}
+ if service_name_candidates == {"timestream-query", "timestream-write"}:
+ return ServiceModelIdentifier("timestream-query")
+ if service_name_candidates == {"docdb", "neptune", "rds"}:
+ return ServiceModelIdentifier("rds")
+ if service_name_candidates == {"sqs"}:
+ # SQS now have 2 different specs for `query` and `json` protocol. From our current implementation with the
+ # parser and serializer, we need to have 2 different service names for them, but they share one provider
+ # implementation. `sqs` represents the `json` protocol spec, and `sqs-query` the `query` protocol
+ # (default again in botocore starting with 1.32.6).
+ # The `application/x-amz-json-1.0` header is mandatory for requests targeting SQS with the `json` protocol. We
+ # can safely route them to the `sqs` JSON parser/serializer. If not present, route the request to the
+ # sqs-query protocol.
+ content_type = request.headers.get("Content-Type")
+ return (
+ ServiceModelIdentifier("sqs")
+ if content_type == "application/x-amz-json-1.0"
+ else ServiceModelIdentifier("sqs", "query")
+ )
+
+
+def determine_aws_service_model_for_data_plane(
+ request: Request, services: ServiceCatalog = None
+) -> Optional[ServiceModel]:
+ """
+ A stripped down version of ``determine_aws_service_model`` which only checks hostname indicators for
+ the AWS data plane, such as s3 websites, lambda function URLs, or API gateway routes.
+ """
+ custom_host_match = custom_host_addressing_rules(request.host)
+ if custom_host_match:
+ services = services or get_service_catalog()
+ return services.get(*custom_host_match)
+
+
+def determine_aws_service_model(
+ request: Request, services: ServiceCatalog = None
+) -> Optional[ServiceModel]:
+ """
+ Tries to determine the name of the AWS service an incoming request is targeting.
+ :param request: to determine the target service name of
+ :param services: service catalog (can be handed in for caching purposes)
+ :return: service name string (or None if the targeting service could not be determined exactly)
+ """
+ services = services or get_service_catalog()
+ signing_name, target_prefix, operation, host, path = _extract_service_indicators(request)
+ candidates = set()
+
+ # 1. check the signing names
+ if signing_name:
+ signing_name_candidates = services.by_signing_name(signing_name)
+ if len(signing_name_candidates) == 1:
+ # a unique signing-name -> service name mapping is the case for ~75% of service operations
+ return services.get(*signing_name_candidates[0])
+
+ # try to find a match with the custom signing name rules
+ custom_match = custom_signing_name_rules(signing_name, path)
+ if custom_match:
+ return services.get(*custom_match)
+
+ # still ambiguous - add the services to the list of candidates
+ candidates.update(signing_name_candidates)
+
+ # 2. check the target prefix
+ if target_prefix and operation:
+ target_candidates = services.by_target_prefix(target_prefix)
+ if len(target_candidates) == 1:
+ # a unique target prefix
+ return services.get(*target_candidates[0])
+
+ # still ambiguous - add the services to the list of candidates
+ candidates.update(target_candidates)
+
+ # exclude services where the operation is not contained in the service spec
+ for service_identifier in list(candidates):
+ service = services.get(*service_identifier)
+ if operation not in service.operation_names:
+ candidates.remove(service_identifier)
+ else:
+ # exclude services which have a target prefix (the current request does not have one)
+ for service_identifier in list(candidates):
+ service = services.get(*service_identifier)
+ if service.metadata.get("targetPrefix") is not None:
+ candidates.remove(service_identifier)
+
+ if len(candidates) == 1:
+ service_identifier = candidates.pop()
+ return services.get(*service_identifier)
+
+ # 3. check the path if it is set and not a trivial root path
+ if path and path != "/":
+ # try to find a match with the custom path rules
+ custom_path_match = custom_path_addressing_rules(path)
+ if custom_path_match:
+ return services.get(*custom_path_match)
+
+ # 4. check the host (custom host addressing rules)
+ if host:
+ # iterate over the service spec's endpoint prefix
+ for prefix, services_per_prefix in services.endpoint_prefix_index.items():
+ # this prevents a virtual host addressed bucket to be wrongly recognized
+ if host.startswith(f"{prefix}.") and ".s3." not in host:
+ if len(services_per_prefix) == 1:
+ return services.get(*services_per_prefix[0])
+ candidates.update(services_per_prefix)
+
+ custom_host_match = custom_host_addressing_rules(host)
+ if custom_host_match:
+ return services.get(*custom_host_match)
+
+ if request.shallow:
+ # from here on we would need access to the request body, which doesn't exist for shallow requests like
+ # WebsocketRequests.
+ return None
+
+ # 5. check the query / form-data
+ try:
+ values = request.values
+ if "Action" in values:
+ # query / ec2 protocol requests always have an action and a version (the action is more significant)
+ query_candidates = [
+ service
+ for service in services.by_operation(values["Action"])
+ if service.protocol in ("ec2", "query")
+ ]
+
+ if len(query_candidates) == 1:
+ return services.get(*query_candidates[0])
+
+ if "Version" in values:
+ for service_identifier in list(query_candidates):
+ service_model = services.get(*service_identifier)
+ if values["Version"] != service_model.api_version:
+ # the combination of Version and Action is not unique, add matches to the candidates
+ query_candidates.remove(service_identifier)
+
+ if len(query_candidates) == 1:
+ return services.get(*query_candidates[0])
+
+ candidates.update(query_candidates)
+
+ except RequestEntityTooLarge:
+ # Some requests can be form-urlencoded but also contain binary data, which will fail the form parsing (S3 can
+ # do this). In that case, skip this step and continue to try to determine the service name. The exception is
+ # RequestEntityTooLarge even if the error is due to failed decoding.
+ LOG.debug(
+ "Failed to determine AWS service from request body because the form could not be parsed",
+ exc_info=LOG.isEnabledFor(logging.DEBUG),
+ )
+
+ # 6. resolve service spec conflicts
+ resolved_conflict = resolve_conflicts(candidates, request)
+ if resolved_conflict:
+ return services.get(*resolved_conflict)
+
+ # 7. check the legacy rules in the end
+ legacy_match = legacy_rules(request)
+ if legacy_match:
+ return services.get(*legacy_match)
+
+ if signing_name:
+ return services.get(name=signing_name)
+ if candidates:
+ return services.get(*candidates.pop())
+ return None
diff --git a/localstack-core/localstack/aws/protocol/validate.py b/localstack-core/localstack/aws/protocol/validate.py
new file mode 100644
index 0000000000000..30d1be4355fb0
--- /dev/null
+++ b/localstack-core/localstack/aws/protocol/validate.py
@@ -0,0 +1,173 @@
+"""Slightly extends the ``botocore.validate`` package to provide better integration with our parser/serializer."""
+
+from typing import Any, Dict, List, NamedTuple
+
+from botocore.model import OperationModel, Shape
+from botocore.validate import ParamValidator as BotocoreParamValidator
+from botocore.validate import ValidationErrors as BotocoreValidationErrors
+from botocore.validate import type_check
+
+from localstack.aws.api import ServiceRequest
+
+
+class Error(NamedTuple):
+ """
+ A wrapper around ``botocore.validate`` error tuples.
+
+ Attributes:
+ reason The error type
+ name The name of the parameter the error occurred at
+ attributes Error type-specific attributes
+ """
+
+ reason: str
+ name: str
+ attributes: Dict[str, Any]
+
+
+class ParameterValidationError(Exception):
+ error: Error
+
+ def __init__(self, error: Error) -> None:
+ self.error = error
+ super().__init__(self.message)
+
+ @property
+ def reason(self):
+ return self.error.reason
+
+ @property
+ def message(self) -> str:
+ """
+ Returns a default message for the error formatted by BotocoreValidationErrors.
+ :return: the exception message.
+ """
+ return BotocoreValidationErrors()._format_error(self.error)
+
+
+class MissingRequiredField(ParameterValidationError):
+ @property
+ def required_name(self) -> str:
+ return self.error.attributes["required_name"]
+
+
+# TODO: extend subclasses with properties from error arguments as needed. see ValidationErrors._format_error for
+# which those are.
+
+
+class UnknownField(ParameterValidationError):
+ pass
+
+
+class InvalidType(ParameterValidationError):
+ pass
+
+
+class InvalidRange(ParameterValidationError):
+ pass
+
+
+class InvalidLength(ParameterValidationError):
+ pass
+
+
+class JsonEncodingError(ParameterValidationError):
+ pass
+
+
+class InvalidDocumentType(ParameterValidationError):
+ pass
+
+
+class MoreThanOneInput(ParameterValidationError):
+ pass
+
+
+class EmptyInput(ParameterValidationError):
+ pass
+
+
+class ValidationErrors(BotocoreValidationErrors):
+ def __init__(self, shape: Shape, params: Dict[str, Any]):
+ super().__init__()
+ self.shape = shape
+ self.params = params
+ self._exceptions: List[ParameterValidationError] = []
+
+ @property
+ def exceptions(self):
+ return self._exceptions
+
+ def raise_first(self):
+ for error in self._exceptions:
+ raise error
+
+ def report(self, name, reason, **kwargs):
+ error = Error(reason, name, kwargs)
+ self._errors.append(error)
+ self._exceptions.append(self.to_exception(error))
+
+ def to_exception(self, error: Error) -> ParameterValidationError:
+ error_type, name, additional = error
+
+ if error_type == "missing required field":
+ return MissingRequiredField(error)
+ elif error_type == "unknown field":
+ return UnknownField(error)
+ elif error_type == "invalid type":
+ return InvalidType(error)
+ elif error_type == "invalid range":
+ return InvalidRange(error)
+ elif error_type == "invalid length":
+ return InvalidLength(error)
+ elif error_type == "unable to encode to json":
+ return JsonEncodingError(error)
+ elif error_type == "invalid type for document":
+ return InvalidDocumentType(error)
+ elif error_type == "more than one input":
+ return MoreThanOneInput(error)
+ elif error_type == "empty input":
+ return EmptyInput(error)
+
+ return ParameterValidationError(error)
+
+
+class ParamValidator(BotocoreParamValidator):
+ def validate(self, params: Dict[str, Any], shape: Shape):
+ """Validate parameters against a shape model.
+
+ This method will validate the parameters against a provided shape model.
+ All errors will be collected before returning to the caller. This means
+ that this method will not stop at the first error, it will return all
+ possible errors.
+
+ :param params: User provided dict of parameters
+ :param shape: A shape model describing the expected input.
+
+ :return: A list of errors.
+
+ """
+ errors = ValidationErrors(shape, params)
+ self._validate(params, shape, errors, name="")
+ return errors
+
+ @type_check(valid_types=(dict,))
+ def _validate_structure(self, params, shape, errors, name):
+ # our parser sets the value of required members to None if they are not in the incoming request. we correct
+ # this behavior here to get the correct error messages.
+ for required_member in shape.metadata.get("required", []):
+ if required_member in params and params[required_member] is None:
+ params.pop(required_member)
+
+ super(ParamValidator, self)._validate_structure(params, shape, errors, name)
+
+
+def validate_request(operation: OperationModel, request: ServiceRequest) -> ValidationErrors:
+ """
+ Validates the service request with the input shape of the given operation.
+
+ :param operation: the operation
+ :param request: the input shape of the operation being validated
+ :return: ValidationError object
+ """
+ return ParamValidator().validate(request, operation.input_shape)
diff --git a/localstack-core/localstack/aws/scaffold.py b/localstack-core/localstack/aws/scaffold.py
new file mode 100644
index 0000000000000..0f828f6156dde
--- /dev/null
+++ b/localstack-core/localstack/aws/scaffold.py
@@ -0,0 +1,560 @@
+import io
+import keyword
+import re
+from functools import cached_property
+from multiprocessing import Pool
+from pathlib import Path
+from typing import Dict, List, Optional, Set
+
+import click
+from botocore import xform_name
+from botocore.exceptions import UnknownServiceError
+from botocore.model import (
+ ListShape,
+ MapShape,
+ OperationModel,
+ ServiceModel,
+ Shape,
+ StringShape,
+ StructureShape,
+)
+from typing_extensions import OrderedDict
+
+from localstack.aws.spec import load_service
+from localstack.utils.common import camel_to_snake_case, snake_to_camel_case
+
+# Some minification packages might treat "type" as a keyword, some specs define shapes called like the type "Optional"
+KEYWORDS = list(keyword.kwlist) + ["type", "Optional", "Union"]
+is_keyword = KEYWORDS.__contains__
+
+
+def is_bad_param_name(name: str) -> bool:
+ if name == "context":
+ return True
+
+ if is_keyword(name):
+ return True
+
+ return False
+
+
+def to_valid_python_name(spec_name: str) -> str:
+ sanitized = re.sub(r"[^0-9a-zA-Z_]+", "_", spec_name)
+
+ if sanitized[0].isnumeric():
+ sanitized = "i_" + sanitized
+
+ if is_keyword(sanitized):
+ sanitized += "_"
+
+ if sanitized.startswith("__"):
+ sanitized = sanitized[1:]
+
+ return sanitized
+
+
+def html_to_rst(html: str):
+ import pypandoc
+
+ doc = pypandoc.convert_text(html, "rst", format="html")
+ doc = doc.replace("\_", "_") # noqa: W605
+ doc = doc.replace("\|", "|") # noqa: W605
+ doc = doc.replace("\ ", " ") # noqa: W605
+ doc = doc.replace("\\", "\\\\") # noqa: W605
+ rst = doc.strip()
+ return rst
+
+
+class ShapeNode:
+ service: ServiceModel
+ shape: Shape
+
+ def __init__(self, service: ServiceModel, shape: Shape) -> None:
+ super().__init__()
+ self.service = service
+ self.shape = shape
+
+ @cached_property
+ def request_operation(self) -> Optional[OperationModel]:
+ for operation_name in self.service.operation_names:
+ operation = self.service.operation_model(operation_name)
+ if operation.input_shape is None:
+ continue
+
+ if to_valid_python_name(self.shape.name) == to_valid_python_name(
+ operation.input_shape.name
+ ):
+ return operation
+
+ return None
+
+ @cached_property
+ def response_operation(self) -> Optional[OperationModel]:
+ for operation_name in self.service.operation_names:
+ operation = self.service.operation_model(operation_name)
+ if operation.output_shape is None:
+ continue
+
+ if to_valid_python_name(self.shape.name) == to_valid_python_name(
+ operation.output_shape.name
+ ):
+ return operation
+
+ return None
+
+ @cached_property
+ def is_request(self):
+ return self.request_operation is not None
+
+ @cached_property
+ def is_response(self):
+ return self.response_operation is not None
+
+ @property
+ def name(self) -> str:
+ return to_valid_python_name(self.shape.name)
+
+ @cached_property
+ def is_exception(self):
+ metadata = self.shape.metadata
+ return metadata.get("error") or metadata.get("exception")
+
+ @property
+ def is_primitive(self):
+ return self.shape.type_name in ["integer", "boolean", "float", "double", "string"]
+
+ @property
+ def is_enum(self):
+ return isinstance(self.shape, StringShape) and self.shape.enum
+
+ @property
+ def dependencies(self) -> List[str]:
+ shape = self.shape
+
+ if isinstance(shape, StructureShape):
+ return [to_valid_python_name(v.name) for v in shape.members.values()]
+ if isinstance(shape, ListShape):
+ return [to_valid_python_name(shape.member.name)]
+ if isinstance(shape, MapShape):
+ return [to_valid_python_name(shape.key.name), to_valid_python_name(shape.value.name)]
+
+ return []
+
+ def _print_structure_declaration(self, output, doc=True, quote_types=False):
+ if self.is_exception:
+ self._print_as_class(output, "ServiceException", doc, quote_types)
+ return
+
+ if any(map(is_keyword, self.shape.members.keys())):
+ self._print_as_typed_dict(output, doc, quote_types)
+ return
+
+ if self.is_request:
+ base = "ServiceRequest"
+ else:
+ base = "TypedDict, total=False"
+
+ self._print_as_class(output, base, doc, quote_types)
+
+ def _print_as_class(self, output, base: str, doc=True, quote_types=False):
+ output.write(f"class {to_valid_python_name(self.shape.name)}({base}):\n")
+
+ q = '"' if quote_types else ""
+
+ if doc:
+ self.print_shape_doc(output, self.shape)
+
+ if self.is_exception:
+ error_spec = self.shape.metadata.get("error", {})
+ output.write(f' code: str = "{error_spec.get("code", self.shape.name)}"\n')
+ output.write(f" sender_fault: bool = {error_spec.get('senderFault', False)}\n")
+ output.write(f" status_code: int = {error_spec.get('httpStatusCode', 400)}\n")
+ elif not self.shape.members:
+ output.write(" pass\n")
+
+ # Avoid generating members for the common error members:
+ # - The message will always be the exception message (first argument of the exception class init)
+ # - The code is already set above
+ # - The type is the sender_fault which is already set above
+ remaining_members = {
+ k: v
+ for k, v in self.shape.members.items()
+ if not self.is_exception or k.lower() not in ["message", "code"]
+ }
+
+ # render any streaming payload first
+ if self.is_request and self.request_operation.has_streaming_input:
+ member: str = self.request_operation.input_shape.serialization.get("payload")
+ shape: Shape = self.request_operation.get_streaming_input()
+ if member in self.shape.required_members:
+ output.write(f" {member}: IO[{q}{to_valid_python_name(shape.name)}{q}]\n")
+ else:
+ output.write(
+ f" {member}: Optional[IO[{q}{to_valid_python_name(shape.name)}{q}]]\n"
+ )
+ del remaining_members[member]
+ # render the streaming payload first
+ if self.is_response and self.response_operation.has_streaming_output:
+ member: str = self.response_operation.output_shape.serialization.get("payload")
+ shape: Shape = self.response_operation.get_streaming_output()
+ shape_name = to_valid_python_name(shape.name)
+ if member in self.shape.required_members:
+ output.write(
+ f" {member}: Union[{q}{shape_name}{q}, IO[{q}{shape_name}{q}], Iterable[{q}{shape_name}{q}]]\n"
+ )
+ else:
+ output.write(
+ f" {member}: Optional[Union[{q}{shape_name}{q}, IO[{q}{shape_name}{q}], Iterable[{q}{shape_name}{q}]]]\n"
+ )
+ del remaining_members[member]
+
+ for k, v in remaining_members.items():
+ if k in self.shape.required_members:
+ if v.serialization.get("eventstream"):
+ output.write(f" {k}: Iterator[{q}{to_valid_python_name(v.name)}{q}]\n")
+ else:
+ output.write(f" {k}: {q}{to_valid_python_name(v.name)}{q}\n")
+ else:
+ if v.serialization.get("eventstream"):
+ output.write(f" {k}: Iterator[{q}{to_valid_python_name(v.name)}{q}]\n")
+ else:
+ output.write(f" {k}: Optional[{q}{to_valid_python_name(v.name)}{q}]\n")
+
+ def _print_as_typed_dict(self, output, doc=True, quote_types=False):
+ name = to_valid_python_name(self.shape.name)
+ output.write('%s = TypedDict("%s", {\n' % (name, name))
+ for k, v in self.shape.members.items():
+ member_name = to_valid_python_name(v.name)
+ # check if the member name is the same as the type name (recursive types need to use forward references)
+ recursive_type = name == member_name
+ q = '"' if quote_types or recursive_type else ""
+ if k in self.shape.required_members:
+ if v.serialization.get("eventstream"):
+ output.write(f' "{k}": Iterator[{q}{member_name}{q}],\n')
+ else:
+ output.write(f' "{k}": {q}{member_name}{q},\n')
+ else:
+ if v.serialization.get("eventstream"):
+ output.write(f' "{k}": Iterator[{q}{member_name}{q}],\n')
+ else:
+ output.write(f' "{k}": Optional[{q}{member_name}{q}],\n')
+ output.write("}, total=False)")
+
+ def print_shape_doc(self, output, shape):
+ html = shape.documentation
+ rst = html_to_rst(html)
+ if rst:
+ output.write(' """')
+ output.write(f"{rst}\n")
+ output.write(' """\n')
+
+ def print_declaration(self, output, doc=True, quote_types=False):
+ shape = self.shape
+
+ q = '"' if quote_types else ""
+
+ if isinstance(shape, StructureShape):
+ self._print_structure_declaration(output, doc, quote_types)
+ elif isinstance(shape, ListShape):
+ output.write(
+ f"{to_valid_python_name(shape.name)} = List[{q}{to_valid_python_name(shape.member.name)}{q}]"
+ )
+ elif isinstance(shape, MapShape):
+ output.write(
+ f"{to_valid_python_name(shape.name)} = Dict[{q}{to_valid_python_name(shape.key.name)}{q}, {q}{to_valid_python_name(shape.value.name)}{q}]"
+ )
+ elif isinstance(shape, StringShape):
+ if shape.enum:
+ output.write(f"class {to_valid_python_name(shape.name)}(StrEnum):\n")
+ for value in shape.enum:
+ name = to_valid_python_name(value)
+ output.write(f' {name} = "{value}"\n')
+ else:
+ output.write(f"{to_valid_python_name(shape.name)} = str")
+ elif shape.type_name == "string":
+ output.write(f"{to_valid_python_name(shape.name)} = str")
+ elif shape.type_name == "integer":
+ output.write(f"{to_valid_python_name(shape.name)} = int")
+ elif shape.type_name == "long":
+ output.write(f"{to_valid_python_name(shape.name)} = int")
+ elif shape.type_name == "double":
+ output.write(f"{to_valid_python_name(shape.name)} = float")
+ elif shape.type_name == "float":
+ output.write(f"{to_valid_python_name(shape.name)} = float")
+ elif shape.type_name == "boolean":
+ output.write(f"{to_valid_python_name(shape.name)} = bool")
+ elif shape.type_name == "blob":
+ # blobs are often associated with streaming payloads, but we handle that on operation level,
+ # not on shape level
+ output.write(f"{to_valid_python_name(shape.name)} = bytes")
+ elif shape.type_name == "timestamp":
+ output.write(f"{to_valid_python_name(shape.name)} = datetime")
+ else:
+ output.write(
+ f"# unknown shape type for {to_valid_python_name(shape.name)}: {shape.type_name}"
+ )
+ # TODO: BoxedInteger?
+
+ output.write("\n")
+
+ def get_order(self):
+ """
+ Defines a basic order in which to sort the stack of shape nodes before printing.
+ First all non-enum primitives are printed, then enums, then exceptions, then all other types.
+ """
+ if self.is_primitive:
+ if self.is_enum:
+ return 1
+ else:
+ return 0
+
+ if self.is_exception:
+ return 2
+
+ return 3
+
+
+def generate_service_types(output, service: ServiceModel, doc=True):
+ output.write("from datetime import datetime\n")
+ output.write("from enum import StrEnum\n")
+ output.write(
+ "from typing import Dict, List, Optional, Iterator, Iterable, IO, Union, TypedDict\n"
+ )
+ output.write("\n")
+ output.write(
+ "from localstack.aws.api import handler, RequestContext, ServiceException, ServiceRequest"
+ )
+ output.write("\n")
+
+ # ==================================== print type declarations
+ nodes: Dict[str, ShapeNode] = {}
+
+ for shape_name in service.shape_names:
+ shape = service.shape_for(shape_name)
+ nodes[to_valid_python_name(shape_name)] = ShapeNode(service, shape)
+
+ # output.write("__all__ = [\n")
+ # for name in nodes.keys():
+ # output.write(f' "{name}",\n')
+ # output.write("]\n")
+
+ printed: Set[str] = set()
+ visited: Set[str] = set()
+ stack: List[str] = list(nodes.keys())
+
+ stack = sorted(stack, key=lambda name: nodes[name].get_order())
+ stack.reverse()
+
+ while stack:
+ name = stack.pop()
+ if name in printed:
+ continue
+ node = nodes[name]
+
+ dependencies = [dep for dep in node.dependencies if dep not in printed]
+
+ if not dependencies:
+ node.print_declaration(output, doc=doc)
+ printed.add(name)
+ elif name in visited:
+ # break out of circular dependencies
+ node.print_declaration(output, doc=doc, quote_types=True)
+ printed.add(name)
+ else:
+ stack.append(name)
+ stack.extend(dependencies)
+ visited.add(name)
+
+
+def generate_service_api(output, service: ServiceModel, doc=True):
+ service_name = service.service_name.replace("-", "_")
+ class_name = service_name + "_api"
+ class_name = snake_to_camel_case(class_name)
+
+ output.write(f"class {class_name}:\n")
+ output.write("\n")
+ output.write(f' service = "{service.service_name}"\n')
+ output.write(f' version = "{service.api_version}"\n')
+ for op_name in service.operation_names:
+ operation: OperationModel = service.operation_model(op_name)
+
+ fn_name = camel_to_snake_case(op_name)
+
+ if operation.output_shape:
+ output_shape = to_valid_python_name(operation.output_shape.name)
+ else:
+ output_shape = "None"
+
+ output.write("\n")
+ parameters = OrderedDict()
+ param_shapes = OrderedDict()
+
+ if input_shape := operation.input_shape:
+ members = list(input_shape.members)
+
+ streaming_payload_member = None
+ if operation.has_streaming_input:
+ streaming_payload_member = operation.input_shape.serialization.get("payload")
+
+ for m in input_shape.required_members:
+ members.remove(m)
+ m_shape = input_shape.members[m]
+ type_name = to_valid_python_name(m_shape.name)
+ if m == streaming_payload_member:
+ type_name = f"IO[{type_name}]"
+ parameters[xform_name(m)] = type_name
+ param_shapes[xform_name(m)] = m_shape
+
+ for m in members:
+ m_shape = input_shape.members[m]
+ param_shapes[xform_name(m)] = m_shape
+ type_name = to_valid_python_name(m_shape.name)
+ if m == streaming_payload_member:
+ type_name = f"IO[{type_name}]"
+ parameters[xform_name(m)] = f"{type_name} = None"
+
+ if any(map(is_bad_param_name, parameters.keys())):
+ # if we cannot render the parameter name, don't expand the parameters in the handler
+ param_list = f"request: {to_valid_python_name(input_shape.name)}" if input_shape else ""
+ output.write(f' @handler("{operation.name}", expand=False)\n')
+ else:
+ param_list = ", ".join([f"{k}: {v}" for k, v in parameters.items()])
+ output.write(f' @handler("{operation.name}")\n')
+
+ # add the **kwargs in the end
+ if param_list:
+ param_list += ", **kwargs"
+ else:
+ param_list = "**kwargs"
+
+ output.write(
+ f" def {fn_name}(self, context: RequestContext, {param_list}) -> {output_shape}:\n"
+ )
+
+ # convert html documentation to rst and print it into to the signature
+ if doc:
+ html = operation.documentation
+ rst = html_to_rst(html)
+ output.write(' """')
+ output.write(f"{rst}\n")
+ output.write("\n")
+
+ # parameters
+ for param_name, shape in param_shapes.items():
+ # FIXME: this doesn't work properly
+ rst = html_to_rst(shape.documentation)
+ rst = rst.strip().split(".")[0] + "."
+ output.write(f":param {param_name}: {rst}\n")
+
+ # return value
+ if operation.output_shape:
+ output.write(f":returns: {to_valid_python_name(operation.output_shape.name)}\n")
+
+ # errors
+ for error in operation.error_shapes:
+ output.write(f":raises {to_valid_python_name(error.name)}:\n")
+
+ output.write(' """\n')
+
+ output.write(" raise NotImplementedError\n")
+
+
+@click.group()
+def scaffold():
+ pass
+
+
+@scaffold.command(name="generate")
+@click.argument("service", type=str)
+@click.option("--doc/--no-doc", default=False, help="whether or not to generate docstrings")
+@click.option(
+ "--save/--print",
+ default=False,
+ help="whether or not to save the result into the api directory",
+)
+@click.option(
+ "--path",
+ default="./localstack-core/localstack/aws/api",
+ help="the path where the api should be saved",
+)
+def generate(service: str, doc: bool, save: bool, path: str):
+ """
+ Generate types and API stubs for a given AWS service.
+
+ SERVICE is the service to generate the stubs for (e.g., sqs, or cloudformation)
+ """
+ from click import ClickException
+
+ try:
+ code = generate_code(service, doc=doc)
+ except UnknownServiceError:
+ raise ClickException(f"unknown service {service}")
+
+ if not save:
+ # either just print the code to stdout
+ click.echo(code)
+ return
+
+ # or find the file path and write the code to that location
+ create_code_directory(service, code, path)
+ click.echo("done!")
+
+
+def generate_code(service_name: str, doc: bool = False) -> str:
+ model = load_service(service_name)
+ output = io.StringIO()
+ generate_service_types(output, model, doc=doc)
+ generate_service_api(output, model, doc=doc)
+ return output.getvalue()
+
+
+def create_code_directory(service_name: str, code: str, base_path: str):
+ service_name = service_name.replace("-", "_")
+ # handle service names which are reserved keywords in python (f.e. lambda)
+ if is_keyword(service_name):
+ service_name += "_"
+ path = Path(base_path, service_name)
+
+ if not path.exists():
+ click.echo(f"creating directory {path}")
+ path.mkdir()
+
+ file = path / "__init__.py"
+ click.echo(f"writing to file {file}")
+ file.write_text(code)
+
+
+@scaffold.command()
+@click.option("--doc/--no-doc", default=False, help="whether or not to generate docstrings")
+@click.option(
+ "--path",
+ default="./localstack-core/localstack/aws/api",
+ help="the path in which to upgrade ASF APIs",
+)
+def upgrade(path: str, doc: bool = False):
+ """
+ Execute the code generation for all existing APIs.
+ """
+ services = [
+ d.name.rstrip("_").replace("_", "-")
+ for d in Path(path).iterdir()
+ if d.is_dir() and not d.name.startswith("__")
+ ]
+
+ with Pool() as pool:
+ pool.starmap(_do_generate_code, [(service, path, doc) for service in services])
+
+ click.echo("done!")
+
+
+def _do_generate_code(service: str, path: str, doc: bool):
+ try:
+ code = generate_code(service, doc)
+ except UnknownServiceError:
+ click.echo(f"unknown service {service}! skipping...")
+ return
+ create_code_directory(service, code, base_path=path)
+
+
+if __name__ == "__main__":
+ scaffold()
diff --git a/localstack/services/__init__.py b/localstack-core/localstack/aws/serving/__init__.py
similarity index 100%
rename from localstack/services/__init__.py
rename to localstack-core/localstack/aws/serving/__init__.py
diff --git a/localstack-core/localstack/aws/serving/asgi.py b/localstack-core/localstack/aws/serving/asgi.py
new file mode 100644
index 0000000000000..3bbeefd49944f
--- /dev/null
+++ b/localstack-core/localstack/aws/serving/asgi.py
@@ -0,0 +1,5 @@
+from rolo.gateway.asgi import AsgiGateway
+
+__all__ = [
+ "AsgiGateway",
+]
diff --git a/localstack-core/localstack/aws/serving/edge.py b/localstack-core/localstack/aws/serving/edge.py
new file mode 100644
index 0000000000000..0e204a4d96f88
--- /dev/null
+++ b/localstack-core/localstack/aws/serving/edge.py
@@ -0,0 +1,119 @@
+import logging
+import threading
+from typing import List
+
+from rolo.gateway.wsgi import WsgiGateway
+
+from localstack import config
+from localstack.aws.app import LocalstackAwsGateway
+from localstack.config import HostAndPort
+from localstack.runtime import get_current_runtime
+from localstack.runtime.shutdown import ON_AFTER_SERVICE_SHUTDOWN_HANDLERS
+from localstack.utils.collections import ensure_list
+
+LOG = logging.getLogger(__name__)
+
+
+def serve_gateway(
+ listen: HostAndPort | List[HostAndPort], use_ssl: bool, asynchronous: bool = False
+):
+ """
+ Implementation of the edge.do_start_edge_proxy interface to start a Hypercorn server instance serving the
+ LocalstackAwsGateway.
+ """
+
+ gateway = get_current_runtime().components.gateway
+
+ listens = ensure_list(listen)
+
+ if config.GATEWAY_SERVER == "hypercorn":
+ return _serve_hypercorn(gateway, listens, use_ssl, asynchronous)
+ elif config.GATEWAY_SERVER == "werkzeug":
+ return _serve_werkzeug(gateway, listens, use_ssl, asynchronous)
+ elif config.GATEWAY_SERVER == "twisted":
+ return _serve_twisted(gateway, listens, use_ssl, asynchronous)
+ else:
+ raise ValueError(f"Unknown gateway server type {config.GATEWAY_SERVER}")
+
+
+def _serve_werkzeug(
+ gateway: LocalstackAwsGateway, listen: List[HostAndPort], use_ssl: bool, asynchronous: bool
+):
+ from werkzeug.serving import ThreadedWSGIServer
+
+ from .werkzeug import CustomWSGIRequestHandler
+
+ params = {
+ "app": WsgiGateway(gateway),
+ "handler": CustomWSGIRequestHandler,
+ }
+
+ if use_ssl:
+ from localstack.utils.ssl import create_ssl_cert, install_predefined_cert_if_available
+
+ install_predefined_cert_if_available()
+ serial_number = listen[0].port
+ _, cert_file_name, key_file_name = create_ssl_cert(serial_number=serial_number)
+ params["ssl_context"] = (cert_file_name, key_file_name)
+
+ threads = []
+ servers: List[ThreadedWSGIServer] = []
+
+ for host_port in listen:
+ kwargs = dict(params)
+ kwargs["host"] = host_port.host
+ kwargs["port"] = host_port.port
+ server = ThreadedWSGIServer(**kwargs)
+ servers.append(server)
+ threads.append(
+ threading.Thread(
+ target=server.serve_forever, name=f"werkzeug-server-{host_port.port}", daemon=True
+ )
+ )
+
+ def _shutdown_servers():
+ LOG.debug("[shutdown] Shutting down gateway servers")
+ for _srv in servers:
+ _srv.shutdown()
+
+ ON_AFTER_SERVICE_SHUTDOWN_HANDLERS.register(_shutdown_servers)
+
+ for thread in threads:
+ thread.start()
+
+ if not asynchronous:
+ for thread in threads:
+ return thread.join()
+
+ # FIXME: thread handling is a bit wonky
+ return threads[0]
+
+
+def _serve_hypercorn(
+ gateway: LocalstackAwsGateway, listen: List[HostAndPort], use_ssl: bool, asynchronous: bool
+):
+ from localstack.http.hypercorn import GatewayServer
+
+ # start serving gateway
+ server = GatewayServer(gateway, listen, use_ssl, config.GATEWAY_WORKER_COUNT)
+ server.start()
+
+ # with the current way the infrastructure is started, this is the easiest way to shut down the server correctly
+ # FIXME: but the infrastructure shutdown should be much cleaner, core components like the gateway should be handled
+ # explicitly by the thing starting the components, not implicitly by the components.
+ def _shutdown_gateway():
+ LOG.debug("[shutdown] Shutting down gateway server")
+ server.shutdown()
+
+ ON_AFTER_SERVICE_SHUTDOWN_HANDLERS.register(_shutdown_gateway)
+ if not asynchronous:
+ server.join()
+ return server._thread
+
+
+def _serve_twisted(
+ gateway: LocalstackAwsGateway, listen: List[HostAndPort], use_ssl: bool, asynchronous: bool
+):
+ from .twisted import serve_gateway
+
+ return serve_gateway(gateway, listen, use_ssl, asynchronous)
diff --git a/localstack-core/localstack/aws/serving/hypercorn.py b/localstack-core/localstack/aws/serving/hypercorn.py
new file mode 100644
index 0000000000000..450d2664badc9
--- /dev/null
+++ b/localstack-core/localstack/aws/serving/hypercorn.py
@@ -0,0 +1,47 @@
+import asyncio
+from typing import Any, Optional, Tuple
+
+from hypercorn import Config
+from hypercorn.asyncio import serve as serve_hypercorn
+
+from localstack import constants
+
+from ..gateway import Gateway
+from .asgi import AsgiGateway
+
+
+def serve(
+ gateway: Gateway,
+ host: str = "localhost",
+ port: int = constants.DEFAULT_PORT_EDGE,
+ use_reloader: bool = True,
+ ssl_creds: Optional[Tuple[Any, Any]] = None,
+ **kwargs,
+) -> None:
+ """
+ Serve the given Gateway through a hypercorn server and block until it is completed.
+
+ :param gateway: the Gateway instance to serve
+ :param host: the host to expose the server on
+ :param port: the port to expose the server on
+ :param use_reloader: whether to use the reloader
+ :param ssl_creds: the ssl credentials (tuple of certfile and keyfile)
+ :param kwargs: any oder parameters that can be passed to the hypercorn.Config object
+ """
+ config = Config()
+ config.h11_pass_raw_headers = True
+ config.bind = f"{host}:{port}"
+ config.use_reloader = use_reloader
+
+ if ssl_creds:
+ cert_file_name, key_file_name = ssl_creds
+ if cert_file_name:
+ kwargs["certfile"] = cert_file_name
+ if key_file_name:
+ kwargs["keyfile"] = key_file_name
+
+ for k, v in kwargs.items():
+ setattr(config, k, v)
+
+ loop = asyncio.new_event_loop()
+ loop.run_until_complete(serve_hypercorn(AsgiGateway(gateway, event_loop=loop), config))
diff --git a/localstack-core/localstack/aws/serving/twisted.py b/localstack-core/localstack/aws/serving/twisted.py
new file mode 100644
index 0000000000000..549150a73ae61
--- /dev/null
+++ b/localstack-core/localstack/aws/serving/twisted.py
@@ -0,0 +1,173 @@
+"""
+Bindings to serve LocalStack using twisted.
+"""
+
+import logging
+import time
+from typing import List
+
+from rolo.gateway import Gateway
+from rolo.serving.twisted import TwistedGateway
+from twisted.internet import endpoints, interfaces, reactor, ssl
+from twisted.protocols.policies import ProtocolWrapper, WrappingFactory
+from twisted.protocols.tls import BufferingTLSTransport, TLSMemoryBIOFactory
+from twisted.python.threadpool import ThreadPool
+
+from localstack import config
+from localstack.config import HostAndPort
+from localstack.runtime.shutdown import ON_AFTER_SERVICE_SHUTDOWN_HANDLERS
+from localstack.utils.patch import patch
+from localstack.utils.ssl import create_ssl_cert, install_predefined_cert_if_available
+from localstack.utils.threads import start_worker_thread
+
+LOG = logging.getLogger(__name__)
+
+
+class TLSMultiplexer(ProtocolWrapper):
+ """
+ Custom protocol to multiplex HTTPS and HTTP connections over the same port. This is the equivalent of
+ ``DuplexSocket``, but since twisted use its own SSL layer and doesn't use `ssl.SSLSocket``, we need to implement
+ the multiplexing behavior in the Twisted layer.
+
+ The basic idea is to defer the ``makeConnection`` call until the first data are received, and then re-configure
+ the underlying ``wrappedProtocol`` if needed with a TLS wrapper.
+ """
+
+ tlsProtocol = BufferingTLSTransport
+
+ def __init__(
+ self,
+ factory: "WrappingFactory",
+ wrappedProtocol: interfaces.IProtocol,
+ ):
+ super().__init__(factory, wrappedProtocol)
+ self._isInitialized = False
+ self._isTLS = None
+ self._negotiatedProtocol = None
+
+ def makeConnection(self, transport):
+ self.connected = 1
+ self.transport = transport
+ self.factory.registerProtocol(self) # this is idempotent
+ # we defer the actual makeConnection call to the first invocation of dataReceived
+
+ def dataReceived(self, data: bytes) -> None:
+ if self._isInitialized:
+ super().dataReceived(data)
+ return
+
+ # once the first data have been received, we can check whether it's a TLS handshake, then we need to run the
+ # actual makeConnection procedure.
+ self._isInitialized = True
+ self._isTLS = data[0] == 22 # 0x16 is the marker byte identifying a TLS handshake
+
+ if self._isTLS:
+ # wrap protocol again in tls protocol
+ self.wrappedProtocol = self.tlsProtocol(self.factory, self.wrappedProtocol)
+ else:
+ if data.startswith(b"PRI * HTTP/2"):
+ # TODO: can we do proper protocol negotiation like in ALPN?
+ # in the TLS case, this is determined by the ALPN procedure by OpenSSL.
+ self._negotiatedProtocol = b"h2"
+
+ # now that we've set the real wrapped protocol, run the make connection procedure
+ super().makeConnection(self.transport)
+ super().dataReceived(data)
+
+ @property
+ def negotiatedProtocol(self) -> str | None:
+ if self._negotiatedProtocol:
+ return self._negotiatedProtocol
+ return self.wrappedProtocol.negotiatedProtocol
+
+
+class TLSMultiplexerFactory(TLSMemoryBIOFactory):
+ protocol = TLSMultiplexer
+
+
+def stop_thread_pool(self: ThreadPool, stop, timeout: float = None):
+ """
+ Patch for a custom shutdown procedure for a ThreadPool that waits a given amount of time for all threads.
+
+ :param self: the pool to shut down
+ :param stop: the original function
+ :param timeout: the maximum amount of time to wait
+ """
+ # copied from ThreadPool.stop()
+ if self.joined:
+ return
+ if not timeout:
+ stop()
+ return
+
+ self.joined = True
+ self.started = False
+ self._team.quit()
+
+ # our own joining logic with timeout
+ remaining = timeout
+ total_waited = 0
+
+ for thread in self.threads:
+ then = time.time()
+
+ # LOG.info("[shutdown] Joining thread %s", thread)
+ thread.join(remaining)
+
+ waited = time.time() - then
+ total_waited += waited
+ remaining -= waited
+
+ if thread.is_alive():
+ LOG.warning(
+ "[shutdown] Request thread %s still alive after %.2f seconds",
+ thread,
+ total_waited,
+ )
+
+ if remaining <= 0:
+ remaining = 0
+
+
+def serve_gateway(
+ gateway: Gateway, listen: List[HostAndPort], use_ssl: bool, asynchronous: bool = False
+):
+ """
+ Serve a Gateway instance using twisted.
+ """
+ # setup reactor
+ reactor.suggestThreadPoolSize(config.GATEWAY_WORKER_COUNT)
+ thread_pool = reactor.getThreadPool()
+ patch(thread_pool.stop)(stop_thread_pool)
+
+ def _shutdown_reactor():
+ LOG.debug("[shutdown] Shutting down twisted reactor serving the gateway")
+ thread_pool.stop(timeout=10)
+ reactor.stop()
+
+ ON_AFTER_SERVICE_SHUTDOWN_HANDLERS.register(_shutdown_reactor)
+
+ # setup twisted webserver Site
+ site = TwistedGateway(gateway)
+
+ # configure ssl
+ if use_ssl:
+ install_predefined_cert_if_available()
+ serial_number = listen[0].port
+ _, cert_file_name, key_file_name = create_ssl_cert(serial_number=serial_number)
+ context_factory = ssl.DefaultOpenSSLContextFactory(key_file_name, cert_file_name)
+ context_factory.getContext().use_certificate_chain_file(cert_file_name)
+ protocol_factory = TLSMultiplexerFactory(context_factory, False, site)
+ else:
+ protocol_factory = site
+
+ # add endpoint for each host/port combination
+ for host_and_port in listen:
+ # TODO: interface = host?
+ endpoint = endpoints.TCP4ServerEndpoint(reactor, host_and_port.port)
+ endpoint.listen(protocol_factory)
+
+ if asynchronous:
+ return start_worker_thread(reactor.run)
+ else:
+ return reactor.run()
diff --git a/localstack-core/localstack/aws/serving/werkzeug.py b/localstack-core/localstack/aws/serving/werkzeug.py
new file mode 100644
index 0000000000000..22e351adc4842
--- /dev/null
+++ b/localstack-core/localstack/aws/serving/werkzeug.py
@@ -0,0 +1,58 @@
+import ssl
+from typing import TYPE_CHECKING, Any, Optional, Tuple
+
+from rolo.gateway import Gateway
+from rolo.gateway.wsgi import WsgiGateway
+from werkzeug import run_simple
+from werkzeug.serving import WSGIRequestHandler
+
+if TYPE_CHECKING:
+ from _typeshed.wsgi import WSGIEnvironment
+
+from localstack import constants
+
+
+def serve(
+ gateway: Gateway,
+ host: str = "localhost",
+ port: int = constants.DEFAULT_PORT_EDGE,
+ use_reloader: bool = True,
+ ssl_creds: Optional[Tuple[Any, Any]] = None,
+ **kwargs,
+) -> None:
+ """
+ Serve a Gateway as a WSGI application through werkzeug. This is mostly for development purposes.
+
+ :param gateway: the Gateway to serve
+ :param host: the host to expose the server to
+ :param port: the port to expose the server to
+ :param use_reloader: whether to autoreload the server on changes
+ :param kwargs: any other arguments that can be passed to `werkzeug.run_simple`
+ """
+ kwargs["threaded"] = kwargs.get("threaded", True) # make sure requests don't block
+ kwargs["ssl_context"] = ssl_creds
+ kwargs.setdefault("request_handler", CustomWSGIRequestHandler)
+ run_simple(host, port, WsgiGateway(gateway), use_reloader=use_reloader, **kwargs)
+
+
+class CustomWSGIRequestHandler(WSGIRequestHandler):
+ def make_environ(self) -> "WSGIEnvironment":
+ environ = super().make_environ()
+
+ # restore RAW_URI from the requestline will be something like ``GET //foo/?foo=bar%20ed HTTP/1.1``
+ environ["RAW_URI"] = " ".join(self.requestline.split(" ")[1:-1])
+
+ # restore raw headers for rolo
+ environ["asgi.headers"] = [
+ (k.encode("latin-1"), v.encode("latin-1")) for k, v in self.headers.raw_items()
+ ]
+
+ # the default WSGIRequestHandler does not understand our DuplexSocket, so it will always set https, which we
+ # correct here
+ try:
+ is_ssl = isinstance(self.request, ssl.SSLSocket)
+ except AttributeError:
+ is_ssl = False
+ environ["wsgi.url_scheme"] = "https" if is_ssl else "http"
+
+ return environ
diff --git a/localstack-core/localstack/aws/serving/wsgi.py b/localstack-core/localstack/aws/serving/wsgi.py
new file mode 100644
index 0000000000000..8ae26b3d8c9df
--- /dev/null
+++ b/localstack-core/localstack/aws/serving/wsgi.py
@@ -0,0 +1,5 @@
+from rolo.gateway.wsgi import WsgiGateway
+
+__all__ = [
+ "WsgiGateway",
+]
diff --git a/localstack-core/localstack/aws/skeleton.py b/localstack-core/localstack/aws/skeleton.py
new file mode 100644
index 0000000000000..9d66fa4b375c1
--- /dev/null
+++ b/localstack-core/localstack/aws/skeleton.py
@@ -0,0 +1,228 @@
+import inspect
+import logging
+from typing import Any, Callable, Dict, NamedTuple, Optional, Union
+
+from botocore import xform_name
+from botocore.model import ServiceModel
+
+from localstack.aws.api import (
+ CommonServiceException,
+ RequestContext,
+ ServiceException,
+)
+from localstack.aws.api.core import ServiceRequest, ServiceRequestHandler, ServiceResponse
+from localstack.aws.protocol.parser import create_parser
+from localstack.aws.protocol.serializer import ResponseSerializer, create_serializer
+from localstack.aws.spec import load_service
+from localstack.http import Response
+from localstack.utils import analytics
+from localstack.utils.coverage_docs import get_coverage_link_for_service
+
+LOG = logging.getLogger(__name__)
+
+DispatchTable = Dict[str, ServiceRequestHandler]
+
+
+def create_skeleton(service: Union[str, ServiceModel], delegate: Any):
+ if isinstance(service, str):
+ service = load_service(service)
+
+ return Skeleton(service, create_dispatch_table(delegate))
+
+
+class HandlerAttributes(NamedTuple):
+ """
+ Holder object of the attributes added to a function by the @handler decorator.
+ """
+
+ function_name: str
+ operation: str
+ pass_context: bool
+ expand_parameters: bool
+
+
+def create_dispatch_table(delegate: object) -> DispatchTable:
+ """
+ Creates a dispatch table for a given object. First, the entire class tree of the object is scanned to find any
+ functions that are decorated with @handler. It then resolves those functions on the delegate.
+ """
+ # scan class tree for @handler wrapped functions (reverse class tree so that inherited functions overwrite parent
+ # functions)
+ cls_tree = inspect.getmro(delegate.__class__)
+ handlers: Dict[str, HandlerAttributes] = {}
+ cls_tree = reversed(list(cls_tree))
+ for cls in cls_tree:
+ if cls == object:
+ continue
+
+ for name, fn in inspect.getmembers(cls, inspect.isfunction):
+ try:
+ # attributes come from operation_marker in @handler wrapper
+ handlers[fn.operation] = HandlerAttributes(
+ fn.__name__, fn.operation, fn.pass_context, fn.expand_parameters
+ )
+ except AttributeError:
+ pass
+
+ # create dispatch table from operation handlers by resolving bound functions on the delegate
+ dispatch_table: DispatchTable = {}
+ for handler in handlers.values():
+ # resolve the bound function of the delegate
+ bound_function = getattr(delegate, handler.function_name)
+ # create a dispatcher
+ dispatch_table[handler.operation] = ServiceRequestDispatcher(
+ bound_function,
+ operation=handler.operation,
+ pass_context=handler.pass_context,
+ expand_parameters=handler.expand_parameters,
+ )
+
+ return dispatch_table
+
+
+class ServiceRequestDispatcher:
+ fn: Callable
+ operation: str
+ expand_parameters: bool = True
+ pass_context: bool = True
+
+ def __init__(
+ self,
+ fn: Callable,
+ operation: str,
+ pass_context: bool = True,
+ expand_parameters: bool = True,
+ ):
+ self.fn = fn
+ self.operation = operation
+ self.pass_context = pass_context
+ self.expand_parameters = expand_parameters
+
+ def __call__(
+ self, context: RequestContext, request: ServiceRequest
+ ) -> Optional[ServiceResponse]:
+ args = []
+ kwargs = {}
+
+ if not self.expand_parameters:
+ if self.pass_context:
+ args.append(context)
+ args.append(request)
+ else:
+ if request is None:
+ kwargs = {}
+ else:
+ kwargs = {xform_name(k): v for k, v in request.items()}
+ kwargs["context"] = context
+
+ return self.fn(*args, **kwargs)
+
+
+class Skeleton:
+ service: ServiceModel
+ dispatch_table: DispatchTable
+
+ def __init__(self, service: ServiceModel, implementation: Union[Any, DispatchTable]):
+ self.service = service
+
+ if isinstance(implementation, dict):
+ self.dispatch_table = implementation
+ else:
+ self.dispatch_table = create_dispatch_table(implementation)
+
+ def invoke(self, context: RequestContext) -> Response:
+ serializer = create_serializer(context.service)
+
+ if context.operation and context.service_request:
+ # if the parsed request is already set in the context, re-use them
+ operation, instance = context.operation, context.service_request
+ else:
+ # otherwise, parse the incoming HTTPRequest
+ operation, instance = create_parser(context.service).parse(context.request)
+ context.operation = operation
+
+ try:
+ # Find the operation's handler in the dispatch table
+ if operation.name not in self.dispatch_table:
+ LOG.warning(
+ "missing entry in dispatch table for %s.%s",
+ self.service.service_name,
+ operation.name,
+ )
+ raise NotImplementedError
+
+ return self.dispatch_request(serializer, context, instance)
+ except ServiceException as e:
+ return self.on_service_exception(serializer, context, e)
+ except NotImplementedError as e:
+ return self.on_not_implemented_error(serializer, context, e)
+
+ def dispatch_request(
+ self, serializer: ResponseSerializer, context: RequestContext, instance: ServiceRequest
+ ) -> Response:
+ operation = context.operation
+
+ handler = self.dispatch_table[operation.name]
+
+ # Call the appropriate handler
+ result = handler(context, instance) or {}
+
+ # if the service handler returned an HTTP request, forego serialization and return immediately
+ if isinstance(result, Response):
+ return result
+
+ context.service_response = result
+
+ # Serialize result dict to a Response and return it
+ return serializer.serialize_to_response(
+ result, operation, context.request.headers, context.request_id
+ )
+
+ def on_service_exception(
+ self, serializer: ResponseSerializer, context: RequestContext, exception: ServiceException
+ ) -> Response:
+ """
+ Called by invoke if the handler of the operation raised a ServiceException.
+
+ :param serializer: serializer which should be used to serialize the exception
+ :param context: the request context
+ :param exception: the exception that was raised
+ :return: a Response object
+ """
+ context.service_exception = exception
+
+ return serializer.serialize_error_to_response(
+ exception, context.operation, context.request.headers, context.request_id
+ )
+
+ def on_not_implemented_error(
+ self,
+ serializer: ResponseSerializer,
+ context: RequestContext,
+ exception: NotImplementedError,
+ ) -> Response:
+ """
+ Called by invoke if either the dispatch table did not contain an entry for the operation, or the service
+ provider raised a NotImplementedError
+ :param serializer: the serialzier which should be used to serialize the NotImplementedError
+ :param context: the request context
+ :param exception: the NotImplementedError that was raised
+ :return: a Response object
+ """
+ operation = context.operation
+
+ action_name = operation.name
+ service_name = operation.service_model.service_name
+ exception_message: str | None = exception.args[0] if exception.args else None
+ message = exception_message or get_coverage_link_for_service(service_name, action_name)
+ LOG.info(message)
+ error = CommonServiceException("InternalFailure", message, status_code=501)
+ # record event
+ analytics.log.event(
+ "services_notimplemented", payload={"s": service_name, "a": action_name}
+ )
+ context.service_exception = error
+
+ return serializer.serialize_error_to_response(
+ error, operation, context.request.headers, context.request_id
+ )
diff --git a/localstack-core/localstack/aws/spec-patches.json b/localstack-core/localstack/aws/spec-patches.json
new file mode 100644
index 0000000000000..dbe268b52c45b
--- /dev/null
+++ b/localstack-core/localstack/aws/spec-patches.json
@@ -0,0 +1,1344 @@
+{
+ "s3/2006-03-01/service-2": [
+ {
+ "op": "add",
+ "path": "/shapes/NoSuchBucket/members/BucketName",
+ "value": {
+ "shape": "BucketName"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/NoSuchBucket/error",
+ "value": {
+ "httpStatusCode": 404
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/NoSuchLifecycleConfiguration",
+ "value": {
+ "type": "structure",
+ "members": {
+ "BucketName": {
+ "shape": "BucketName"
+ }
+ },
+ "error": {
+ "httpStatusCode": 404
+ },
+ "documentation": "The lifecycle configuration does not exist
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/InvalidBucketName",
+ "value": {
+ "type": "structure",
+ "members": {
+ "BucketName": {
+ "shape": "BucketName"
+ }
+ },
+ "error": {
+ "httpStatusCode": 400
+ },
+ "documentation": "The specified bucket is not valid.
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/BucketRegion",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/BucketContentType",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/HeadBucketOutput",
+ "value": {
+ "type": "structure",
+ "members": {
+ "BucketRegion": {
+ "shape": "BucketRegion",
+ "location": "header",
+ "locationName": "x-amz-bucket-region"
+ },
+ "BucketContentType": {
+ "shape": "BucketContentType",
+ "location": "header",
+ "locationName": "content-type"
+ }
+ }
+ }
+ },
+ {
+ "op": "add",
+ "path": "/operations/HeadBucket/output",
+ "value": {
+ "shape": "HeadBucketOutput"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/operations/PutBucketPolicy/http/responseCode",
+ "value": 204
+ },
+ {
+ "op": "add",
+ "path": "/shapes/GetBucketLocationOutput/payload",
+ "value": "LocationConstraint"
+ },
+ {
+ "op": "add",
+ "path": "/shapes/BucketAlreadyOwnedByYou/members/BucketName",
+ "value": {
+ "shape": "BucketName"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/BucketAlreadyOwnedByYou/error",
+ "value": {
+ "httpStatusCode": 409
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/GetObjectOutput/members/StatusCode",
+ "value": {
+ "shape": "GetObjectResponseStatusCode",
+ "location": "statusCode"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/HeadObjectOutput/members/StatusCode",
+ "value": {
+ "shape": "GetObjectResponseStatusCode",
+ "location": "statusCode"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/NoSuchKey/members/Key",
+ "value": {
+ "shape": "ObjectKey"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/NoSuchKey/error",
+ "value": {
+ "httpStatusCode": 404
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/NoSuchKey/members/DeleteMarker",
+ "value": {
+ "shape": "DeleteMarker",
+ "location": "header",
+ "locationName": "x-amz-delete-marker"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/NoSuchKey/members/VersionId",
+ "value": {
+ "shape": "ObjectVersionId",
+ "location": "header",
+ "locationName": "x-amz-version-id"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/NoSuchVersion",
+ "value": {
+ "type": "structure",
+ "members": {
+ "VersionId": {
+ "shape": "ObjectVersionId"
+ },
+ "Key": {
+ "shape": "ObjectKey"
+ }
+ },
+ "error": {
+ "httpStatusCode": 404
+ },
+ "documentation": "
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/PreconditionFailed",
+ "value": {
+ "type": "structure",
+ "members": {
+ "Condition": {
+ "shape": "IfCondition"
+ }
+ },
+ "error": {
+ "httpStatusCode": 412
+ },
+ "documentation": "At least one of the pre-conditions you specified did not hold
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/IfCondition",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/InvalidRange",
+ "value": {
+ "type": "structure",
+ "members": {
+ "ActualObjectSize": {
+ "shape": "ObjectSize"
+ },
+ "RangeRequested": {
+ "shape": "ContentRange"
+ }
+ },
+ "error": {
+ "httpStatusCode": 416
+ },
+ "documentation": "The requested range is not satisfiable
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/HeadObjectOutput/members/Expires",
+ "value": {
+ "shape": "Expires",
+ "documentation": "The date and time at which the object is no longer cacheable.
",
+ "location": "header",
+ "locationName": "expires"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/GetObjectOutput/members/Expires",
+ "value": {
+ "shape": "Expires",
+ "documentation": "The date and time at which the object is no longer cacheable.
",
+ "location": "header",
+ "locationName": "expires"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/RestoreObjectOutputStatusCode",
+ "value": {
+ "type": "integer"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/RestoreObjectOutput/members/StatusCode",
+ "value": {
+ "shape": "RestoreObjectOutputStatusCode",
+ "location": "statusCode"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/InvalidArgument",
+ "value": {
+ "type": "structure",
+ "members": {
+ "ArgumentName": {
+ "shape": "ArgumentName"
+ },
+ "ArgumentValue": {
+ "shape": "ArgumentValue"
+ },
+ "HostId": {
+ "shape": "HostId"
+ }
+ },
+ "error": {
+ "httpStatusCode": 400
+ },
+ "documentation": "Invalid Argument
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/ArgumentName",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/ArgumentValue",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/SignatureDoesNotMatch",
+ "value": {
+ "type": "structure",
+ "members": {
+ "AWSAccessKeyId": {
+ "shape": "AWSAccessKeyId"
+ },
+ "CanonicalRequest": {
+ "shape": "CanonicalRequest"
+ },
+ "CanonicalRequestBytes": {
+ "shape": "CanonicalRequestBytes"
+ },
+ "HostId": {
+ "shape": "HostId"
+ },
+ "SignatureProvided": {
+ "shape": "SignatureProvided"
+ },
+ "StringToSign": {
+ "shape": "StringToSign"
+ },
+ "StringToSignBytes": {
+ "shape": "StringToSignBytes"
+ }
+ },
+ "error": {
+ "httpStatusCode": 403
+ },
+ "documentation": "The request signature we calculated does not match the signature you provided. Check your key and signing method.
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/AccessDenied",
+ "value": {
+ "type": "structure",
+ "members": {
+ "Expires": {
+ "shape": "Expires"
+ },
+ "ServerTime": {
+ "shape": "ServerTime"
+ },
+ "X_Amz_Expires": {
+ "shape": "X-Amz-Expires",
+ "locationName":"X-Amz-Expires"
+ },
+ "HostId": {
+ "shape": "HostId"
+ },
+ "HeadersNotSigned": {
+ "shape": "HeadersNotSigned"
+ }
+ },
+ "error": {
+ "httpStatusCode": 403
+ },
+ "documentation": "Request has expired
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/AWSAccessKeyId",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/HostId",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/HeadersNotSigned",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/SignatureProvided",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/StringToSign",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/StringToSignBytes",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/CanonicalRequest",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/CanonicalRequestBytes",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/ServerTime",
+ "value": {
+ "type": "timestamp"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/X-Amz-Expires",
+ "value": {
+ "type": "integer"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/AuthorizationQueryParametersError",
+ "value": {
+ "type": "structure",
+ "members": {
+ "HostId": {
+ "shape": "HostId"
+ }
+ },
+ "documentation": "Query-string authentication version 4 requires the X-Amz-Algorithm, X-Amz-Credential, X-Amz-Signature, X-Amz-Date, X-Amz-SignedHeaders, and X-Amz-Expires parameters.
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/operations/PostObject",
+ "value": {
+ "name":"PostObject",
+ "http":{
+ "method":"POST",
+ "requestUri":"/{Bucket}"
+ },
+ "input":{"shape":"PostObjectRequest"},
+ "output":{"shape":"PostResponse"},
+ "documentationUrl":"http://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectPOST.html",
+ "documentation":"The POST operation adds an object to a specified bucket by using HTML forms. POST is an alternate form of PUT that enables browser-based uploads as a way of putting objects in buckets. Parameters that are passed to PUT through HTTP Headers are instead passed as form fields to POST in the multipart/form-data encoded message body. To add an object to a bucket, you must have WRITE access on the bucket. Amazon S3 never stores partial objects. If you receive a successful response, you can be confident that the entire object was stored.
"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/PostObjectRequest",
+ "value": {
+ "type":"structure",
+ "required":[
+ "Bucket"
+ ],
+ "members":{
+ "Body":{
+ "shape":"Body",
+ "documentation":"
Object data.
",
+ "streaming":true
+ },
+ "Bucket":{
+ "shape":"BucketName",
+ "documentation":"The bucket name to which the PUT action was initiated.
When using this action with an access point, you must direct requests to the access point hostname. The access point hostname takes the form AccessPointName -AccountId .s3-accesspoint.Region .amazonaws.com. When using this action with an access point through the Amazon Web Services SDKs, you provide the access point ARN in place of the bucket name. For more information about access point ARNs, see Using access points in the Amazon S3 User Guide .
When using this action with Amazon S3 on Outposts, you must direct requests to the S3 on Outposts hostname. The S3 on Outposts hostname takes the form AccessPointName -AccountId .outpostID .s3-outposts.Region .amazonaws.com
. When using this action with S3 on Outposts through the Amazon Web Services SDKs, you provide the Outposts bucket ARN in place of the bucket name. For more information about S3 on Outposts ARNs, see Using Amazon S3 on Outposts in the Amazon S3 User Guide .
",
+ "location":"uri",
+ "locationName":"Bucket"
+ }
+ },
+ "payload":"Body"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/PostResponse",
+ "value": {
+ "type":"structure",
+ "members":{
+ "StatusCode": {
+ "shape": "GetObjectResponseStatusCode",
+ "location": "statusCode"
+ },
+ "Location":{
+ "shape":"Location",
+ "documentation":"The URI that identifies the newly created object.
"
+ },
+ "LocationHeader":{
+ "shape":"Location",
+ "documentation":"The URI that identifies the newly created object.
",
+ "location": "header",
+ "locationName": "Location"
+ },
+ "Bucket":{
+ "shape":"BucketName",
+ "documentation":"The name of the bucket that contains the newly created object. Does not return the access point ARN or access point alias if used.
When using this action with an access point, you must direct requests to the access point hostname. The access point hostname takes the form AccessPointName -AccountId .s3-accesspoint.Region .amazonaws.com. When using this action with an access point through the Amazon Web Services SDKs, you provide the access point ARN in place of the bucket name. For more information about access point ARNs, see Using access points in the Amazon S3 User Guide .
When using this action with Amazon S3 on Outposts, you must direct requests to the S3 on Outposts hostname. The S3 on Outposts hostname takes the form AccessPointName -AccountId .outpostID .s3-outposts.Region .amazonaws.com
. When using this action with S3 on Outposts through the Amazon Web Services SDKs, you provide the Outposts bucket ARN in place of the bucket name. For more information about S3 on Outposts ARNs, see Using Amazon S3 on Outposts in the Amazon S3 User Guide .
"
+ },
+ "Key":{
+ "shape":"ObjectKey",
+ "documentation":"The object key of the newly created object.
"
+ },
+ "Expiration": {
+ "shape": "Expiration",
+ "documentation": "If the expiration is configured for the object (see PutBucketLifecycleConfiguration ), the response includes this header. It includes the expiry-date
and rule-id
key-value pairs that provide information about object expiration. The value of the rule-id
is URL-encoded.
",
+ "location": "header",
+ "locationName": "x-amz-expiration"
+ },
+ "ETag":{
+ "shape":"ETag",
+ "documentation":"Entity tag that identifies the newly created object's data. Objects with different object data will have different entity tags. The entity tag is an opaque string. The entity tag may or may not be an MD5 digest of the object data. If the entity tag is not an MD5 digest of the object data, it will contain one or more nonhexadecimal characters and/or will consist of less than 32 or more than 32 hexadecimal digits. For more information about how the entity tag is calculated, see Checking object integrity in the Amazon S3 User Guide .
"
+ },
+ "ETagHeader":{
+ "shape":"ETag",
+ "documentation":"Entity tag that identifies the newly created object's data. Objects with different object data will have different entity tags. The entity tag is an opaque string. The entity tag may or may not be an MD5 digest of the object data. If the entity tag is not an MD5 digest of the object data, it will contain one or more nonhexadecimal characters and/or will consist of less than 32 or more than 32 hexadecimal digits. For more information about how the entity tag is calculated, see Checking object integrity in the Amazon S3 User Guide .
",
+ "location": "header",
+ "locationName": "ETag"
+ },
+ "ChecksumCRC32": {
+ "shape": "ChecksumCRC32",
+ "documentation": "The base64-encoded, 32-bit CRC32 checksum of the object. This will only be present if it was uploaded with the object. With multipart uploads, this may not be a checksum value of the object. For more information about how checksums are calculated with multipart uploads, see Checking object integrity in the Amazon S3 User Guide .
",
+ "location": "header",
+ "locationName": "x-amz-checksum-crc32"
+ },
+ "ChecksumCRC32C": {
+ "shape": "ChecksumCRC32C",
+ "documentation": "The base64-encoded, 32-bit CRC32C checksum of the object. This will only be present if it was uploaded with the object. With multipart uploads, this may not be a checksum value of the object. For more information about how checksums are calculated with multipart uploads, see Checking object integrity in the Amazon S3 User Guide .
",
+ "location": "header",
+ "locationName": "x-amz-checksum-crc32c"
+ },
+ "ChecksumSHA1": {
+ "shape": "ChecksumSHA1",
+ "documentation": "The base64-encoded, 160-bit SHA-1 digest of the object. This will only be present if it was uploaded with the object. With multipart uploads, this may not be a checksum value of the object. For more information about how checksums are calculated with multipart uploads, see Checking object integrity in the Amazon S3 User Guide .
",
+ "location": "header",
+ "locationName": "x-amz-checksum-sha1"
+ },
+ "ChecksumSHA256": {
+ "shape": "ChecksumSHA256",
+ "documentation": "The base64-encoded, 256-bit SHA-256 digest of the object. This will only be present if it was uploaded with the object. With multipart uploads, this may not be a checksum value of the object. For more information about how checksums are calculated with multipart uploads, see Checking object integrity in the Amazon S3 User Guide .
",
+ "location": "header",
+ "locationName": "x-amz-checksum-sha256"
+ },
+ "ServerSideEncryption": {
+ "shape": "ServerSideEncryption",
+ "documentation": "If you specified server-side encryption either with an Amazon Web Services KMS key or Amazon S3-managed encryption key in your PUT request, the response includes this header. It confirms the encryption algorithm that Amazon S3 used to encrypt the object.
",
+ "location": "header",
+ "locationName": "x-amz-server-side-encryption"
+ },
+ "VersionId": {
+ "shape": "ObjectVersionId",
+ "documentation": "Version of the object.
",
+ "location": "header",
+ "locationName": "x-amz-version-id"
+ },
+ "SSECustomerAlgorithm": {
+ "shape": "SSECustomerAlgorithm",
+ "documentation": "If server-side encryption with a customer-provided encryption key was requested, the response will include this header confirming the encryption algorithm used.
",
+ "location": "header",
+ "locationName": "x-amz-server-side-encryption-customer-algorithm"
+ },
+ "SSECustomerKeyMD5": {
+ "shape": "SSECustomerKeyMD5",
+ "documentation": "If server-side encryption with a customer-provided encryption key was requested, the response will include this header to provide round-trip message integrity verification of the customer-provided encryption key.
",
+ "location": "header",
+ "locationName": "x-amz-server-side-encryption-customer-key-MD5"
+ },
+ "SSEKMSKeyId": {
+ "shape": "SSEKMSKeyId",
+ "documentation": "If x-amz-server-side-encryption
is present and has the value of aws:kms
, this header specifies the ID of the Amazon Web Services Key Management Service (Amazon Web Services KMS) symmetric customer managed key that was used for the object.
",
+ "location": "header",
+ "locationName": "x-amz-server-side-encryption-aws-kms-key-id"
+ },
+ "SSEKMSEncryptionContext": {
+ "shape": "SSEKMSEncryptionContext",
+ "documentation": "If present, specifies the Amazon Web Services KMS Encryption Context to use for object encryption. The value of this header is a base64-encoded UTF-8 string holding JSON with the encryption context key-value pairs.
",
+ "location": "header",
+ "locationName": "x-amz-server-side-encryption-context"
+ },
+ "BucketKeyEnabled": {
+ "shape": "BucketKeyEnabled",
+ "documentation": "Indicates whether the uploaded object uses an S3 Bucket Key for server-side encryption with Amazon Web Services KMS (SSE-KMS).
",
+ "location": "header",
+ "locationName": "x-amz-server-side-encryption-bucket-key-enabled"
+ },
+ "RequestCharged": {
+ "shape": "RequestCharged",
+ "location": "header",
+ "locationName": "x-amz-request-charged"
+ }
+ }
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/NoSuchWebsiteConfiguration",
+ "value": {
+ "type": "structure",
+ "members": {
+ "BucketName": {
+ "shape": "BucketName"
+ }
+ },
+ "error": {
+ "httpStatusCode": 404
+ },
+ "documentation": "The specified bucket does not have a website configuration
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/NoSuchUpload/members/UploadId",
+ "value": {
+ "shape": "MultipartUploadId"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/NoSuchUpload/error",
+ "value": {
+ "httpStatusCode": 404
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/ReplicationConfigurationNotFoundError",
+ "value": {
+ "type": "structure",
+ "members": {
+ "BucketName": {
+ "shape": "BucketName"
+ }
+ },
+ "error": {
+ "httpStatusCode": 404
+ },
+ "documentation": "The replication configuration was not found.
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/BucketCannedACL/enum/4",
+ "value": "log-delivery-write",
+ "documentation": "Not included in the specs, but valid value according to the docs: https://docs.aws.amazon.com/AmazonS3/latest/userguide/acl-overview.html#canned-acl
"
+ },
+ {
+ "op": "add",
+ "path": "/shapes/BadRequest",
+ "value": {
+ "type": "structure",
+ "members": {
+ "HostId": {
+ "shape": "HostId"
+ }
+ },
+ "documentation": "Insufficient information. Origin request header needed.
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/AccessForbidden",
+ "value": {
+ "type": "structure",
+ "members": {
+ "HostId": {
+ "shape": "HostId"
+ },
+ "Method": {
+ "shape": "HttpMethod"
+ },
+ "ResourceType": {
+ "shape": "ResourceType"
+ }
+ },
+ "error": {
+ "httpStatusCode": 403
+ },
+ "documentation": "CORSResponse
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/HttpMethod",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/ResourceType",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/NoSuchCORSConfiguration",
+ "value": {
+ "type": "structure",
+ "members": {
+ "BucketName": {
+ "shape": "BucketName"
+ }
+ },
+ "error": {
+ "httpStatusCode": 404
+ },
+ "documentation": "The CORS configuration does not exist
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/MissingSecurityHeader",
+ "value": {
+ "type": "structure",
+ "members": {
+ "MissingHeaderName": {
+ "shape": "MissingHeaderName"
+ }
+ },
+ "error": {
+ "httpStatusCode": 400
+ },
+ "documentation": "Your request was missing a required header
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/MissingHeaderName",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/InvalidPartOrder",
+ "value": {
+ "type": "structure",
+ "members": {
+ "UploadId": {
+ "shape": "MultipartUploadId"
+ }
+ },
+ "error": {
+ "httpStatusCode": 400
+ },
+ "documentation": "The list of parts was not in ascending order. Parts must be ordered by part number.
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/InvalidStorageClass",
+ "value": {
+ "type": "structure",
+ "members": {
+ "StorageClassRequested": {
+ "shape": "StorageClass"
+ }
+ },
+ "error": {
+ "httpStatusCode": 400
+ },
+ "documentation": "The storage class you specified is not valid
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/ListObjectsOutput/members/BucketRegion",
+ "value": {
+ "shape": "BucketRegion",
+ "location": "header",
+ "locationName": "x-amz-bucket-region"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/ListObjectsV2Output/members/BucketRegion",
+ "value": {
+ "shape": "BucketRegion",
+ "location": "header",
+ "locationName": "x-amz-bucket-region"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/ResourceType",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/MethodNotAllowed",
+ "value": {
+ "type": "structure",
+ "members": {
+ "Method": {
+ "shape": "HttpMethod"
+ },
+ "ResourceType": {
+ "shape": "ResourceType"
+ },
+ "DeleteMarker": {
+ "shape": "DeleteMarker",
+ "location": "header",
+ "locationName": "x-amz-delete-marker"
+ },
+ "VersionId": {
+ "shape": "ObjectVersionId",
+ "location": "header",
+ "locationName": "x-amz-version-id"
+ },
+ "Allow": {
+ "shape": "HttpMethod",
+ "location": "header",
+ "locationName": "allow"
+ }
+ },
+ "error": {
+ "httpStatusCode": 405
+ },
+ "documentation": "The specified method is not allowed against this resource.
",
+ "exception": true
+ }
+ },
+ {
+ "op": "remove",
+ "path": "/shapes/ListBucketsOutput/members/Buckets"
+ },
+ {
+ "op": "add",
+ "path": "/shapes/ListBucketsOutput/members/Buckets",
+ "value": {
+ "shape":"Buckets",
+ "documentation":"The list of buckets owned by the requester.
"
+ }
+ },
+ {
+ "op": "remove",
+ "path": "/shapes/ListObjectsOutput/members/Contents"
+ },
+ {
+ "op": "add",
+ "path": "/shapes/ListObjectsOutput/members/Contents",
+ "value": {
+ "shape":"ObjectList",
+ "documentation":"Metadata about each object returned.
"
+ }
+ },
+ {
+ "op": "remove",
+ "path": "/shapes/ListObjectsV2Output/members/Contents"
+ },
+ {
+ "op": "add",
+ "path": "/shapes/ListObjectsV2Output/members/Contents",
+ "value": {
+ "shape":"ObjectList",
+ "documentation":"Metadata about each object returned.
"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/CrossLocationLoggingProhibitted",
+ "value": {
+ "type": "structure",
+ "members": {
+ "TargetBucketLocation": {
+ "shape": "BucketRegion"
+ },
+ "SourceBucketLocation": {
+ "shape": "BucketRegion"
+ }
+ },
+ "error": {
+ "httpStatusCode": 403
+ },
+ "documentation": "Cross S3 location logging not allowed.
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/InvalidTargetBucketForLogging",
+ "value": {
+ "type": "structure",
+ "members": {
+ "TargetBucket": {
+ "shape": "BucketName"
+ }
+ },
+ "error": {
+ "httpStatusCode": 400
+ },
+ "documentation": "The target bucket for logging does not exist
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/operations/PutBucketInventoryConfiguration/http/responseCode",
+ "value": 204
+ },
+ {
+ "op": "add",
+ "path": "/operations/PutBucketAnalyticsConfiguration/http/responseCode",
+ "value": 204
+ },
+ {
+ "op": "add",
+ "path": "/operations/PutBucketIntelligentTieringConfiguration/http/responseCode",
+ "value": 204
+ },
+ {
+ "op": "add",
+ "path": "/shapes/BucketNotEmpty",
+ "value": {
+ "type": "structure",
+ "members": {
+ "BucketName": {
+ "shape": "BucketName"
+ }
+ },
+ "error": {
+ "httpStatusCode": 409
+ },
+ "documentation": "The bucket you tried to delete is not empty
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/MinSizeAllowed",
+ "value": {
+ "type": "long"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/ProposedSize",
+ "value": {
+ "type": "long"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/EntityTooSmall",
+ "value": {
+ "type": "structure",
+ "members": {
+ "ETag": {
+ "shape": "ETag"
+ },
+ "MinSizeAllowed": {
+ "shape": "MinSizeAllowed"
+ },
+ "PartNumber": {
+ "shape": "PartNumber"
+ },
+ "ProposedSize": {
+ "shape": "ProposedSize"
+ }
+ },
+ "documentation": "Your proposed upload is smaller than the minimum allowed object size. Each part must be at least 5 MB in size, except the last part.
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/InvalidPart",
+ "value": {
+ "type": "structure",
+ "members": {
+ "ETag": {
+ "shape": "ETag"
+ },
+ "UploadId": {
+ "shape": "MultipartUploadId"
+ },
+ "PartNumber": {
+ "shape": "PartNumber"
+ }
+ },
+ "documentation": "One or more of the specified parts could not be found. The part might not have been uploaded, or the specified entity tag might not have matched the part's entity tag.
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/NoSuchTagSet",
+ "value": {
+ "type": "structure",
+ "members": {
+ "BucketName": {
+ "shape": "BucketName"
+ }
+ },
+ "error": {
+ "httpStatusCode": 404
+ },
+ "documentation": "There is no tag set associated with the bucket.
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/operations/PutBucketTagging/http/responseCode",
+ "value": 204
+ },
+ {
+ "op": "add",
+ "path": "/shapes/InvalidTag",
+ "value": {
+ "type": "structure",
+ "members": {
+ "TagKey": {
+ "shape": "ObjectKey"
+ },
+ "TagValue": {
+ "shape": "Value"
+ }
+ },
+ "documentation": "The tag provided was not a valid tag. This error can occur if the tag did not pass input validation.
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/ObjectLockConfigurationNotFoundError",
+ "value": {
+ "type": "structure",
+ "members": {
+ "BucketName": {
+ "shape": "BucketName"
+ }
+ },
+ "error": {
+ "httpStatusCode": 404
+ },
+ "documentation": "Object Lock configuration does not exist for this bucket
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/InvalidPartNumber",
+ "value": {
+ "type": "structure",
+ "members": {
+ "PartNumberRequested": {
+ "shape": "PartNumber"
+ },
+ "ActualPartCount": {
+ "shape": "PartNumber"
+ }
+ },
+ "error": {
+ "httpStatusCode": 416
+ },
+ "documentation": "The requested partnumber is not satisfiable
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/OwnershipControlsNotFoundError",
+ "value": {
+ "type": "structure",
+ "members": {
+ "BucketName": {
+ "shape": "BucketName"
+ }
+ },
+ "error": {
+ "httpStatusCode": 404
+ },
+ "documentation": "The bucket ownership controls were not found
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/NoSuchPublicAccessBlockConfiguration",
+ "value": {
+ "type": "structure",
+ "members": {
+ "BucketName": {
+ "shape": "BucketName"
+ }
+ },
+ "error": {
+ "httpStatusCode": 404
+ },
+ "documentation": "The public access block configuration was not found
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/NoSuchBucketPolicy",
+ "value": {
+ "type": "structure",
+ "members": {
+ "BucketName": {
+ "shape": "BucketName"
+ }
+ },
+ "error": {
+ "httpStatusCode": 404
+ },
+ "documentation": "The bucket policy does not exist
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/InvalidObjectState/error",
+ "value": {
+ "httpStatusCode": 403
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/InvalidDigest",
+ "value": {
+ "type": "structure",
+ "members": {
+ "Content_MD5": {
+ "shape": "ContentMD5",
+ "locationName":"Content-MD5"
+ }
+ },
+ "documentation": "The Content-MD5 you specified was invalid.
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/KeyLength",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/KeyTooLongError",
+ "value": {
+ "type": "structure",
+ "members": {
+ "MaxSizeAllowed": {
+ "shape": "KeyLength"
+ },
+ "Size": {
+ "shape": "KeyLength"
+ }
+ },
+ "documentation": "Your key is too long
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/InvalidLocationConstraint",
+ "value": {
+ "type": "structure",
+ "members": {
+ "LocationConstraint": {
+ "shape": "BucketRegion"
+ }
+ },
+ "documentation": "The specified location-constraint is not valid
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/EntityTooLarge",
+ "value": {
+ "type": "structure",
+ "members": {
+ "MaxSizeAllowed": {
+ "shape": "KeyLength"
+ },
+ "HostId": {
+ "shape": "HostId"
+ },
+ "ProposedSize": {
+ "shape": "ProposedSize"
+ }
+ },
+ "documentation": "Your proposed upload exceeds the maximum allowed size
",
+ "exception": true
+ }
+ },
+ {
+ "op": "remove",
+ "path": "/shapes/ListObjectVersionsOutput/members/Versions"
+ },
+ {
+ "op": "add",
+ "path": "/shapes/ListObjectVersionsOutput/members/Versions",
+ "value": {
+ "shape":"ObjectVersionList",
+ "documentation":"Container for version information.
",
+ "locationName":"Version"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/InvalidEncryptionAlgorithmError",
+ "value": {
+ "type": "structure",
+ "members": {
+ "ArgumentName": {
+ "shape": "ArgumentName"
+ },
+ "ArgumentValue": {
+ "shape": "ArgumentValue"
+ }
+ },
+ "error": {
+ "httpStatusCode": 400
+ },
+ "documentation": "The Encryption request you specified is not valid.
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/Header",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/additionalMessage",
+ "value": {
+ "type": "string"
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/NotImplemented",
+ "value": {
+ "type": "structure",
+ "members": {
+ "Header": {
+ "shape": "Header"
+ },
+ "additionalMessage": {
+ "shape": "additionalMessage"
+ }
+ },
+ "error": {
+ "httpStatusCode": 501
+ },
+ "documentation": "A header you provided implies functionality that is not implemented.
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/ConditionalRequestConflict",
+ "value": {
+ "type": "structure",
+ "members": {
+ "Condition": {
+ "shape": "IfCondition"
+ },
+ "Key": {
+ "shape": "ObjectKey"
+ }
+ },
+ "error": {
+ "httpStatusCode": 409
+ },
+ "documentation": "The conditional request cannot succeed due to a conflicting operation against this resource.
",
+ "exception": true
+ }
+ },
+ {
+ "op": "add",
+ "path": "/shapes/BadDigest",
+ "value": {
+ "type": "structure",
+ "members": {
+ "ExpectedDigest": {
+ "shape": "ContentMD5"
+ },
+ "CalculatedDigest": {
+ "shape": "ContentMD5"
+ }
+ },
+ "error": {
+ "httpStatusCode": 400
+ },
+ "documentation": "The Content-MD5 you specified did not match what we received.
",
+ "exception": true
+ }
+ }
+ ],
+ "apigatewayv2/2018-11-29/service-2": [
+ {
+ "op": "add",
+ "path": "/operations/UpdateDeployment/http/responseCode",
+ "value": 201
+ },
+ {
+ "op": "add",
+ "path": "/operations/UpdateApi/http/responseCode",
+ "value": 201
+ },
+ {
+ "op": "add",
+ "path": "/operations/UpdateRoute/http/responseCode",
+ "value": 201
+ },
+ {
+ "op": "add",
+ "path": "/operations/CreateApiMapping/http/responseCode",
+ "value": 200
+ }
+ ]
+}
diff --git a/localstack-core/localstack/aws/spec.py b/localstack-core/localstack/aws/spec.py
new file mode 100644
index 0000000000000..3c769f8d7f555
--- /dev/null
+++ b/localstack-core/localstack/aws/spec.py
@@ -0,0 +1,307 @@
+import dataclasses
+import json
+import logging
+import os
+from collections import defaultdict
+from functools import cached_property, lru_cache
+from typing import Dict, Generator, List, Literal, NamedTuple, Optional, Tuple
+
+import jsonpatch
+from botocore.exceptions import UnknownServiceError
+from botocore.loaders import Loader, instance_cache
+from botocore.model import OperationModel, ServiceModel
+
+LOG = logging.getLogger(__name__)
+
+ServiceName = str
+ProtocolName = Literal["query", "json", "rest-json", "rest-xml", "ec2"]
+
+
+class ServiceModelIdentifier(NamedTuple):
+ """
+ Identifies a specific service model.
+ If the protocol is not given, the default protocol of the service with the specific name is assumed.
+ Maybe also add versions here in the future (if we can support multiple different versions for one service).
+ """
+
+ name: ServiceName
+ protocol: Optional[ProtocolName] = None
+
+
+spec_patches_json = os.path.join(os.path.dirname(__file__), "spec-patches.json")
+
+
+def load_spec_patches() -> Dict[str, list]:
+ if not os.path.exists(spec_patches_json):
+ return {}
+ with open(spec_patches_json) as fd:
+ return json.load(fd)
+
+
+# Path for custom specs which are not (anymore) provided by botocore
+LOCALSTACK_BUILTIN_DATA_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
+
+
+class LocalStackBuiltInDataLoaderMixin(Loader):
+ def __init__(self, *args, **kwargs):
+ # add the builtin data path to the extra_search_paths to ensure they are discovered by the loader
+ super().__init__(*args, extra_search_paths=[LOCALSTACK_BUILTIN_DATA_PATH], **kwargs)
+
+
+class PatchingLoader(Loader):
+ """
+ A custom botocore Loader that applies JSON patches from the given json patch file to the specs as they are loaded.
+ """
+
+ patches: Dict[str, list]
+
+ def __init__(self, patches: Dict[str, list], *args, **kwargs):
+ # add the builtin data path to the extra_search_paths to ensure they are discovered by the loader
+ super().__init__(*args, **kwargs)
+ self.patches = patches
+
+ @instance_cache
+ def load_data(self, name: str):
+ result = super(PatchingLoader, self).load_data(name)
+
+ if patches := self.patches.get(name):
+ return jsonpatch.apply_patch(result, patches)
+
+ return result
+
+
+class CustomLoader(PatchingLoader, LocalStackBuiltInDataLoaderMixin):
+ # Class mixing the different loader features (patching, localstack specific data)
+ pass
+
+
+loader = CustomLoader(load_spec_patches())
+
+
+class UnknownServiceProtocolError(UnknownServiceError):
+ """Raised when trying to load a service with an unknown protocol.
+
+ :ivar service_name: The name of the service.
+ :ivar protocol: The name of the unknown protocol.
+ """
+
+ fmt = "Unknown service protocol: '{service_name}-{protocol}'."
+
+
+def list_services() -> List[ServiceModel]:
+ return [load_service(service) for service in loader.list_available_services("service-2")]
+
+
+def load_service(
+ service: ServiceName, version: Optional[str] = None, protocol: Optional[ProtocolName] = None
+) -> ServiceModel:
+ """
+ Loads a service
+ :param service: to load, f.e. "sqs". For custom, internalized, service protocol specs (f.e. sqs-query) it's also
+ possible to directly define the protocol in the service name (f.e. use sqs-query)
+ :param version: of the service to load, f.e. "2012-11-05", by default the latest version will be used
+ :param protocol: specific protocol to load for the specific service, f.e. "json" for the "sqs" service
+ if the service cannot be found
+ :return: Loaded service model of the service
+ :raises: UnknownServiceError if the service cannot be found
+ :raises: UnknownServiceProtocolError if the specific protocol of the service cannot be found
+ """
+ service_description = loader.load_service_model(service, "service-2", version)
+
+ # check if the protocol is defined, and if so, if the loaded service defines this protocol
+ if protocol is not None and protocol != service_description.get("metadata", {}).get("protocol"):
+ # if the protocol is defined, but not the one of the currently loaded service,
+ # check if we already loaded the custom spec based on the naming convention (-),
+ # f.e. "sqs-query"
+ if service.endswith(f"-{protocol}"):
+ # if so, we raise an exception
+ raise UnknownServiceProtocolError(service_name=service, protocol=protocol)
+ # otherwise we try to load it (recursively)
+ try:
+ return load_service(f"{service}-{protocol}", version, protocol=protocol)
+ except UnknownServiceError:
+ # raise an unknown protocol error in case the service also can't be loaded with the naming convention
+ raise UnknownServiceProtocolError(service_name=service, protocol=protocol)
+
+ # remove potential protocol names from the service name
+ # FIXME add more protocols here if we have to internalize more than just sqs-query
+ # TODO this should not contain specific internalized serivce names
+ service = {"sqs-query": "sqs"}.get(service, service)
+ return ServiceModel(service_description, service)
+
+
+def iterate_service_operations() -> Generator[Tuple[ServiceModel, OperationModel], None, None]:
+ """
+ Returns one record per operation in the AWS service spec, where the first item is the service model the operation
+ belongs to, and the second is the operation model.
+
+ :return: an iterable
+ """
+ for service in list_services():
+ for op_name in service.operation_names:
+ yield service, service.operation_model(op_name)
+
+
+@dataclasses.dataclass
+class ServiceCatalogIndex:
+ """
+ The ServiceCatalogIndex enables fast lookups for common operations to determine a service from service indicators.
+ """
+
+ service_names: List[ServiceName]
+ target_prefix_index: Dict[str, List[ServiceModelIdentifier]]
+ signing_name_index: Dict[str, List[ServiceModelIdentifier]]
+ operations_index: Dict[str, List[ServiceModelIdentifier]]
+ endpoint_prefix_index: Dict[str, List[ServiceModelIdentifier]]
+
+
+class LazyServiceCatalogIndex:
+ """
+ A ServiceCatalogIndex that builds indexes in-memory from the spec.
+ """
+
+ @cached_property
+ def service_names(self) -> List[ServiceName]:
+ return list(self._services.keys())
+
+ @cached_property
+ def target_prefix_index(self) -> Dict[str, List[ServiceModelIdentifier]]:
+ result = defaultdict(list)
+ for service_models in self._services.values():
+ for service_model in service_models:
+ target_prefix = service_model.metadata.get("targetPrefix")
+ if target_prefix:
+ result[target_prefix].append(
+ ServiceModelIdentifier(service_model.service_name, service_model.protocol)
+ )
+ return dict(result)
+
+ @cached_property
+ def signing_name_index(self) -> Dict[str, List[ServiceModelIdentifier]]:
+ result = defaultdict(list)
+ for service_models in self._services.values():
+ for service_model in service_models:
+ result[service_model.signing_name].append(
+ ServiceModelIdentifier(service_model.service_name, service_model.protocol)
+ )
+ return dict(result)
+
+ @cached_property
+ def operations_index(self) -> Dict[str, List[ServiceModelIdentifier]]:
+ result = defaultdict(list)
+ for service_models in self._services.values():
+ for service_model in service_models:
+ operations = service_model.operation_names
+ if operations:
+ for operation in operations:
+ result[operation].append(
+ ServiceModelIdentifier(
+ service_model.service_name, service_model.protocol
+ )
+ )
+ return dict(result)
+
+ @cached_property
+ def endpoint_prefix_index(self) -> Dict[str, List[ServiceModelIdentifier]]:
+ result = defaultdict(list)
+ for service_models in self._services.values():
+ for service_model in service_models:
+ result[service_model.endpoint_prefix].append(
+ ServiceModelIdentifier(service_model.service_name, service_model.protocol)
+ )
+ return dict(result)
+
+ @cached_property
+ def _services(self) -> Dict[ServiceName, List[ServiceModel]]:
+ services = defaultdict(list)
+ for service in list_services():
+ services[service.service_name].append(service)
+ return services
+
+
+class ServiceCatalog:
+ index: ServiceCatalogIndex
+
+ def __init__(self, index: ServiceCatalogIndex = None):
+ self.index = index or LazyServiceCatalogIndex()
+
+ @lru_cache(maxsize=512)
+ def get(
+ self, name: ServiceName, protocol: Optional[ProtocolName] = None
+ ) -> Optional[ServiceModel]:
+ return load_service(name, protocol=protocol)
+
+ @property
+ def service_names(self) -> List[ServiceName]:
+ return self.index.service_names
+
+ @property
+ def target_prefix_index(self) -> Dict[str, List[ServiceModelIdentifier]]:
+ return self.index.target_prefix_index
+
+ @property
+ def signing_name_index(self) -> Dict[str, List[ServiceModelIdentifier]]:
+ return self.index.signing_name_index
+
+ @property
+ def operations_index(self) -> Dict[str, List[ServiceModelIdentifier]]:
+ return self.index.operations_index
+
+ @property
+ def endpoint_prefix_index(self) -> Dict[str, List[ServiceModelIdentifier]]:
+ return self.index.endpoint_prefix_index
+
+ def by_target_prefix(self, target_prefix: str) -> List[ServiceModelIdentifier]:
+ return self.target_prefix_index.get(target_prefix, [])
+
+ def by_signing_name(self, signing_name: str) -> List[ServiceModelIdentifier]:
+ return self.signing_name_index.get(signing_name, [])
+
+ def by_operation(self, operation_name: str) -> List[ServiceModelIdentifier]:
+ return self.operations_index.get(operation_name, [])
+
+
+def build_service_index_cache(file_path: str) -> ServiceCatalogIndex:
+ """
+ Creates a new ServiceCatalogIndex and stores it into the given file_path.
+
+ :param file_path: the path to pickle to
+ :return: the created ServiceCatalogIndex
+ """
+ return save_service_index_cache(LazyServiceCatalogIndex(), file_path)
+
+
+def load_service_index_cache(file: str) -> ServiceCatalogIndex:
+ """
+ Loads from the given file the pickled ServiceCatalogIndex.
+
+ :param file: the file to load from
+ :return: the loaded ServiceCatalogIndex
+ """
+ import pickle
+
+ with open(file, "rb") as fd:
+ return pickle.load(fd)
+
+
+def save_service_index_cache(index: LazyServiceCatalogIndex, file_path: str) -> ServiceCatalogIndex:
+ """
+ Creates from the given LazyServiceCatalogIndex a ``ServiceCatalogIndex`, pickles its contents into the given file,
+ and then returns the newly created index.
+
+ :param index: the LazyServiceCatalogIndex to store the index from.
+ :param file_path: the path to pickle to
+ :return: the created ServiceCatalogIndex
+ """
+ import pickle
+
+ cache = ServiceCatalogIndex(
+ service_names=index.service_names,
+ endpoint_prefix_index=index.endpoint_prefix_index,
+ operations_index=index.operations_index,
+ signing_name_index=index.signing_name_index,
+ target_prefix_index=index.target_prefix_index,
+ )
+ with open(file_path, "wb") as fd:
+ pickle.dump(cache, fd)
+ return cache
diff --git a/localstack-core/localstack/cli/__init__.py b/localstack-core/localstack/cli/__init__.py
new file mode 100644
index 0000000000000..fb0407e19e65e
--- /dev/null
+++ b/localstack-core/localstack/cli/__init__.py
@@ -0,0 +1,10 @@
+from .console import console
+from .plugin import LocalstackCli, LocalstackCliPlugin
+
+name = "cli"
+
+__all__ = [
+ "console",
+ "LocalstackCli",
+ "LocalstackCliPlugin",
+]
diff --git a/localstack-core/localstack/cli/console.py b/localstack-core/localstack/cli/console.py
new file mode 100644
index 0000000000000..24bda10813744
--- /dev/null
+++ b/localstack-core/localstack/cli/console.py
@@ -0,0 +1,11 @@
+from rich.console import Console
+
+BANNER = r"""
+ __ _______ __ __
+ / / ____ _________ _/ / ___// /_____ ______/ /__
+ / / / __ \/ ___/ __ `/ /\__ \/ __/ __ `/ ___/ //_/
+ / /___/ /_/ / /__/ /_/ / /___/ / /_/ /_/ / /__/ ,<
+ /_____/\____/\___/\__,_/_//____/\__/\__,_/\___/_/|_|
+"""
+
+console = Console()
diff --git a/localstack-core/localstack/cli/exceptions.py b/localstack-core/localstack/cli/exceptions.py
new file mode 100644
index 0000000000000..cd65d2ee13d26
--- /dev/null
+++ b/localstack-core/localstack/cli/exceptions.py
@@ -0,0 +1,19 @@
+import typing as t
+from gettext import gettext
+
+import click
+from click import ClickException, echo
+from click._compat import get_text_stderr
+
+
+class CLIError(ClickException):
+ """A ClickException with a red error message"""
+
+ def format_message(self) -> str:
+ return click.style(f"β Error: {self.message}", fg="red")
+
+ def show(self, file: t.Optional[t.IO[t.Any]] = None) -> None:
+ if file is None:
+ file = get_text_stderr()
+
+ echo(gettext(self.format_message()), file=file)
diff --git a/localstack-core/localstack/cli/localstack.py b/localstack-core/localstack/cli/localstack.py
new file mode 100644
index 0000000000000..9abbf4e53775a
--- /dev/null
+++ b/localstack-core/localstack/cli/localstack.py
@@ -0,0 +1,926 @@
+import json
+import logging
+import os
+import sys
+import traceback
+from typing import Dict, List, Optional, Tuple, TypedDict
+
+import click
+import requests
+
+from localstack import config
+from localstack.cli.exceptions import CLIError
+from localstack.constants import VERSION
+from localstack.utils.analytics.cli import publish_invocation
+from localstack.utils.bootstrap import get_container_default_logfile_location
+from localstack.utils.json import CustomEncoder
+
+from .console import BANNER, console
+from .plugin import LocalstackCli, load_cli_plugins
+
+
+class LocalStackCliGroup(click.Group):
+ """
+ A Click group used for the top-level ``localstack`` command group. It implements global exception handling
+ by:
+
+ - Ignoring click exceptions (already handled)
+ - Handling common exceptions (like DockerNotAvailable)
+ - Wrapping all unexpected exceptions in a ClickException (for a unified error message)
+
+ It also implements a custom help formatter to build more fine-grained groups.
+ """
+
+ # FIXME: find a way to communicate this from the actual command
+ advanced_commands = [
+ "aws",
+ "dns",
+ "extensions",
+ "license",
+ "login",
+ "logout",
+ "pod",
+ "state",
+ "ephemeral",
+ "replicator",
+ ]
+
+ def invoke(self, ctx: click.Context):
+ try:
+ return super(LocalStackCliGroup, self).invoke(ctx)
+ except click.exceptions.Exit:
+ # raise Exit exceptions unmodified (e.g., raised on --help)
+ raise
+ except click.ClickException:
+ # don't handle ClickExceptions, just reraise
+ if ctx and ctx.params.get("debug"):
+ click.echo(traceback.format_exc())
+ raise
+ except Exception as e:
+ if ctx and ctx.params.get("debug"):
+ click.echo(traceback.format_exc())
+ from localstack.utils.container_utils.container_client import (
+ ContainerException,
+ DockerNotAvailable,
+ )
+
+ if isinstance(e, DockerNotAvailable):
+ raise CLIError(
+ "Docker could not be found on the system.\n"
+ "Please make sure that you have a working docker environment on your machine."
+ )
+ elif isinstance(e, ContainerException):
+ raise CLIError(e.message)
+ else:
+ # If we have a generic exception, we wrap it in a ClickException
+ raise CLIError(str(e)) from e
+
+ def format_commands(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
+ """Extra format methods for multi methods that adds all the commands after the options. It also
+ groups commands into command categories."""
+ categories = {"Commands": [], "Advanced": [], "Deprecated": []}
+
+ commands = []
+ for subcommand in self.list_commands(ctx):
+ cmd = self.get_command(ctx, subcommand)
+ # What is this, the tool lied about a command. Ignore it
+ if cmd is None:
+ continue
+ if cmd.hidden:
+ continue
+
+ commands.append((subcommand, cmd))
+
+ # allow for 3 times the default spacing
+ if len(commands):
+ limit = formatter.width - 6 - max(len(cmd[0]) for cmd in commands)
+
+ for subcommand, cmd in commands:
+ help = cmd.get_short_help_str(limit)
+ categories[self._get_category(cmd)].append((subcommand, help))
+
+ for category, rows in categories.items():
+ if rows:
+ with formatter.section(category):
+ formatter.write_dl(rows)
+
+ def _get_category(self, cmd) -> str:
+ if cmd.deprecated:
+ return "Deprecated"
+
+ if cmd.name in self.advanced_commands:
+ return "Advanced"
+
+ return "Commands"
+
+
+def create_with_plugins() -> LocalstackCli:
+ """
+ Creates a LocalstackCli instance with all cli plugins loaded.
+ :return: a LocalstackCli instance
+ """
+ cli = LocalstackCli()
+ cli.group = localstack
+ load_cli_plugins(cli)
+ return cli
+
+
+def _setup_cli_debug() -> None:
+ from localstack.logging.setup import setup_logging_for_cli
+
+ config.DEBUG = True
+ os.environ["DEBUG"] = "1"
+
+ setup_logging_for_cli(logging.DEBUG if config.DEBUG else logging.INFO)
+
+
+# Re-usable format option decorator which can be used across multiple commands
+_click_format_option = click.option(
+ "-f",
+ "--format",
+ "format_",
+ type=click.Choice(["table", "plain", "dict", "json"]),
+ default="table",
+ help="The formatting style for the command output.",
+)
+
+
+@click.group(
+ name="localstack",
+ help="The LocalStack Command Line Interface (CLI)",
+ cls=LocalStackCliGroup,
+ context_settings={
+ # add "-h" as a synonym for "--help"
+ # https://click.palletsprojects.com/en/8.1.x/documentation/#help-parameter-customization
+ "help_option_names": ["-h", "--help"],
+ # show default values for options by default - https://github.com/pallets/click/pull/1225
+ "show_default": True,
+ },
+)
+@click.version_option(
+ VERSION,
+ "--version",
+ "-v",
+ message="LocalStack CLI %(version)s",
+ help="Show the version of the LocalStack CLI and exit",
+)
+@click.option("-d", "--debug", is_flag=True, help="Enable CLI debugging mode")
+@click.option("-p", "--profile", type=str, help="Set the configuration profile")
+def localstack(debug, profile) -> None:
+ # --profile is read manually in localstack.cli.main because it needs to be read before localstack.config is read
+
+ if debug:
+ _setup_cli_debug()
+
+ from localstack.utils.files import cache_dir
+
+ # overwrite the config variable here to defer import of cache_dir
+ if not os.environ.get("LOCALSTACK_VOLUME_DIR", "").strip():
+ config.VOLUME_DIR = str(cache_dir() / "volume")
+
+ # FIXME: at some point we should remove the use of `config.dirs` for the CLI,
+ # see https://github.com/localstack/localstack/pull/7906
+ config.dirs.for_cli().mkdirs()
+
+
+@localstack.group(
+ name="config",
+ short_help="Manage your LocalStack config",
+)
+def localstack_config() -> None:
+ """
+ Inspect and validate your LocalStack configuration.
+ """
+ pass
+
+
+@localstack_config.command(name="show", short_help="Show your config")
+@_click_format_option
+@publish_invocation
+def cmd_config_show(format_: str) -> None:
+ """
+ Print the current LocalStack config values.
+
+ This command prints the LocalStack configuration values from your environment.
+ It analyzes the environment variables as well as the LocalStack CLI profile.
+ It does _not_ analyze a specific file (like a docker-compose-yml).
+ """
+ # TODO: parse values from potential docker-compose file?
+ assert config
+
+ try:
+ # only load the ext config if it's available
+ from localstack.pro.core import config as ext_config
+
+ assert ext_config
+ except ImportError:
+ # the ext package is not available
+ return None
+
+ if format_ == "table":
+ _print_config_table()
+ elif format_ == "plain":
+ _print_config_pairs()
+ elif format_ == "dict":
+ _print_config_dict()
+ elif format_ == "json":
+ _print_config_json()
+ else:
+ _print_config_pairs() # fall back to plain
+
+
+@localstack_config.command(name="validate", short_help="Validate your config")
+@click.option(
+ "-f",
+ "--file",
+ help="Path to compose file",
+ default="docker-compose.yml",
+ type=click.Path(exists=True, file_okay=True, readable=True),
+)
+@publish_invocation
+def cmd_config_validate(file: str) -> None:
+ """
+ Validate your LocalStack configuration (docker compose).
+
+ This command inspects the given docker-compose file (by default docker-compose.yml in the current working
+ directory) and validates if the configuration is valid.
+
+ \b
+ It will show an error and return a non-zero exit code if:
+ - The docker-compose file is syntactically incorrect.
+ - If the file contains common issues when configuring LocalStack.
+ """
+
+ from localstack.utils import bootstrap
+
+ if bootstrap.validate_localstack_config(file):
+ console.print("[green]:heavy_check_mark:[/green] config valid")
+ sys.exit(0)
+ else:
+ console.print("[red]:heavy_multiplication_x:[/red] validation error")
+ sys.exit(1)
+
+
+def _print_config_json() -> None:
+ import json
+
+ console.print(json.dumps(dict(config.collect_config_items()), cls=CustomEncoder))
+
+
+def _print_config_pairs() -> None:
+ for key, value in config.collect_config_items():
+ console.print(f"{key}={value}")
+
+
+def _print_config_dict() -> None:
+ console.print(dict(config.collect_config_items()))
+
+
+def _print_config_table() -> None:
+ from rich.table import Table
+
+ grid = Table(show_header=True)
+ grid.add_column("Key")
+ grid.add_column("Value")
+
+ for key, value in config.collect_config_items():
+ grid.add_row(key, str(value))
+
+ console.print(grid)
+
+
+@localstack.group(
+ name="status",
+ short_help="Query status info",
+ invoke_without_command=True,
+)
+@click.pass_context
+def localstack_status(ctx: click.Context) -> None:
+ """
+ Query status information about the currently running LocalStack instance.
+ """
+ if ctx.invoked_subcommand is None:
+ ctx.invoke(localstack_status.get_command(ctx, "docker"))
+
+
+@localstack_status.command(name="docker", short_help="Query LocalStack Docker status")
+@_click_format_option
+def cmd_status_docker(format_: str) -> None:
+ """
+ Query information about the currently running LocalStack Docker image, its container,
+ and the LocalStack runtime.
+ """
+ with console.status("Querying Docker status"):
+ _print_docker_status(format_)
+
+
+class DockerStatus(TypedDict, total=False):
+ running: bool
+ runtime_version: str
+ image_tag: str
+ image_id: str
+ image_created: str
+ container_name: Optional[str]
+ container_ip: Optional[str]
+
+
+def _print_docker_status(format_: str) -> None:
+ from localstack.utils import docker_utils
+ from localstack.utils.bootstrap import get_docker_image_details, get_server_version
+ from localstack.utils.container_networking import get_main_container_ip, get_main_container_name
+
+ img = get_docker_image_details()
+ cont_name = config.MAIN_CONTAINER_NAME
+ running = docker_utils.DOCKER_CLIENT.is_container_running(cont_name)
+ status = DockerStatus(
+ runtime_version=get_server_version(),
+ image_tag=img["tag"],
+ image_id=img["id"],
+ image_created=img["created"],
+ running=running,
+ )
+ if running:
+ status["container_name"] = get_main_container_name()
+ status["container_ip"] = get_main_container_ip()
+
+ if format_ == "dict":
+ console.print(status)
+ if format_ == "table":
+ _print_docker_status_table(status)
+ if format_ == "json":
+ console.print(json.dumps(status))
+ if format_ == "plain":
+ for key, value in status.items():
+ console.print(f"{key}={value}")
+
+
+def _print_docker_status_table(status: DockerStatus) -> None:
+ from rich.table import Table
+
+ grid = Table(show_header=False)
+ grid.add_column()
+ grid.add_column()
+
+ grid.add_row("Runtime version", f"[bold]{status['runtime_version']}[/bold]")
+ grid.add_row(
+ "Docker image",
+ f"tag: {status['image_tag']}, "
+ f"id: {status['image_id']}, "
+ f":calendar: {status['image_created']}",
+ )
+ cont_status = "[bold][red]:heavy_multiplication_x: stopped"
+ if status["running"]:
+ cont_status = (
+ f"[bold][green]:heavy_check_mark: running[/green][/bold] "
+ f'(name: "[italic]{status["container_name"]}[/italic]", IP: {status["container_ip"]})'
+ )
+ grid.add_row("Runtime status", cont_status)
+ console.print(grid)
+
+
+@localstack_status.command(name="services", short_help="Query LocalStack services status")
+@_click_format_option
+def cmd_status_services(format_: str) -> None:
+ """
+ Query information about the services of the currently running LocalStack instance.
+ """
+ url = config.external_service_url()
+
+ try:
+ health = requests.get(f"{url}/_localstack/health", timeout=2)
+ doc = health.json()
+ services = doc.get("services", [])
+ if format_ == "table":
+ _print_service_table(services)
+ if format_ == "plain":
+ for service, status in services.items():
+ console.print(f"{service}={status}")
+ if format_ == "dict":
+ console.print(services)
+ if format_ == "json":
+ console.print(json.dumps(services))
+ except requests.ConnectionError:
+ if config.DEBUG:
+ console.print_exception()
+ raise CLIError(f"could not connect to LocalStack health endpoint at {url}")
+
+
+def _print_service_table(services: Dict[str, str]) -> None:
+ from rich.table import Table
+
+ status_display = {
+ "running": "[green]:heavy_check_mark:[/green] running",
+ "starting": ":hourglass_flowing_sand: starting",
+ "available": "[grey]:heavy_check_mark:[/grey] available",
+ "error": "[red]:heavy_multiplication_x:[/red] error",
+ }
+
+ table = Table()
+ table.add_column("Service")
+ table.add_column("Status")
+
+ services = list(services.items())
+ services.sort(key=lambda item: item[0])
+
+ for service, status in services:
+ if status in status_display:
+ status = status_display[status]
+
+ table.add_row(service, status)
+
+ console.print(table)
+
+
+@localstack.command(name="start", short_help="Start LocalStack")
+@click.option("--docker", is_flag=True, help="Start LocalStack in a docker container [default]")
+@click.option("--host", is_flag=True, help="Start LocalStack directly on the host")
+@click.option("--no-banner", is_flag=True, help="Disable LocalStack banner", default=False)
+@click.option(
+ "-d", "--detached", is_flag=True, help="Start LocalStack in the background", default=False
+)
+@click.option(
+ "--network",
+ type=str,
+ help="The container network the LocalStack container should be started in. By default, the default docker bridge network is used.",
+ required=False,
+)
+@click.option(
+ "--env",
+ "-e",
+ help="Additional environment variables that are passed to the LocalStack container",
+ multiple=True,
+ required=False,
+)
+@click.option(
+ "--publish",
+ "-p",
+ help="Additional port mappings that are passed to the LocalStack container",
+ multiple=True,
+ required=False,
+)
+@click.option(
+ "--volume",
+ "-v",
+ help="Additional volume mounts that are passed to the LocalStack container",
+ multiple=True,
+ required=False,
+)
+@click.option(
+ "--host-dns",
+ help="Expose the LocalStack DNS server to the host using port bindings.",
+ required=False,
+ is_flag=True,
+ default=False,
+)
+@publish_invocation
+def cmd_start(
+ docker: bool,
+ host: bool,
+ no_banner: bool,
+ detached: bool,
+ network: str = None,
+ env: Tuple = (),
+ publish: Tuple = (),
+ volume: Tuple = (),
+ host_dns: bool = False,
+) -> None:
+ """
+ Start the LocalStack runtime.
+
+ This command starts the LocalStack runtime with your current configuration.
+ By default, it will start a new Docker container from the latest LocalStack(-Pro) Docker image
+ with best-practice volume mounts and port mappings.
+ """
+ if docker and host:
+ raise CLIError("Please specify either --docker or --host")
+ if host and detached:
+ raise CLIError("Cannot start detached in host mode")
+
+ if not no_banner:
+ print_banner()
+ print_version()
+ print_profile()
+ print_app()
+ console.line()
+
+ from localstack.utils import bootstrap
+
+ if not no_banner:
+ if host:
+ console.log("starting LocalStack in host mode :laptop_computer:")
+ else:
+ console.log("starting LocalStack in Docker mode :whale:")
+
+ if host:
+ # call hooks to prepare host
+ bootstrap.prepare_host(console)
+
+ # from here we abandon the regular CLI control path and start treating the process like a localstack
+ # runtime process
+ os.environ["LOCALSTACK_CLI"] = "0"
+ config.dirs = config.init_directories()
+
+ try:
+ bootstrap.start_infra_locally()
+ except ImportError:
+ if config.DEBUG:
+ console.print_exception()
+ raise CLIError(
+ "It appears you have a light install of localstack which only supports running in docker.\n"
+ "If you would like to use --host, please install localstack with Python using "
+ "`pip install localstack[runtime]` instead."
+ )
+ else:
+ # make sure to initialize the bootstrap environment and directories for the host (even if we're executing
+ # in Docker), to allow starting the container from within other containers (e.g., Github Codespaces).
+ config.OVERRIDE_IN_DOCKER = False
+ config.is_in_docker = False
+ config.dirs = config.init_directories()
+
+ # call hooks to prepare host (note that this call should stay below the config overrides above)
+ bootstrap.prepare_host(console)
+
+ # pass the parsed cli params to the start infra command
+ params = click.get_current_context().params
+
+ if network:
+ # reconciles the network config and makes sure that MAIN_DOCKER_NETWORK is set automatically if
+ # `--network` is set.
+ if config.MAIN_DOCKER_NETWORK:
+ if config.MAIN_DOCKER_NETWORK != network:
+ raise CLIError(
+ f"Values of MAIN_DOCKER_NETWORK={config.MAIN_DOCKER_NETWORK} and --network={network} "
+ f"do not match"
+ )
+ else:
+ config.MAIN_DOCKER_NETWORK = network
+ os.environ["MAIN_DOCKER_NETWORK"] = network
+
+ if detached:
+ bootstrap.start_infra_in_docker_detached(console, params)
+ else:
+ bootstrap.start_infra_in_docker(console, params)
+
+
+@localstack.command(name="stop", short_help="Stop LocalStack")
+@publish_invocation
+def cmd_stop() -> None:
+ """
+ Stops the current LocalStack runtime.
+
+ This command stops the currently running LocalStack docker container.
+ By default, this command looks for a container named `localstack-main` (which is the default
+ container name used by the `localstack start` command).
+ If your LocalStack container has a different name, set the config variable
+ `MAIN_CONTAINER_NAME`.
+ """
+ from localstack.utils.docker_utils import DOCKER_CLIENT
+
+ from ..utils.container_utils.container_client import NoSuchContainer
+
+ container_name = config.MAIN_CONTAINER_NAME
+
+ try:
+ DOCKER_CLIENT.stop_container(container_name)
+ console.print("container stopped: %s" % container_name)
+ except NoSuchContainer:
+ raise CLIError(
+ f'Expected a running LocalStack container named "{container_name}", but found none'
+ )
+
+
+@localstack.command(name="restart", short_help="Restart LocalStack")
+@publish_invocation
+def cmd_restart() -> None:
+ """
+ Restarts the current LocalStack runtime.
+ """
+ url = config.external_service_url()
+
+ try:
+ response = requests.post(
+ f"{url}/_localstack/health",
+ json={"action": "restart"},
+ )
+ response.raise_for_status()
+ console.print("LocalStack restarted within the container.")
+ except requests.ConnectionError:
+ if config.DEBUG:
+ console.print_exception()
+ raise CLIError("could not restart the LocalStack container")
+
+
+@localstack.command(
+ name="logs",
+ short_help="Show LocalStack logs",
+)
+@click.option(
+ "-f",
+ "--follow",
+ is_flag=True,
+ help="Block the terminal and follow the log output",
+ default=False,
+)
+@click.option(
+ "-n",
+ "--tail",
+ type=int,
+ help="Print only the last lines of the log output",
+ default=None,
+ metavar="N",
+)
+@publish_invocation
+def cmd_logs(follow: bool, tail: int) -> None:
+ """
+ Show the logs of the current LocalStack runtime.
+
+ This command shows the logs of the currently running LocalStack docker container.
+ By default, this command looks for a container named `localstack-main` (which is the default
+ container name used by the `localstack start` command).
+ If your LocalStack container has a different name, set the config variable
+ `MAIN_CONTAINER_NAME`.
+ """
+ from localstack.utils.docker_utils import DOCKER_CLIENT
+
+ container_name = config.MAIN_CONTAINER_NAME
+ logfile = get_container_default_logfile_location(container_name)
+
+ if not DOCKER_CLIENT.is_container_running(container_name):
+ console.print("localstack container not running")
+ if os.path.exists(logfile):
+ console.print("printing logs from previous run")
+ with open(logfile) as fd:
+ for line in fd:
+ click.echo(line, nl=False)
+ sys.exit(1)
+
+ if follow:
+ num_lines = 0
+ for line in DOCKER_CLIENT.stream_container_logs(container_name):
+ print(line.decode("utf-8").rstrip("\r\n"))
+ num_lines += 1
+ if tail is not None and num_lines >= tail:
+ break
+
+ else:
+ logs = DOCKER_CLIENT.get_container_logs(container_name)
+ if tail is not None:
+ logs = "\n".join(logs.split("\n")[-tail:])
+ print(logs)
+
+
+@localstack.command(name="wait", short_help="Wait for LocalStack")
+@click.option(
+ "-t",
+ "--timeout",
+ type=float,
+ help="Only wait for seconds before raising a timeout error",
+ default=None,
+ metavar="N",
+)
+@publish_invocation
+def cmd_wait(timeout: Optional[float] = None) -> None:
+ """
+ Wait for the LocalStack runtime to be up and running.
+
+ This commands waits for a started LocalStack runtime to be up and running, ready to serve
+ requests.
+ By default, this command looks for a container named `localstack-main` (which is the default
+ container name used by the `localstack start` command).
+ If your LocalStack container has a different name, set the config variable
+ `MAIN_CONTAINER_NAME`.
+ """
+ from localstack.utils.bootstrap import wait_container_is_ready
+
+ if not wait_container_is_ready(timeout=timeout):
+ raise CLIError("timeout")
+
+
+@localstack.command(name="ssh", short_help="Obtain a shell in LocalStack")
+@publish_invocation
+def cmd_ssh() -> None:
+ """
+ Obtain a shell in the current LocalStack runtime.
+
+ This command starts a new interactive shell in the currently running LocalStack container.
+ By default, this command looks for a container named `localstack-main` (which is the default
+ container name used by the `localstack start` command).
+ If your LocalStack container has a different name, set the config variable
+ `MAIN_CONTAINER_NAME`.
+ """
+ from localstack.utils.docker_utils import DOCKER_CLIENT
+
+ if not DOCKER_CLIENT.is_container_running(config.MAIN_CONTAINER_NAME):
+ raise CLIError(
+ f'Expected a running LocalStack container named "{config.MAIN_CONTAINER_NAME}", but found none'
+ )
+ os.execlp("docker", "docker", "exec", "-it", config.MAIN_CONTAINER_NAME, "bash")
+
+
+@localstack.group(name="update", short_help="Update LocalStack")
+def localstack_update() -> None:
+ """
+ Update different LocalStack components.
+ """
+ pass
+
+
+@localstack_update.command(name="all", short_help="Update all LocalStack components")
+@click.pass_context
+@publish_invocation
+def cmd_update_all(ctx: click.Context) -> None:
+ """
+ Update all LocalStack components.
+
+ This is the same as executing `localstack update localstack-cli` and
+ `localstack update docker-images`.
+ Updating the LocalStack CLI is currently only supported if the CLI
+ is installed and run via Python / PIP. If you used a different installation method,
+ please follow the instructions on https://docs.localstack.cloud/.
+ """
+ ctx.invoke(localstack_update.get_command(ctx, "localstack-cli"))
+ ctx.invoke(localstack_update.get_command(ctx, "docker-images"))
+
+
+@localstack_update.command(name="localstack-cli", short_help="Update LocalStack CLI")
+@publish_invocation
+def cmd_update_localstack_cli() -> None:
+ """
+ Update the LocalStack CLI.
+
+ This command updates the LocalStack CLI. This is currently only supported if the CLI
+ is installed and run via Python / PIP. If you used a different installation method,
+ please follow the instructions on https://docs.localstack.cloud/.
+ """
+ if is_frozen_bundle():
+ # "update" can only be performed if running from source / in a non-frozen interpreter
+ raise CLIError(
+ "The LocalStack CLI can only update itself if installed via PIP. "
+ "Please follow the instructions on https://docs.localstack.cloud/ to update your CLI."
+ )
+
+ import subprocess
+ from subprocess import CalledProcessError
+
+ console.rule("Updating LocalStack CLI")
+ with console.status("Updating LocalStack CLI..."):
+ try:
+ subprocess.check_output(
+ [sys.executable, "-m", "pip", "install", "--upgrade", "localstack"]
+ )
+ console.print(":heavy_check_mark: LocalStack CLI updated")
+ except CalledProcessError:
+ console.print(":heavy_multiplication_x: LocalStack CLI update failed", style="bold red")
+
+
+@localstack_update.command(
+ name="docker-images", short_help="Update docker images LocalStack depends on"
+)
+@publish_invocation
+def cmd_update_docker_images() -> None:
+ """
+ Update all Docker images LocalStack depends on.
+
+ This command updates all Docker LocalStack docker images, as well as other Docker images
+ LocalStack depends on (and which have been used before / are present on the machine).
+ """
+ from localstack.utils.docker_utils import DOCKER_CLIENT
+
+ console.rule("Updating docker images")
+
+ all_images = DOCKER_CLIENT.get_docker_image_names(strip_latest=False)
+ image_prefixes = [
+ "localstack/",
+ "public.ecr.aws/lambda",
+ ]
+ localstack_images = [
+ image
+ for image in all_images
+ if any(
+ image.startswith(image_prefix) or image.startswith(f"docker.io/{image_prefix}")
+ for image_prefix in image_prefixes
+ )
+ and not image.endswith(":") # ignore dangling images
+ ]
+ update_images(localstack_images)
+
+
+def update_images(image_list: List[str]) -> None:
+ from rich.markup import escape
+ from rich.progress import MofNCompleteColumn, Progress
+
+ from localstack.utils.container_utils.container_client import ContainerException
+ from localstack.utils.docker_utils import DOCKER_CLIENT
+
+ updated_count = 0
+ failed_count = 0
+ progress = Progress(
+ *Progress.get_default_columns(), MofNCompleteColumn(), transient=True, console=console
+ )
+ with progress:
+ for image in progress.track(image_list, description="Processing image..."):
+ try:
+ updated = False
+ hash_before_pull = DOCKER_CLIENT.inspect_image(image_name=image, pull=False)["Id"]
+ DOCKER_CLIENT.pull_image(image)
+ if (
+ hash_before_pull
+ != DOCKER_CLIENT.inspect_image(image_name=image, pull=False)["Id"]
+ ):
+ updated = True
+ updated_count += 1
+ console.print(
+ f":heavy_check_mark: Image {escape(image)} {'updated' if updated else 'up-to-date'}.",
+ style="bold" if updated else None,
+ highlight=False,
+ )
+ except ContainerException as e:
+ console.print(
+ f":heavy_multiplication_x: Image {escape(image)} pull failed: {e.message}",
+ style="bold red",
+ highlight=False,
+ )
+ failed_count += 1
+ console.rule()
+ console.print(
+ f"Images updated: {updated_count}, Images failed: {failed_count}, total images processed: {len(image_list)}."
+ )
+
+
+@localstack.command(name="completion", short_help="CLI shell completion")
+@click.pass_context
+@click.argument(
+ "shell", required=True, type=click.Choice(["bash", "zsh", "fish"], case_sensitive=False)
+)
+@publish_invocation
+def localstack_completion(ctx: click.Context, shell: str) -> None:
+ """
+ Print shell completion code for the specified shell (bash, zsh, or fish).
+ The shell code must be evaluated to enable the interactive shell completion of LocalStack CLI commands.
+ This is usually done by sourcing it from the .bash_profile.
+
+ \b
+ Examples:
+ # Bash
+ ## Bash completion on Linux depends on the 'bash-completion' package.
+ ## Write the LocalStack CLI completion code for bash to a file and source it from .bash_profile
+ localstack completion bash > ~/.localstack/completion.bash.inc
+ printf "
+ # LocalStack CLI bash completion
+ source '$HOME/.localstack/completion.bash.inc'
+ " >> $HOME/.bash_profile
+ source $HOME/.bash_profile
+ \b
+ # zsh
+ ## Set the LocalStack completion code for zsh to autoload on startup:
+ localstack completion zsh > "${fpath[1]}/_localstack"
+ \b
+ # fish
+ ## Set the LocalStack completion code for fish to autoload on startup:
+ localstack completion fish > ~/.config/fish/completions/localstack.fish
+ """
+
+ # lookup the completion, raise an error if the given completion is not found
+ import click.shell_completion
+
+ comp_cls = click.shell_completion.get_completion_class(shell)
+ if comp_cls is None:
+ raise CLIError("Completion for given shell could not be found.")
+
+ # Click's program name is the base path of sys.argv[0]
+ path = sys.argv[0]
+ prog_name = os.path.basename(path)
+
+ # create the completion variable according to the docs
+ # https://click.palletsprojects.com/en/8.1.x/shell-completion/#enabling-completion
+ complete_var = f"_{prog_name}_COMPLETE".replace("-", "_").upper()
+
+ # instantiate the completion class and print the completion source
+ comp = comp_cls(ctx.command, {}, prog_name, complete_var)
+ click.echo(comp.source())
+
+
+def print_version() -> None:
+ console.print(f"- [bold]LocalStack CLI:[/bold] [blue]{VERSION}[/blue]")
+
+
+def print_profile() -> None:
+ if config.LOADED_PROFILES:
+ console.print(f"- [bold]Profile:[/bold] [blue]{', '.join(config.LOADED_PROFILES)}[/blue]")
+
+
+def print_app() -> None:
+ console.print("- [bold]App:[/bold] https://app.localstack.cloud")
+
+
+def print_banner() -> None:
+ print(BANNER)
+
+
+def is_frozen_bundle() -> bool:
+ """
+ :return: true if we are currently running in a frozen bundle / a pyinstaller binary.
+ """
+ # check if we are in a PyInstaller binary
+ # https://pyinstaller.org/en/stable/runtime-information.html
+ return getattr(sys, "frozen", False) and hasattr(sys, "_MEIPASS")
diff --git a/localstack-core/localstack/cli/lpm.py b/localstack-core/localstack/cli/lpm.py
new file mode 100644
index 0000000000000..ad4a6f5489d5c
--- /dev/null
+++ b/localstack-core/localstack/cli/lpm.py
@@ -0,0 +1,139 @@
+import itertools
+import logging
+from multiprocessing.pool import ThreadPool
+from typing import List, Optional
+
+import click
+from rich.console import Console
+
+from localstack import config
+from localstack.cli.exceptions import CLIError
+from localstack.packages import InstallTarget, Package
+from localstack.packages.api import NoSuchPackageException, PackagesPluginManager
+from localstack.utils.bootstrap import setup_logging
+
+LOG = logging.getLogger(__name__)
+
+console = Console()
+
+
+@click.group()
+def cli():
+ """
+ The LocalStack Package Manager (lpm) CLI is a set of commands to install third-party packages used by localstack
+ service providers.
+
+ Here are some handy commands:
+
+ List all packages
+
+ python -m localstack.cli.lpm list
+
+ Install DynamoDB Local:
+
+ python -m localstack.cli.install dynamodb-local
+
+ Install all community packages, four in parallel:
+
+ python -m localstack.cli.lpm list | grep "/community" | cut -d'/' -f1 | xargs python -m localstack.cli.lpm install --parallel 4
+ """
+ setup_logging()
+
+
+def _do_install_package(package: Package, version: str = None, target: InstallTarget = None):
+ console.print(f"installing... [bold]{package}[/bold]")
+ try:
+ package.install(version=version, target=target)
+ console.print(f"[green]installed[/green] [bold]{package}[/bold]")
+ except Exception as e:
+ console.print(f"[red]error[/red] installing {package}: {e}")
+ raise e
+
+
+@cli.command()
+@click.argument("package", nargs=-1, required=True)
+@click.option(
+ "--parallel",
+ type=int,
+ default=1,
+ required=False,
+ help="how many installers to run in parallel processes",
+)
+@click.option(
+ "--version",
+ type=str,
+ default=None,
+ required=False,
+ help="version to install of a package",
+)
+@click.option(
+ "--target",
+ type=click.Choice([target.name.lower() for target in InstallTarget]),
+ default=None,
+ required=False,
+ help="target of the installation",
+)
+def install(
+ package: List[str],
+ parallel: Optional[int] = 1,
+ version: Optional[str] = None,
+ target: Optional[str] = None,
+):
+ """Install one or more packages."""
+ try:
+ if target:
+ target = InstallTarget[str.upper(target)]
+ else:
+ # LPM is meant to be used at build-time, the default target is static_libs
+ target = InstallTarget.STATIC_LIBS
+
+ # collect installers and install in parallel:
+ console.print(f"resolving packages: {package}")
+ package_manager = PackagesPluginManager()
+ package_manager.load_all()
+ package_instances = package_manager.get_packages(package, version)
+
+ if parallel > 1:
+ console.print(f"install {parallel} packages in parallel:")
+
+ config.dirs.mkdirs()
+
+ with ThreadPool(processes=parallel) as pool:
+ pool.starmap(
+ _do_install_package,
+ zip(package_instances, itertools.repeat(version), itertools.repeat(target)),
+ )
+ except NoSuchPackageException as e:
+ LOG.debug(str(e), exc_info=e)
+ raise CLIError(str(e))
+ except Exception as e:
+ LOG.debug("one or more package installations failed.", exc_info=e)
+ raise CLIError("one or more package installations failed.")
+
+
+@cli.command(name="list")
+@click.option(
+ "-v",
+ "--verbose",
+ is_flag=True,
+ default=False,
+ required=False,
+ help="Verbose output (show additional info on packages)",
+)
+def list_packages(verbose: bool):
+ """List available packages of all repositories"""
+ package_manager = PackagesPluginManager()
+ package_manager.load_all()
+ packages = package_manager.get_all_packages()
+ for package_name, package_scope, package_instance in packages:
+ console.print(f"[green]{package_name}[/green]/{package_scope}")
+ if verbose:
+ for version in package_instance.get_versions():
+ if version == package_instance.default_version:
+ console.print(f" - [bold]{version} (default)[/bold]", highlight=False)
+ else:
+ console.print(f" - {version}", highlight=False)
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/localstack-core/localstack/cli/main.py b/localstack-core/localstack/cli/main.py
new file mode 100644
index 0000000000000..d9162bb098a4d
--- /dev/null
+++ b/localstack-core/localstack/cli/main.py
@@ -0,0 +1,21 @@
+import os
+
+
+def main():
+ # indicate to the environment we are starting from the CLI
+ os.environ["LOCALSTACK_CLI"] = "1"
+
+ # config profiles are the first thing that need to be loaded (especially before localstack.config!)
+ from .profiles import set_profile_from_sys_argv
+
+ set_profile_from_sys_argv()
+
+ # initialize CLI plugins
+ from .localstack import create_with_plugins
+
+ cli = create_with_plugins()
+ cli()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/localstack-core/localstack/cli/plugin.py b/localstack-core/localstack/cli/plugin.py
new file mode 100644
index 0000000000000..f9af88474a6d5
--- /dev/null
+++ b/localstack-core/localstack/cli/plugin.py
@@ -0,0 +1,39 @@
+import abc
+import logging
+import os
+
+import click
+from plux import Plugin, PluginManager
+
+LOG = logging.getLogger(__name__)
+
+
+class LocalstackCli:
+ group: click.Group
+
+ def __call__(self, *args, **kwargs):
+ self.group(*args, **kwargs)
+
+
+class LocalstackCliPlugin(Plugin):
+ namespace = "localstack.plugins.cli"
+
+ def load(self, cli) -> None:
+ self.attach(cli)
+
+ @abc.abstractmethod
+ def attach(self, cli: LocalstackCli) -> None:
+ """
+ Attach commands to the `localstack` CLI.
+
+ :param cli: the cli object
+ """
+
+
+def load_cli_plugins(cli):
+ if os.environ.get("DEBUG_PLUGINS", "0").lower() in ("true", "1"):
+ # importing localstack.config is still quite expensive...
+ logging.basicConfig(level=logging.DEBUG)
+
+ loader = PluginManager("localstack.plugins.cli", load_args=(cli,))
+ loader.load_all()
diff --git a/localstack-core/localstack/cli/plugins.py b/localstack-core/localstack/cli/plugins.py
new file mode 100644
index 0000000000000..c63588161d304
--- /dev/null
+++ b/localstack-core/localstack/cli/plugins.py
@@ -0,0 +1,134 @@
+import os
+import time
+
+import click
+from plux import PluginManager
+from plux.build.setuptools import find_plugins
+from plux.core.entrypoint import spec_to_entry_point
+from rich import print as rprint
+from rich.console import Console
+from rich.table import Table
+from rich.tree import Tree
+
+from localstack.cli.exceptions import CLIError
+
+console = Console()
+
+
+@click.group()
+def cli():
+ """
+ The plugins CLI is a set of commands to help troubleshoot LocalStack's plugin mechanism.
+ """
+ pass
+
+
+@cli.command()
+@click.option("--where", type=str, default=os.path.abspath(os.curdir))
+@click.option("--exclude", multiple=True, default=())
+@click.option("--include", multiple=True, default=("*",))
+@click.option("--output", type=str, default="tree")
+def find(where, exclude, include, output):
+ """
+ Find plugins by scanning the given path for PluginSpecs.
+ It starts from the current directory if --where is not specified.
+ This is what a setup.py method would run as a build step, i.e., discovering entry points.
+ """
+ with console.status(f"Scanning path {where}"):
+ plugins = find_plugins(where, exclude, include)
+
+ if output == "tree":
+ tree = Tree("Entrypoints")
+ for namespace, entry_points in plugins.items():
+ node = tree.add(f"[bold]{namespace}")
+
+ t = Table()
+ t.add_column("Name")
+ t.add_column("Location")
+
+ for ep in entry_points:
+ key, value = ep.split("=")
+ t.add_row(key, value)
+
+ node.add(t)
+
+ rprint(tree)
+ elif output == "dict":
+ rprint(dict(plugins))
+ else:
+ raise CLIError("unknown output format %s" % output)
+
+
+@cli.command("list")
+@click.option("--namespace", type=str, required=True)
+def cmd_list(namespace):
+ """
+ List all available plugins using a PluginManager from available endpoints.
+ """
+ manager = PluginManager(namespace)
+
+ t = Table()
+ t.add_column("Name")
+ t.add_column("Factory")
+
+ for spec in manager.list_plugin_specs():
+ ep = spec_to_entry_point(spec)
+ t.add_row(spec.name, ep.value)
+
+ rprint(t)
+
+
+@cli.command()
+@click.option("--namespace", type=str, required=True)
+@click.option("--name", type=str, required=True)
+def load(namespace, name):
+ """
+ Attempts to load a plugin using a PluginManager.
+ """
+ manager = PluginManager(namespace)
+
+ with console.status(f"Loading {namespace}:{name}"):
+ then = time.time()
+ plugin = manager.load(name)
+ took = time.time() - then
+
+ rprint(
+ f":tada: successfully loaded [bold][green]{namespace}[/green][/bold]:[bold][cyan]{name}[/cyan][/bold] ({type(plugin)}"
+ )
+ rprint(f":stopwatch: loading took {took:.4f} s")
+
+
+@cli.command()
+@click.option("--namespace", type=str)
+def cache(namespace):
+ """
+ Outputs the stevedore entrypoints cache from which plugins are loaded.
+ """
+ from stevedore._cache import _c
+
+ data = _c._get_data_for_path(None)
+
+ tree = Tree("Entrypoints")
+ for group, entry_points in data.get("groups").items():
+ if namespace and group != namespace:
+ continue
+ node = tree.add(f"[bold]{group}")
+
+ t = Table()
+ t.add_column("Name")
+ t.add_column("Value")
+
+ for key, value, _ in entry_points:
+ t.add_row(key, value)
+
+ node.add(t)
+
+ if namespace:
+ rprint(t)
+ return
+
+ rprint(tree)
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/localstack-core/localstack/cli/profiles.py b/localstack-core/localstack/cli/profiles.py
new file mode 100644
index 0000000000000..1625b802f73a4
--- /dev/null
+++ b/localstack-core/localstack/cli/profiles.py
@@ -0,0 +1,40 @@
+import os
+import sys
+from typing import Optional
+
+# important: this needs to be free of localstack imports
+
+
+def set_profile_from_sys_argv():
+ """
+ Reads the --profile flag from sys.argv and then sets the 'CONFIG_PROFILE' os variable accordingly. This is later
+ picked up by ``localstack.config``.
+ """
+ profile = parse_profile_argument(sys.argv)
+ if profile:
+ os.environ["CONFIG_PROFILE"] = profile.strip()
+
+
+def parse_profile_argument(args) -> Optional[str]:
+ """
+ Lightweight arg parsing to find ``--profile ``, or ``--profile=`` and return the value of
+ ```` from the given arguments.
+
+ :param args: list of CLI arguments
+ :returns: the value of ``--profile``.
+ """
+ for i, current_arg in enumerate(args):
+ if current_arg.startswith("--profile="):
+ # if using the "=" notation, we remove the "--profile=" prefix to get the value
+ return current_arg[10:]
+ elif current_arg.startswith("-p="):
+ # if using the "=" notation, we remove the "-p=" prefix to get the value
+ return current_arg[3:]
+ if current_arg in ["--profile", "-p"]:
+ # otherwise use the next arg in the args list as value
+ try:
+ return args[i + 1]
+ except KeyError:
+ return None
+
+ return None
diff --git a/localstack-core/localstack/config.py b/localstack-core/localstack/config.py
new file mode 100644
index 0000000000000..a063abb1213c9
--- /dev/null
+++ b/localstack-core/localstack/config.py
@@ -0,0 +1,1612 @@
+import ipaddress
+import logging
+import os
+import platform
+import re
+import socket
+import subprocess
+import tempfile
+import time
+import warnings
+from collections import defaultdict
+from typing import Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union
+
+from localstack import constants
+from localstack.constants import (
+ DEFAULT_BUCKET_MARKER_LOCAL,
+ DEFAULT_DEVELOP_PORT,
+ DEFAULT_VOLUME_DIR,
+ ENV_INTERNAL_TEST_COLLECT_METRIC,
+ ENV_INTERNAL_TEST_RUN,
+ FALSE_STRINGS,
+ LOCALHOST,
+ LOCALHOST_IP,
+ LOCALSTACK_ROOT_FOLDER,
+ LOG_LEVELS,
+ TRACE_LOG_LEVELS,
+ TRUE_STRINGS,
+)
+
+T = TypeVar("T", str, int)
+
+# keep track of start time, for performance debugging
+load_start_time = time.time()
+
+
+class Directories:
+ """
+ Holds different directories available to localstack. Some directories are shared between the host and the
+ localstack container, some live only on the host and others in the container.
+
+ Attributes:
+ static_libs: container only; binaries and libraries statically packaged with the image
+ var_libs: shared; binaries and libraries+data computed at runtime: lazy-loaded binaries, ssl cert, ...
+ cache: shared; ephemeral data that has to persist across localstack runs and reboots
+ tmp: container only; ephemeral data that has to persist across localstack runs but not reboots
+ mounted_tmp: shared; same as above, but shared for persistence across different containers, tests, ...
+ functions: shared; volume to communicate between host<->lambda containers
+ data: shared; holds localstack state, pods, ...
+ config: host only; pre-defined configuration values, cached credentials, machine id, ...
+ init: shared; user-defined provisioning scripts executed in the container when it starts
+ logs: shared; log files produced by localstack
+ """
+
+ static_libs: str
+ var_libs: str
+ cache: str
+ tmp: str
+ mounted_tmp: str
+ functions: str
+ data: str
+ config: str
+ init: str
+ logs: str
+
+ def __init__(
+ self,
+ static_libs: str,
+ var_libs: str,
+ cache: str,
+ tmp: str,
+ mounted_tmp: str,
+ functions: str,
+ data: str,
+ config: str,
+ init: str,
+ logs: str,
+ ) -> None:
+ super().__init__()
+ self.static_libs = static_libs
+ self.var_libs = var_libs
+ self.cache = cache
+ self.tmp = tmp
+ self.mounted_tmp = mounted_tmp
+ self.functions = functions
+ self.data = data
+ self.config = config
+ self.init = init
+ self.logs = logs
+
+ @staticmethod
+ def defaults() -> "Directories":
+ """Returns Localstack directory paths based on the localstack filesystem hierarchy."""
+ return Directories(
+ static_libs="/usr/lib/localstack",
+ var_libs=f"{DEFAULT_VOLUME_DIR}/lib",
+ cache=f"{DEFAULT_VOLUME_DIR}/cache",
+ tmp=os.path.join(tempfile.gettempdir(), "localstack"),
+ mounted_tmp=f"{DEFAULT_VOLUME_DIR}/tmp",
+ functions=f"{DEFAULT_VOLUME_DIR}/tmp", # FIXME: remove - this was misconceived
+ data=f"{DEFAULT_VOLUME_DIR}/state",
+ logs=f"{DEFAULT_VOLUME_DIR}/logs",
+ config="/etc/localstack/conf.d", # for future use
+ init="/etc/localstack/init",
+ )
+
+ @staticmethod
+ def for_container() -> "Directories":
+ """
+ Returns Localstack directory paths as they are defined within the container. Everything shared and writable
+ lives in /var/lib/localstack or {tempfile.gettempdir()}/localstack.
+
+ :returns: Directories object
+ """
+ defaults = Directories.defaults()
+
+ return Directories(
+ static_libs=defaults.static_libs,
+ var_libs=defaults.var_libs,
+ cache=defaults.cache,
+ tmp=defaults.tmp,
+ mounted_tmp=defaults.mounted_tmp,
+ functions=defaults.functions,
+ data=defaults.data if PERSISTENCE else os.path.join(defaults.tmp, "state"),
+ config=defaults.config,
+ logs=defaults.logs,
+ init=defaults.init,
+ )
+
+ @staticmethod
+ def for_host() -> "Directories":
+ """Return directories used for running localstack in host mode. Note that these are *not* the directories
+ that are mounted into the container when the user starts localstack."""
+ root = os.environ.get("FILESYSTEM_ROOT") or os.path.join(
+ LOCALSTACK_ROOT_FOLDER, ".filesystem"
+ )
+ root = os.path.abspath(root)
+
+ defaults = Directories.for_container()
+
+ tmp = os.path.join(root, defaults.tmp.lstrip("/"))
+ data = os.path.join(root, defaults.data.lstrip("/"))
+
+ return Directories(
+ static_libs=os.path.join(root, defaults.static_libs.lstrip("/")),
+ var_libs=os.path.join(root, defaults.var_libs.lstrip("/")),
+ cache=os.path.join(root, defaults.cache.lstrip("/")),
+ tmp=tmp,
+ mounted_tmp=os.path.join(root, defaults.mounted_tmp.lstrip("/")),
+ functions=os.path.join(root, defaults.functions.lstrip("/")),
+ data=data if PERSISTENCE else os.path.join(tmp, "state"),
+ config=os.path.join(root, defaults.config.lstrip("/")),
+ init=os.path.join(root, defaults.init.lstrip("/")),
+ logs=os.path.join(root, defaults.logs.lstrip("/")),
+ )
+
+ @staticmethod
+ def for_cli() -> "Directories":
+ """Returns directories used for when running localstack CLI commands from the host system. Unlike
+ ``for_container``, these needs to be cross-platform. Ideally, this should not be needed at all,
+ because the localstack runtime and CLI do not share any control paths. There are a handful of
+ situations where directories or files may be created lazily for CLI commands. Some paths are
+ intentionally set to None to provoke errors if these paths are used from the CLI - which they
+ shouldn't. This is a symptom of not having a clear separation between CLI/runtime code, which will
+ be a future project."""
+ import tempfile
+
+ from localstack.utils import files
+
+ tmp_dir = os.path.join(tempfile.gettempdir(), "localstack-cli")
+ cache_dir = (files.get_user_cache_dir()).absolute() / "localstack-cli"
+
+ return Directories(
+ static_libs=None,
+ var_libs=None,
+ cache=str(cache_dir), # used by analytics metadata
+ tmp=tmp_dir,
+ mounted_tmp=tmp_dir,
+ functions=None,
+ data=os.path.join(tmp_dir, "state"), # used by localstack-pro config TODO: remove
+ logs=os.path.join(tmp_dir, "logs"), # used for container logs
+ config=None, # in the context of the CLI, config.CONFIG_DIR should be used
+ init=None,
+ )
+
+ def mkdirs(self):
+ for folder in [
+ self.static_libs,
+ self.var_libs,
+ self.cache,
+ self.tmp,
+ self.mounted_tmp,
+ self.functions,
+ self.data,
+ self.config,
+ self.init,
+ self.logs,
+ ]:
+ if folder and not os.path.exists(folder):
+ try:
+ os.makedirs(folder)
+ except Exception:
+ # this can happen due to a race condition when starting
+ # multiple processes in parallel. Should be safe to ignore
+ pass
+
+ def __str__(self):
+ return str(self.__dict__)
+
+
+def eval_log_type(env_var_name: str) -> Union[str, bool]:
+ """Get the log type from environment variable"""
+ ls_log = os.environ.get(env_var_name, "").lower().strip()
+ return ls_log if ls_log in LOG_LEVELS else False
+
+
+def parse_boolean_env(env_var_name: str) -> Optional[bool]:
+ """Parse the value of the given env variable and return True/False, or None if it is not a boolean value."""
+ value = os.environ.get(env_var_name, "").lower().strip()
+ if value in TRUE_STRINGS:
+ return True
+ if value in FALSE_STRINGS:
+ return False
+ return None
+
+
+def is_env_true(env_var_name: str) -> bool:
+ """Whether the given environment variable has a truthy value."""
+ return os.environ.get(env_var_name, "").lower().strip() in TRUE_STRINGS
+
+
+def is_env_not_false(env_var_name: str) -> bool:
+ """Whether the given environment variable is empty or has a truthy value."""
+ return os.environ.get(env_var_name, "").lower().strip() not in FALSE_STRINGS
+
+
+def load_environment(profiles: str = None, env=os.environ) -> List[str]:
+ """Loads the environment variables from ~/.localstack/{profile}.env, for each profile listed in the profiles.
+ :param env: environment to load profile to. Defaults to `os.environ`
+ :param profiles: a comma separated list of profiles to load (defaults to "default")
+ :returns str: the list of the actually loaded profiles (might be the fallback)
+ """
+ if not profiles:
+ profiles = "default"
+
+ profiles = profiles.split(",")
+ environment = {}
+ import dotenv
+
+ for profile in profiles:
+ profile = profile.strip()
+ path = os.path.join(CONFIG_DIR, f"{profile}.env")
+ if not os.path.exists(path):
+ continue
+ environment.update(dotenv.dotenv_values(path))
+
+ for k, v in environment.items():
+ # we do not want to override the environment
+ if k not in env and v is not None:
+ env[k] = v
+
+ return profiles
+
+
+def is_persistence_enabled() -> bool:
+ return PERSISTENCE and dirs.data
+
+
+def is_linux() -> bool:
+ return platform.system() == "Linux"
+
+
+def is_macos() -> bool:
+ return platform.system() == "Darwin"
+
+
+def is_windows() -> bool:
+ return platform.system().lower() == "windows"
+
+
+def ping(host):
+ """Returns True if the host responds to a ping request"""
+ is_in_windows = is_windows()
+ ping_opts = "-n 1 -w 2000" if is_in_windows else "-c 1 -W 2"
+ args = "ping %s %s" % (ping_opts, host)
+ return (
+ subprocess.call(
+ args, shell=not is_in_windows, stdout=subprocess.PIPE, stderr=subprocess.PIPE
+ )
+ == 0
+ )
+
+
+def in_docker():
+ """
+ Returns True if running in a docker container, else False
+ Ref. https://docs.docker.com/config/containers/runmetrics/#control-groups
+ """
+ if OVERRIDE_IN_DOCKER is not None:
+ return OVERRIDE_IN_DOCKER
+
+ # check some marker files that we create in our Dockerfiles
+ for path in [
+ "/usr/lib/localstack/.community-version",
+ "/usr/lib/localstack/.pro-version",
+ "/tmp/localstack/.marker",
+ ]:
+ if os.path.isfile(path):
+ return True
+
+ # details: https://github.com/localstack/localstack/pull/4352
+ if os.path.exists("/.dockerenv"):
+ return True
+ if os.path.exists("/run/.containerenv"):
+ return True
+
+ if not os.path.exists("/proc/1/cgroup"):
+ return False
+ try:
+ if any(
+ [
+ os.path.exists("/sys/fs/cgroup/memory/docker/"),
+ any(
+ "docker-" in file_names
+ for file_names in os.listdir("/sys/fs/cgroup/memory/system.slice")
+ ),
+ os.path.exists("/sys/fs/cgroup/docker/"),
+ any(
+ "docker-" in file_names
+ for file_names in os.listdir("/sys/fs/cgroup/system.slice/")
+ ),
+ ]
+ ):
+ return False
+ except Exception:
+ pass
+ with open("/proc/1/cgroup", "rt") as ifh:
+ content = ifh.read()
+ if "docker" in content or "buildkit" in content:
+ return True
+ os_hostname = socket.gethostname()
+ if os_hostname and os_hostname in content:
+ return True
+
+ # containerd does not set any specific file or config, but it does use
+ # io.containerd.snapshotter.v1.overlayfs as the overlay filesystem for `/`.
+ try:
+ with open("/proc/mounts", "rt") as infile:
+ for line in infile:
+ line = line.strip()
+
+ if not line:
+ continue
+
+ # skip comments
+ if line[0] == "#":
+ continue
+
+ # format (man 5 fstab)
+ # ...
+ parts = line.split()
+ if len(parts) < 4:
+ # badly formatted line
+ continue
+
+ mount_point = parts[1]
+ options = parts[3]
+
+ # only consider the root filesystem
+ if mount_point != "/":
+ continue
+
+ if "io.containerd" in options:
+ return True
+
+ except FileNotFoundError:
+ pass
+
+ return False
+
+
+# whether the `in_docker` check should always return True or False
+OVERRIDE_IN_DOCKER = parse_boolean_env("OVERRIDE_IN_DOCKER")
+
+is_in_docker = in_docker()
+is_in_linux = is_linux()
+is_in_macos = is_macos()
+default_ip = "0.0.0.0" if is_in_docker else "127.0.0.1"
+
+# CLI specific: the configuration profile to load
+CONFIG_PROFILE = os.environ.get("CONFIG_PROFILE", "").strip()
+
+# CLI specific: host configuration directory
+CONFIG_DIR = os.environ.get("CONFIG_DIR", os.path.expanduser("~/.localstack"))
+
+# keep this on top to populate the environment
+try:
+ # CLI specific: the actually loaded configuration profile
+ LOADED_PROFILES = load_environment(CONFIG_PROFILE)
+except ImportError:
+ # dotenv may not be available in lambdas or other environments where config is loaded
+ LOADED_PROFILES = None
+
+# loaded components name - default: all components are loaded and the first one is chosen
+RUNTIME_COMPONENTS = os.environ.get("RUNTIME_COMPONENTS", "").strip()
+
+# directory for persisting data (TODO: deprecated, simply use PERSISTENCE=1)
+DATA_DIR = os.environ.get("DATA_DIR", "").strip()
+
+# whether localstack should persist service state across localstack runs
+PERSISTENCE = is_env_true("PERSISTENCE")
+
+# the strategy for loading snapshots from disk when `PERSISTENCE=1` is used (on_startup, on_request, manual)
+SNAPSHOT_LOAD_STRATEGY = os.environ.get("SNAPSHOT_LOAD_STRATEGY", "").upper()
+
+# the strategy saving snapshots to disk when `PERSISTENCE=1` is used (on_shutdown, on_request, scheduled, manual)
+SNAPSHOT_SAVE_STRATEGY = os.environ.get("SNAPSHOT_SAVE_STRATEGY", "").upper()
+
+# the flush interval (in seconds) for persistence when the snapshot save strategy is set to "scheduled"
+SNAPSHOT_FLUSH_INTERVAL = int(os.environ.get("SNAPSHOT_FLUSH_INTERVAL") or 15)
+
+# whether to clear config.dirs.tmp on startup and shutdown
+CLEAR_TMP_FOLDER = is_env_not_false("CLEAR_TMP_FOLDER")
+
+# folder for temporary files and data
+TMP_FOLDER = os.path.join(tempfile.gettempdir(), "localstack")
+
+# this is exclusively for the CLI to configure the container mount into /var/lib/localstack
+VOLUME_DIR = os.environ.get("LOCALSTACK_VOLUME_DIR", "").strip() or TMP_FOLDER
+
+# fix for Mac OS, to be able to mount /var/folders in Docker
+if TMP_FOLDER.startswith("/var/folders/") and os.path.exists("/private%s" % TMP_FOLDER):
+ TMP_FOLDER = "/private%s" % TMP_FOLDER
+
+# whether to enable verbose debug logging ("LOG" is used when using the CLI with LOCALSTACK_LOG instead of LS_LOG)
+LS_LOG = eval_log_type("LS_LOG") or eval_log_type("LOG")
+DEBUG = is_env_true("DEBUG") or LS_LOG in TRACE_LOG_LEVELS
+
+# PUBLIC PREVIEW: 0 (default), 1 (preview)
+# When enabled it triggers specialised workflows for the debugging.
+LAMBDA_DEBUG_MODE = is_env_true("LAMBDA_DEBUG_MODE")
+
+# path to the lambda debug mode configuration file.
+LAMBDA_DEBUG_MODE_CONFIG_PATH = os.environ.get("LAMBDA_DEBUG_MODE_CONFIG_PATH")
+
+# whether to enable debugpy
+DEVELOP = is_env_true("DEVELOP")
+
+# PORT FOR DEBUGGER
+DEVELOP_PORT = int(os.environ.get("DEVELOP_PORT", "").strip() or DEFAULT_DEVELOP_PORT)
+
+# whether to make debugpy wait for a debbuger client
+WAIT_FOR_DEBUGGER = is_env_true("WAIT_FOR_DEBUGGER")
+
+# whether to assume http or https for `get_protocol`
+USE_SSL = is_env_true("USE_SSL")
+
+# Whether to report internal failures as 500 or 501 errors.
+FAIL_FAST = is_env_true("FAIL_FAST")
+
+# whether to run in TF compatibility mode for TF integration tests
+# (e.g., returning verbatim ports for ELB resources, rather than edge port 4566, etc.)
+TF_COMPAT_MODE = is_env_true("TF_COMPAT_MODE")
+
+# default encoding used to convert strings to byte arrays (mainly for Python 3 compatibility)
+DEFAULT_ENCODING = "utf-8"
+
+# path to local Docker UNIX domain socket
+DOCKER_SOCK = os.environ.get("DOCKER_SOCK", "").strip() or "/var/run/docker.sock"
+
+# additional flags to pass to "docker run" when starting the stack in Docker
+DOCKER_FLAGS = os.environ.get("DOCKER_FLAGS", "").strip()
+
+# command used to run Docker containers (e.g., set to "sudo docker" to run as sudo)
+DOCKER_CMD = os.environ.get("DOCKER_CMD", "").strip() or "docker"
+
+# use the command line docker client instead of the new sdk version, might get removed in the future
+LEGACY_DOCKER_CLIENT = is_env_true("LEGACY_DOCKER_CLIENT")
+
+# Docker image to use when starting up containers for port checks
+PORTS_CHECK_DOCKER_IMAGE = os.environ.get("PORTS_CHECK_DOCKER_IMAGE", "").strip()
+
+
+def is_trace_logging_enabled():
+ if LS_LOG:
+ log_level = str(LS_LOG).upper()
+ return log_level.lower() in TRACE_LOG_LEVELS
+ return False
+
+
+# set log levels immediately, but will be overwritten later by setup_logging
+if DEBUG:
+ logging.getLogger("").setLevel(logging.DEBUG)
+ logging.getLogger("localstack").setLevel(logging.DEBUG)
+
+LOG = logging.getLogger(__name__)
+if is_trace_logging_enabled():
+ load_end_time = time.time()
+ LOG.debug(
+ "Initializing the configuration took %s ms", int((load_end_time - load_start_time) * 1000)
+ )
+
+
+def is_ipv6_address(host: str) -> bool:
+ """
+ Returns True if the given host is an IPv6 address.
+ """
+
+ if not host:
+ return False
+
+ try:
+ ipaddress.IPv6Address(host)
+ return True
+ except ipaddress.AddressValueError:
+ return False
+
+
+class HostAndPort:
+ """
+ Definition of an address for a server to listen to.
+
+ Includes a `parse` method to convert from `str`, allowing for default fallbacks, as well as
+ some helper methods to help tests - particularly testing for equality and a hash function
+ so that `HostAndPort` instances can be used as keys to dictionaries.
+ """
+
+ host: str
+ port: int
+
+ def __init__(self, host: str, port: int):
+ self.host = host
+ self.port = port
+
+ @classmethod
+ def parse(
+ cls,
+ input: str,
+ default_host: str,
+ default_port: int,
+ ) -> "HostAndPort":
+ """
+ Parse a `HostAndPort` from strings like:
+ - 0.0.0.0:4566 -> host=0.0.0.0, port=4566
+ - 0.0.0.0 -> host=0.0.0.0, port=`default_port`
+ - :4566 -> host=`default_host`, port=4566
+ - [::]:4566 -> host=[::], port=4566
+ - [::1] -> host=[::1], port=`default_port`
+ """
+ host, port = default_host, default_port
+
+ # recognize IPv6 addresses (+ port)
+ if input.startswith("["):
+ ipv6_pattern = re.compile(r"^\[(?P[^]]+)\](:(?P\d+))?$")
+ match = ipv6_pattern.match(input)
+
+ if match:
+ host = match.group("host")
+ if not is_ipv6_address(host):
+ raise ValueError(
+ f"input looks like an IPv6 address (is enclosed in square brackets), but is not valid: {host}"
+ )
+ port_s = match.group("port")
+ if port_s:
+ port = cls._validate_port(port_s)
+ else:
+ raise ValueError(
+ f'input looks like an IPv6 address, but is invalid. Should be formatted "[ip]:port": {input}'
+ )
+
+ # recognize IPv4 address + port
+ elif ":" in input:
+ hostname, port_s = input.split(":", 1)
+ if hostname.strip():
+ host = hostname.strip()
+ port = cls._validate_port(port_s)
+ else:
+ if input.strip():
+ host = input.strip()
+
+ # validation
+ if port < 0 or port >= 2**16:
+ raise ValueError("port out of range")
+
+ return cls(host=host, port=port)
+
+ @classmethod
+ def _validate_port(cls, port_s: str) -> int:
+ try:
+ port = int(port_s)
+ except ValueError as e:
+ raise ValueError(f"specified port {port_s} not a number") from e
+
+ return port
+
+ def _get_unprivileged_port_range_start(self) -> int:
+ try:
+ with open(
+ "/proc/sys/net/ipv4/ip_unprivileged_port_start", "rt"
+ ) as unprivileged_port_start:
+ port = unprivileged_port_start.read()
+ return int(port.strip())
+ except Exception:
+ return 1024
+
+ def is_unprivileged(self) -> bool:
+ return self.port >= self._get_unprivileged_port_range_start()
+
+ def host_and_port(self):
+ formatted_host = f"[{self.host}]" if is_ipv6_address(self.host) else self.host
+ return f"{formatted_host}:{self.port}" if self.port is not None else formatted_host
+
+ def __hash__(self) -> int:
+ return hash((self.host, self.port))
+
+ # easier tests
+ def __eq__(self, other: "str | HostAndPort") -> bool:
+ if isinstance(other, self.__class__):
+ return self.host == other.host and self.port == other.port
+ elif isinstance(other, str):
+ return str(self) == other
+ else:
+ raise TypeError(f"cannot compare {self.__class__} to {other.__class__}")
+
+ def __str__(self) -> str:
+ return self.host_and_port()
+
+ def __repr__(self) -> str:
+ return f"HostAndPort(host={self.host}, port={self.port})"
+
+
+class UniqueHostAndPortList(List[HostAndPort]):
+ """
+ Container type that ensures that ports added to the list are unique based
+ on these rules:
+ - :: "trumps" any other binding on the same port, including both IPv6 and IPv4
+ addresses. All other bindings for this port are removed, since :: already
+ covers all interfaces. For example, adding 127.0.0.1:4566, [::1]:4566,
+ and [::]:4566 would result in only [::]:4566 being preserved.
+ - 0.0.0.0 "trumps" any other binding on IPv4 addresses only. IPv6 addresses
+ are not removed.
+ - Identical identical hosts and ports are de-duped
+ """
+
+ def __init__(self, iterable: Union[List[HostAndPort], None] = None):
+ super().__init__(iterable or [])
+ self._ensure_unique()
+
+ def _ensure_unique(self):
+ """
+ Ensure that all bindings on the same port are de-duped.
+ """
+ if len(self) <= 1:
+ return
+
+ unique: List[HostAndPort] = list()
+
+ # Build a dictionary of hosts by port
+ hosts_by_port: Dict[int, List[str]] = defaultdict(list)
+ for item in self:
+ hosts_by_port[item.port].append(item.host)
+
+ # For any given port, dedupe the hosts
+ for port, hosts in hosts_by_port.items():
+ deduped_hosts = set(hosts)
+
+ # IPv6 all interfaces: this is the most general binding.
+ # Any others should be removed.
+ if "::" in deduped_hosts:
+ unique.append(HostAndPort(host="::", port=port))
+ continue
+ # IPv4 all interfaces: this is the next most general binding.
+ # Any others should be removed.
+ if "0.0.0.0" in deduped_hosts:
+ unique.append(HostAndPort(host="0.0.0.0", port=port))
+ continue
+
+ # All other bindings just need to be unique
+ unique.extend([HostAndPort(host=host, port=port) for host in deduped_hosts])
+
+ self.clear()
+ self.extend(unique)
+
+ def append(self, value: HostAndPort):
+ super().append(value)
+ self._ensure_unique()
+
+
+def populate_edge_configuration(
+ environment: Mapping[str, str],
+) -> Tuple[HostAndPort, UniqueHostAndPortList]:
+ """Populate the LocalStack edge configuration from environment variables."""
+ localstack_host_raw = environment.get("LOCALSTACK_HOST")
+ gateway_listen_raw = environment.get("GATEWAY_LISTEN")
+
+ # parse gateway listen from multiple components
+ if gateway_listen_raw is not None:
+ gateway_listen = []
+ for address in gateway_listen_raw.split(","):
+ gateway_listen.append(
+ HostAndPort.parse(
+ address.strip(),
+ default_host=default_ip,
+ default_port=constants.DEFAULT_PORT_EDGE,
+ )
+ )
+ else:
+ # use default if gateway listen is not defined
+ gateway_listen = [HostAndPort(host=default_ip, port=constants.DEFAULT_PORT_EDGE)]
+
+ # the actual value of the LOCALSTACK_HOST port now depends on what gateway listen actually listens to.
+ if localstack_host_raw is None:
+ localstack_host = HostAndPort(
+ host=constants.LOCALHOST_HOSTNAME, port=gateway_listen[0].port
+ )
+ else:
+ localstack_host = HostAndPort.parse(
+ localstack_host_raw,
+ default_host=constants.LOCALHOST_HOSTNAME,
+ default_port=gateway_listen[0].port,
+ )
+
+ assert gateway_listen is not None
+ assert localstack_host is not None
+
+ return (
+ localstack_host,
+ UniqueHostAndPortList(gateway_listen),
+ )
+
+
+# How to access LocalStack
+(
+ # -- Cosmetic
+ LOCALSTACK_HOST,
+ # -- Edge configuration
+ # Main configuration of the listen address of the hypercorn proxy. Of the form
+ # :(,:port>)*
+ GATEWAY_LISTEN,
+) = populate_edge_configuration(os.environ)
+
+GATEWAY_WORKER_COUNT = int(os.environ.get("GATEWAY_WORKER_COUNT") or 1000)
+
+# the gateway server that should be used (supported: hypercorn, twisted dev: werkzeug)
+GATEWAY_SERVER = os.environ.get("GATEWAY_SERVER", "").strip() or "twisted"
+
+# IP of the docker bridge used to enable access between containers
+DOCKER_BRIDGE_IP = os.environ.get("DOCKER_BRIDGE_IP", "").strip()
+
+# Default timeout for Docker API calls sent by the Docker SDK client, in seconds.
+DOCKER_SDK_DEFAULT_TIMEOUT_SECONDS = int(os.environ.get("DOCKER_SDK_DEFAULT_TIMEOUT_SECONDS") or 60)
+
+# Default number of retries to connect to the Docker API by the Docker SDK client.
+DOCKER_SDK_DEFAULT_RETRIES = int(os.environ.get("DOCKER_SDK_DEFAULT_RETRIES") or 0)
+
+# whether to enable API-based updates of configuration variables at runtime
+ENABLE_CONFIG_UPDATES = is_env_true("ENABLE_CONFIG_UPDATES")
+
+# CORS settings
+DISABLE_CORS_HEADERS = is_env_true("DISABLE_CORS_HEADERS")
+DISABLE_CORS_CHECKS = is_env_true("DISABLE_CORS_CHECKS")
+DISABLE_CUSTOM_CORS_S3 = is_env_true("DISABLE_CUSTOM_CORS_S3")
+DISABLE_CUSTOM_CORS_APIGATEWAY = is_env_true("DISABLE_CUSTOM_CORS_APIGATEWAY")
+EXTRA_CORS_ALLOWED_HEADERS = os.environ.get("EXTRA_CORS_ALLOWED_HEADERS", "").strip()
+EXTRA_CORS_EXPOSE_HEADERS = os.environ.get("EXTRA_CORS_EXPOSE_HEADERS", "").strip()
+EXTRA_CORS_ALLOWED_ORIGINS = os.environ.get("EXTRA_CORS_ALLOWED_ORIGINS", "").strip()
+DISABLE_PREFLIGHT_PROCESSING = is_env_true("DISABLE_PREFLIGHT_PROCESSING")
+
+# whether to disable publishing events to the API
+DISABLE_EVENTS = is_env_true("DISABLE_EVENTS")
+DEBUG_ANALYTICS = is_env_true("DEBUG_ANALYTICS")
+
+# whether to log fine-grained debugging information for the handler chain
+DEBUG_HANDLER_CHAIN = is_env_true("DEBUG_HANDLER_CHAIN")
+
+# whether to eagerly start services
+EAGER_SERVICE_LOADING = is_env_true("EAGER_SERVICE_LOADING")
+
+# whether to selectively load services in SERVICES
+STRICT_SERVICE_LOADING = is_env_not_false("STRICT_SERVICE_LOADING")
+
+# Whether to skip downloading additional infrastructure components (e.g., custom Elasticsearch versions)
+SKIP_INFRA_DOWNLOADS = os.environ.get("SKIP_INFRA_DOWNLOADS", "").strip()
+
+# Whether to skip downloading our signed SSL cert.
+SKIP_SSL_CERT_DOWNLOAD = is_env_true("SKIP_SSL_CERT_DOWNLOAD")
+
+# Absolute path to a custom certificate (pem file)
+CUSTOM_SSL_CERT_PATH = os.environ.get("CUSTOM_SSL_CERT_PATH", "").strip()
+
+# Whether delete the cached signed SSL certificate at startup
+REMOVE_SSL_CERT = is_env_true("REMOVE_SSL_CERT")
+
+# Allow non-standard AWS regions
+ALLOW_NONSTANDARD_REGIONS = is_env_true("ALLOW_NONSTANDARD_REGIONS")
+if ALLOW_NONSTANDARD_REGIONS:
+ os.environ["MOTO_ALLOW_NONEXISTENT_REGION"] = "true"
+
+# name of the main Docker container
+MAIN_CONTAINER_NAME = os.environ.get("MAIN_CONTAINER_NAME", "").strip() or "localstack-main"
+
+# the latest commit id of the repository when the docker image was created
+LOCALSTACK_BUILD_GIT_HASH = os.environ.get("LOCALSTACK_BUILD_GIT_HASH", "").strip() or None
+
+# the date on which the docker image was created
+LOCALSTACK_BUILD_DATE = os.environ.get("LOCALSTACK_BUILD_DATE", "").strip() or None
+
+# Equivalent to HTTP_PROXY, but only applicable for external connections
+OUTBOUND_HTTP_PROXY = os.environ.get("OUTBOUND_HTTP_PROXY", "")
+
+# Equivalent to HTTPS_PROXY, but only applicable for external connections
+OUTBOUND_HTTPS_PROXY = os.environ.get("OUTBOUND_HTTPS_PROXY", "")
+
+# Feature flag to enable validation of internal endpoint responses in the handler chain. For test use only.
+OPENAPI_VALIDATE_RESPONSE = is_env_true("OPENAPI_VALIDATE_RESPONSE")
+# Flag to enable the validation of the requests made to the LocalStack internal endpoints. Active by default.
+OPENAPI_VALIDATE_REQUEST = is_env_true("OPENAPI_VALIDATE_REQUEST")
+
+# whether to skip waiting for the infrastructure to shut down, or exit immediately
+FORCE_SHUTDOWN = is_env_not_false("FORCE_SHUTDOWN")
+
+# set variables no_proxy, i.e., run internal service calls directly
+no_proxy = ",".join([constants.LOCALHOST_HOSTNAME, LOCALHOST, LOCALHOST_IP, "[::1]"])
+if os.environ.get("no_proxy"):
+ os.environ["no_proxy"] += "," + no_proxy
+elif os.environ.get("NO_PROXY"):
+ os.environ["NO_PROXY"] += "," + no_proxy
+else:
+ os.environ["no_proxy"] = no_proxy
+
+# additional CLI commands, can be set by plugins
+CLI_COMMANDS = {}
+
+# determine IP of Docker bridge
+if not DOCKER_BRIDGE_IP:
+ DOCKER_BRIDGE_IP = "172.17.0.1"
+ if is_in_docker:
+ candidates = (DOCKER_BRIDGE_IP, "172.18.0.1")
+ for ip in candidates:
+ # TODO: remove from here - should not perform I/O operations in top-level config.py
+ if ping(ip):
+ DOCKER_BRIDGE_IP = ip
+ break
+
+# AWS account used to store internal resources such as Lambda archives or internal SQS queues.
+# It should not be modified by the user, or visible to him, except as through a presigned url with the
+# get-function call.
+INTERNAL_RESOURCE_ACCOUNT = os.environ.get("INTERNAL_RESOURCE_ACCOUNT") or "949334387222"
+
+# TODO: remove with 4.1.0
+# Determine which implementation to use for the event rule / event filtering engine used by multiple services:
+# EventBridge, EventBridge Pipes, Lambda Event Source Mapping
+# Options: python (default) | java (deprecated since 4.0.3)
+EVENT_RULE_ENGINE = os.environ.get("EVENT_RULE_ENGINE", "python").strip()
+
+# -----
+# SERVICE-SPECIFIC CONFIGS BELOW
+# -----
+
+# port ranges for external service instances (f.e. elasticsearch clusters, opensearch clusters,...)
+EXTERNAL_SERVICE_PORTS_START = int(
+ os.environ.get("EXTERNAL_SERVICE_PORTS_START")
+ or os.environ.get("SERVICE_INSTANCES_PORTS_START")
+ or 4510
+)
+EXTERNAL_SERVICE_PORTS_END = int(
+ os.environ.get("EXTERNAL_SERVICE_PORTS_END")
+ or os.environ.get("SERVICE_INSTANCES_PORTS_END")
+ or (EXTERNAL_SERVICE_PORTS_START + 50)
+)
+
+# The default container runtime to use
+CONTAINER_RUNTIME = os.environ.get("CONTAINER_RUNTIME", "").strip() or "docker"
+
+# PUBLIC v1: -Xmx512M (example) Currently not supported in new provider but possible via custom entrypoint.
+# Allow passing custom JVM options to Java Lambdas executed in Docker.
+LAMBDA_JAVA_OPTS = os.environ.get("LAMBDA_JAVA_OPTS", "").strip()
+
+# limit in which to kinesis-mock will start throwing exceptions
+KINESIS_SHARD_LIMIT = os.environ.get("KINESIS_SHARD_LIMIT", "").strip() or "100"
+KINESIS_PERSISTENCE = is_env_not_false("KINESIS_PERSISTENCE")
+
+# limit in which to kinesis-mock will start throwing exceptions
+KINESIS_ON_DEMAND_STREAM_COUNT_LIMIT = (
+ os.environ.get("KINESIS_ON_DEMAND_STREAM_COUNT_LIMIT", "").strip() or "10"
+)
+
+# delay in kinesis-mock response when making changes to streams
+KINESIS_LATENCY = os.environ.get("KINESIS_LATENCY", "").strip() or "500"
+
+# Delay between data persistence (in seconds)
+KINESIS_MOCK_PERSIST_INTERVAL = os.environ.get("KINESIS_MOCK_PERSIST_INTERVAL", "").strip() or "5s"
+
+# Kinesis mock log level override when inconsistent with LS_LOG (e.g., when LS_LOG=debug)
+KINESIS_MOCK_LOG_LEVEL = os.environ.get("KINESIS_MOCK_LOG_LEVEL", "").strip()
+
+# randomly inject faults to Kinesis
+KINESIS_ERROR_PROBABILITY = float(os.environ.get("KINESIS_ERROR_PROBABILITY", "").strip() or 0.0)
+
+# randomly inject faults to DynamoDB
+DYNAMODB_ERROR_PROBABILITY = float(os.environ.get("DYNAMODB_ERROR_PROBABILITY", "").strip() or 0.0)
+DYNAMODB_READ_ERROR_PROBABILITY = float(
+ os.environ.get("DYNAMODB_READ_ERROR_PROBABILITY", "").strip() or 0.0
+)
+DYNAMODB_WRITE_ERROR_PROBABILITY = float(
+ os.environ.get("DYNAMODB_WRITE_ERROR_PROBABILITY", "").strip() or 0.0
+)
+
+# JAVA EE heap size for dynamodb
+DYNAMODB_HEAP_SIZE = os.environ.get("DYNAMODB_HEAP_SIZE", "").strip() or "256m"
+
+# single DB instance across multiple credentials are regions
+DYNAMODB_SHARE_DB = int(os.environ.get("DYNAMODB_SHARE_DB") or 0)
+
+# the port on which to expose dynamodblocal
+DYNAMODB_LOCAL_PORT = int(os.environ.get("DYNAMODB_LOCAL_PORT") or 0)
+
+# Enables the automatic removal of stale KV pais based on TTL
+DYNAMODB_REMOVE_EXPIRED_ITEMS = is_env_true("DYNAMODB_REMOVE_EXPIRED_ITEMS")
+
+# Used to toggle PurgeInProgress exceptions when calling purge within 60 seconds
+SQS_DELAY_PURGE_RETRY = is_env_true("SQS_DELAY_PURGE_RETRY")
+
+# Used to toggle QueueDeletedRecently errors when re-creating a queue within 60 seconds of deleting it
+SQS_DELAY_RECENTLY_DELETED = is_env_true("SQS_DELAY_RECENTLY_DELETED")
+
+# Used to toggle MessageRetentionPeriod functionality in SQS queues
+SQS_ENABLE_MESSAGE_RETENTION_PERIOD = is_env_true("SQS_ENABLE_MESSAGE_RETENTION_PERIOD")
+
+# Strategy used when creating SQS queue urls. can be "off", "standard" (default), "domain", "path", or "dynamic"
+SQS_ENDPOINT_STRATEGY = os.environ.get("SQS_ENDPOINT_STRATEGY", "") or "standard"
+
+# Disable the check for MaxNumberOfMessage in SQS ReceiveMessage
+SQS_DISABLE_MAX_NUMBER_OF_MESSAGE_LIMIT = is_env_true("SQS_DISABLE_MAX_NUMBER_OF_MESSAGE_LIMIT")
+
+# Disable cloudwatch metrics for SQS
+SQS_DISABLE_CLOUDWATCH_METRICS = is_env_true("SQS_DISABLE_CLOUDWATCH_METRICS")
+
+# Interval for reporting "approximate" metrics to cloudwatch, default is 60 seconds
+SQS_CLOUDWATCH_METRICS_REPORT_INTERVAL = int(
+ os.environ.get("SQS_CLOUDWATCH_METRICS_REPORT_INTERVAL") or 60
+)
+
+# PUBLIC: Endpoint host under which LocalStack APIs are accessible from Lambda Docker containers.
+HOSTNAME_FROM_LAMBDA = os.environ.get("HOSTNAME_FROM_LAMBDA", "").strip()
+
+# PUBLIC: hot-reload (default v2), __local__ (default v1)
+# Magic S3 bucket name for Hot Reloading. The S3Key points to the source code on the local file system.
+BUCKET_MARKER_LOCAL = (
+ os.environ.get("BUCKET_MARKER_LOCAL", "").strip() or DEFAULT_BUCKET_MARKER_LOCAL
+)
+
+# PUBLIC: Opt-out to inject the environment variable AWS_ENDPOINT_URL for automatic configuration of AWS SDKs:
+# https://docs.aws.amazon.com/sdkref/latest/guide/feature-ss-endpoints.html
+LAMBDA_DISABLE_AWS_ENDPOINT_URL = is_env_true("LAMBDA_DISABLE_AWS_ENDPOINT_URL")
+
+# PUBLIC: bridge (Docker default)
+# Docker network driver for the Lambda and ECS containers. https://docs.docker.com/network/
+LAMBDA_DOCKER_NETWORK = os.environ.get("LAMBDA_DOCKER_NETWORK", "").strip()
+
+# PUBLIC v1: LocalStack DNS (default)
+# Custom DNS server for the container running your lambda function.
+LAMBDA_DOCKER_DNS = os.environ.get("LAMBDA_DOCKER_DNS", "").strip()
+
+# PUBLIC: -e KEY=VALUE -v host:container
+# Additional flags passed to Docker run|create commands.
+LAMBDA_DOCKER_FLAGS = os.environ.get("LAMBDA_DOCKER_FLAGS", "").strip()
+
+# PUBLIC: 0 (default)
+# Enable this flag to run cross-platform compatible lambda functions natively (i.e., Docker selects architecture) and
+# ignore the AWS architectures (i.e., x86_64, arm64) configured for the lambda function.
+LAMBDA_IGNORE_ARCHITECTURE = is_env_true("LAMBDA_IGNORE_ARCHITECTURE")
+
+# TODO: test and add to docs
+# EXPERIMENTAL: 0 (default)
+# prebuild images before execution? Increased cold start time on the tradeoff of increased time until lambda is ACTIVE
+LAMBDA_PREBUILD_IMAGES = is_env_true("LAMBDA_PREBUILD_IMAGES")
+
+# PUBLIC: docker (default), kubernetes (pro)
+# Where Lambdas will be executed.
+LAMBDA_RUNTIME_EXECUTOR = os.environ.get("LAMBDA_RUNTIME_EXECUTOR", CONTAINER_RUNTIME).strip()
+
+# PUBLIC: 20 (default)
+# How many seconds Lambda will wait for the runtime environment to start up.
+LAMBDA_RUNTIME_ENVIRONMENT_TIMEOUT = int(os.environ.get("LAMBDA_RUNTIME_ENVIRONMENT_TIMEOUT") or 20)
+
+# PUBLIC: base images for Lambda (default) https://docs.aws.amazon.com/lambda/latest/dg/runtimes-images.html
+# localstack/services/lambda_/invocation/lambda_models.py:IMAGE_MAPPING
+# Customize the Docker image of Lambda runtimes, either by:
+# a) pattern with placeholder, e.g. custom-repo/lambda-:2022
+# b) json dict mapping the to an image, e.g. {"python3.9": "custom-repo/lambda-py:thon3.9"}
+LAMBDA_RUNTIME_IMAGE_MAPPING = os.environ.get("LAMBDA_RUNTIME_IMAGE_MAPPING", "").strip()
+
+# PUBLIC: 0 (default)
+# Whether to disable usage of deprecated runtimes
+LAMBDA_RUNTIME_VALIDATION = int(os.environ.get("LAMBDA_RUNTIME_VALIDATION") or 0)
+
+# PUBLIC: 1 (default)
+# Whether to remove any Lambda Docker containers.
+LAMBDA_REMOVE_CONTAINERS = (
+ os.environ.get("LAMBDA_REMOVE_CONTAINERS", "").lower().strip() not in FALSE_STRINGS
+)
+
+# PUBLIC: 600000 (default 10min)
+# Time in milliseconds until lambda shuts down the execution environment after the last invocation has been processed.
+# Set to 0 to immediately shut down the execution environment after an invocation.
+LAMBDA_KEEPALIVE_MS = int(os.environ.get("LAMBDA_KEEPALIVE_MS", 600_000))
+
+# PUBLIC: 1000 (default)
+# The maximum number of events that functions can process simultaneously in the current Region.
+# See AWS service quotas: https://docs.aws.amazon.com/general/latest/gr/lambda-service.html
+# Concurrency limits. Like on AWS these apply per account and region.
+LAMBDA_LIMITS_CONCURRENT_EXECUTIONS = int(
+ os.environ.get("LAMBDA_LIMITS_CONCURRENT_EXECUTIONS", 1_000)
+)
+# SEMI-PUBLIC: not actively communicated
+# per account/region: there must be at least unreserved concurrency.
+LAMBDA_LIMITS_MINIMUM_UNRESERVED_CONCURRENCY = int(
+ os.environ.get("LAMBDA_LIMITS_MINIMUM_UNRESERVED_CONCURRENCY", 100)
+)
+# SEMI-PUBLIC: not actively communicated
+LAMBDA_LIMITS_TOTAL_CODE_SIZE = int(os.environ.get("LAMBDA_LIMITS_TOTAL_CODE_SIZE", 80_530_636_800))
+# PUBLIC: documented after AWS changed validation around 2023-11
+LAMBDA_LIMITS_CODE_SIZE_ZIPPED = int(os.environ.get("LAMBDA_LIMITS_CODE_SIZE_ZIPPED", 52_428_800))
+# SEMI-PUBLIC: not actively communicated
+LAMBDA_LIMITS_CODE_SIZE_UNZIPPED = int(
+ os.environ.get("LAMBDA_LIMITS_CODE_SIZE_UNZIPPED", 262_144_000)
+)
+# PUBLIC: documented upon customer request
+LAMBDA_LIMITS_CREATE_FUNCTION_REQUEST_SIZE = int(
+ os.environ.get("LAMBDA_LIMITS_CREATE_FUNCTION_REQUEST_SIZE", 70_167_211)
+)
+# SEMI-PUBLIC: not actively communicated
+LAMBDA_LIMITS_MAX_FUNCTION_ENVVAR_SIZE_BYTES = int(
+ os.environ.get("LAMBDA_LIMITS_MAX_FUNCTION_ENVVAR_SIZE_BYTES", 4 * 1024)
+)
+# SEMI-PUBLIC: not actively communicated
+LAMBDA_LIMITS_MAX_FUNCTION_PAYLOAD_SIZE_BYTES = int(
+ os.environ.get(
+ "LAMBDA_LIMITS_MAX_FUNCTION_PAYLOAD_SIZE_BYTES", 6 * 1024 * 1024 + 100
+ ) # the 100 comes from the init defaults
+)
+
+# DEV: 0 (default unless in host mode on macOS) For LS developers only. Only applies to Docker mode.
+# Whether to explicitly expose a free TCP port in lambda containers when invoking functions in host mode for
+# systems that cannot reach the container via its IPv4. For example, macOS cannot reach Docker containers:
+# https://docs.docker.com/desktop/networking/#i-cannot-ping-my-containers
+LAMBDA_DEV_PORT_EXPOSE = (
+ # Enable this dev flag by default on macOS in host mode (i.e., non-Docker environment)
+ is_env_not_false("LAMBDA_DEV_PORT_EXPOSE")
+ if not is_in_docker and is_in_macos
+ else is_env_true("LAMBDA_DEV_PORT_EXPOSE")
+)
+
+# DEV: only applies to new lambda provider. All LAMBDA_INIT_* configuration are for LS developers only.
+# There are NO stability guarantees, and they may break at any time.
+
+# DEV: Release version of https://github.com/localstack/lambda-runtime-init overriding the current default
+LAMBDA_INIT_RELEASE_VERSION = os.environ.get("LAMBDA_INIT_RELEASE_VERSION")
+# DEV: 0 (default) Enable for mounting of RIE init binary and delve debugger
+LAMBDA_INIT_DEBUG = is_env_true("LAMBDA_INIT_DEBUG")
+# DEV: path to RIE init binary (e.g., var/rapid/init)
+LAMBDA_INIT_BIN_PATH = os.environ.get("LAMBDA_INIT_BIN_PATH")
+# DEV: path to entrypoint script (e.g., var/rapid/entrypoint.sh)
+LAMBDA_INIT_BOOTSTRAP_PATH = os.environ.get("LAMBDA_INIT_BOOTSTRAP_PATH")
+# DEV: path to delve debugger (e.g., var/rapid/dlv)
+LAMBDA_INIT_DELVE_PATH = os.environ.get("LAMBDA_INIT_DELVE_PATH")
+# DEV: Go Delve debug port
+LAMBDA_INIT_DELVE_PORT = int(os.environ.get("LAMBDA_INIT_DELVE_PORT") or 40000)
+# DEV: Time to wait after every invoke as a workaround to fix a race condition in persistence tests
+LAMBDA_INIT_POST_INVOKE_WAIT_MS = os.environ.get("LAMBDA_INIT_POST_INVOKE_WAIT_MS")
+# DEV: sbx_user1051 (default when not provided) Alternative system user or empty string to skip dropping privileges.
+LAMBDA_INIT_USER = os.environ.get("LAMBDA_INIT_USER")
+
+# Adding Stepfunctions default port
+LOCAL_PORT_STEPFUNCTIONS = int(os.environ.get("LOCAL_PORT_STEPFUNCTIONS") or 8083)
+# Stepfunctions lambda endpoint override
+STEPFUNCTIONS_LAMBDA_ENDPOINT = os.environ.get("STEPFUNCTIONS_LAMBDA_ENDPOINT", "").strip()
+
+# path prefix for windows volume mounting
+WINDOWS_DOCKER_MOUNT_PREFIX = os.environ.get("WINDOWS_DOCKER_MOUNT_PREFIX", "/host_mnt")
+
+# whether to skip S3 presign URL signature validation (TODO: currently enabled, until all issues are resolved)
+S3_SKIP_SIGNATURE_VALIDATION = is_env_not_false("S3_SKIP_SIGNATURE_VALIDATION")
+# whether to skip S3 validation of provided KMS key
+S3_SKIP_KMS_KEY_VALIDATION = is_env_not_false("S3_SKIP_KMS_KEY_VALIDATION")
+
+# PUBLIC: 2000 (default)
+# Allows increasing the default char limit for truncation of lambda log lines when printed in the console.
+# This does not affect the logs processing in CloudWatch.
+LAMBDA_TRUNCATE_STDOUT = int(os.getenv("LAMBDA_TRUNCATE_STDOUT") or 2000)
+
+# INTERNAL: 60 (default matching AWS) only applies to new lambda provider
+# Base delay in seconds for async retries. Further retries use: NUM_ATTEMPTS * LAMBDA_RETRY_BASE_DELAY_SECONDS
+# 300 (5min) is the maximum because NUM_ATTEMPTS can be at most 3 and SQS has a message timer limit of 15 min.
+# For example:
+# 1x LAMBDA_RETRY_BASE_DELAY_SECONDS: delay between initial invocation and first retry
+# 2x LAMBDA_RETRY_BASE_DELAY_SECONDS: delay between the first retry and the second retry
+# 3x LAMBDA_RETRY_BASE_DELAY_SECONDS: delay between the second retry and the third retry
+LAMBDA_RETRY_BASE_DELAY_SECONDS = int(os.getenv("LAMBDA_RETRY_BASE_DELAY") or 60)
+
+# PUBLIC: 0 (default)
+# Set to 1 to create lambda functions synchronously (not recommended).
+# Whether Lambda.CreateFunction will block until the function is in a terminal state (Active or Failed).
+# This technically breaks behavior parity but is provided as a simplification over the default AWS behavior and
+# to match the behavior of the old lambda provider.
+LAMBDA_SYNCHRONOUS_CREATE = is_env_true("LAMBDA_SYNCHRONOUS_CREATE")
+
+# URL to a custom OpenSearch/Elasticsearch backend cluster. If this is set to a valid URL, then localstack will not
+# create OpenSearch/Elasticsearch cluster instances, but instead forward all domains to the given backend.
+OPENSEARCH_CUSTOM_BACKEND = os.environ.get("OPENSEARCH_CUSTOM_BACKEND", "").strip()
+
+# Strategy used when creating OpenSearch/Elasticsearch domain endpoints routed through the edge proxy
+# valid values: domain | path | port (off)
+OPENSEARCH_ENDPOINT_STRATEGY = (
+ os.environ.get("OPENSEARCH_ENDPOINT_STRATEGY", "").strip() or "domain"
+)
+if OPENSEARCH_ENDPOINT_STRATEGY == "off":
+ OPENSEARCH_ENDPOINT_STRATEGY = "port"
+
+# Whether to start one cluster per domain (default), or multiplex opensearch domains to a single clusters
+OPENSEARCH_MULTI_CLUSTER = is_env_not_false("OPENSEARCH_MULTI_CLUSTER")
+
+# Whether to really publish to GCM while using SNS Platform Application (needs credentials)
+LEGACY_SNS_GCM_PUBLISHING = is_env_true("LEGACY_SNS_GCM_PUBLISHING")
+
+SNS_SES_SENDER_ADDRESS = os.environ.get("SNS_SES_SENDER_ADDRESS", "").strip()
+
+SNS_CERT_URL_HOST = os.environ.get("SNS_CERT_URL_HOST", "").strip()
+
+# Whether the Next Gen APIGW invocation logic is enabled (on by default)
+APIGW_NEXT_GEN_PROVIDER = os.environ.get("PROVIDER_OVERRIDE_APIGATEWAY", "") in ("next_gen", "")
+
+# Whether the DynamoDBStreams native provider is enabled
+DDB_STREAMS_PROVIDER_V2 = os.environ.get("PROVIDER_OVERRIDE_DYNAMODBSTREAMS", "") == "v2"
+_override_dynamodb_v2 = os.environ.get("PROVIDER_OVERRIDE_DYNAMODB", "")
+if DDB_STREAMS_PROVIDER_V2:
+ # in order to not have conflicts between the 2 implementations, as they are tightly coupled, we need to set DDB
+ # to be v2 as well
+ if not _override_dynamodb_v2:
+ os.environ["PROVIDER_OVERRIDE_DYNAMODB"] = "v2"
+elif _override_dynamodb_v2 == "v2":
+ os.environ["PROVIDER_OVERRIDE_DYNAMODBSTREAMS"] = "v2"
+ DDB_STREAMS_PROVIDER_V2 = True
+
+# TODO remove fallback to LAMBDA_DOCKER_NETWORK with next minor version
+MAIN_DOCKER_NETWORK = os.environ.get("MAIN_DOCKER_NETWORK", "") or LAMBDA_DOCKER_NETWORK
+
+# Whether to return and parse access key ids starting with an "A", like on AWS
+PARITY_AWS_ACCESS_KEY_ID = is_env_true("PARITY_AWS_ACCESS_KEY_ID")
+
+# Show exceptions for CloudFormation deploy errors
+CFN_VERBOSE_ERRORS = is_env_true("CFN_VERBOSE_ERRORS")
+
+# The CFN_STRING_REPLACEMENT_DENY_LIST env variable is a comma separated list of strings that are not allowed to be
+# replaced in CloudFormation templates (e.g. AWS URLs that are usually edited by Localstack to point to itself if found
+# in a CFN template). They are extracted to a list of strings if the env variable is set.
+CFN_STRING_REPLACEMENT_DENY_LIST = [
+ x for x in os.environ.get("CFN_STRING_REPLACEMENT_DENY_LIST", "").split(",") if x
+]
+
+# Set the timeout to deploy each individual CloudFormation resource
+CFN_PER_RESOURCE_TIMEOUT = int(os.environ.get("CFN_PER_RESOURCE_TIMEOUT") or 300)
+
+# How localstack will react to encountering unsupported resource types.
+# By default unsupported resource types will be ignored.
+# EXPERIMENTAL
+CFN_IGNORE_UNSUPPORTED_RESOURCE_TYPES = is_env_not_false("CFN_IGNORE_UNSUPPORTED_RESOURCE_TYPES")
+
+# bind address of local DNS server
+DNS_ADDRESS = os.environ.get("DNS_ADDRESS") or "0.0.0.0"
+# port of the local DNS server
+DNS_PORT = int(os.environ.get("DNS_PORT", "53"))
+
+# Comma-separated list of regex patterns for DNS names to resolve locally.
+# Any DNS name not matched against any of the patterns on this whitelist
+# will resolve it to the real DNS entry, rather than the local one.
+DNS_NAME_PATTERNS_TO_RESOLVE_UPSTREAM = (
+ os.environ.get("DNS_NAME_PATTERNS_TO_RESOLVE_UPSTREAM") or ""
+).strip()
+DNS_LOCAL_NAME_PATTERNS = (os.environ.get("DNS_LOCAL_NAME_PATTERNS") or "").strip() # deprecated
+
+# IP address that AWS endpoints should resolve to in our local DNS server. By default,
+# hostnames resolve to 127.0.0.1, which allows to use the LocalStack APIs transparently
+# from the host machine. If your code is running in Docker, this should be configured
+# to resolve to the Docker bridge network address, e.g., DNS_RESOLVE_IP=172.17.0.1
+DNS_RESOLVE_IP = os.environ.get("DNS_RESOLVE_IP") or LOCALHOST_IP
+
+# fallback DNS server to send upstream requests to
+DNS_SERVER = os.environ.get("DNS_SERVER")
+DNS_VERIFICATION_DOMAIN = os.environ.get("DNS_VERIFICATION_DOMAIN") or "localstack.cloud"
+
+
+def use_custom_dns():
+ return str(DNS_ADDRESS) not in FALSE_STRINGS
+
+
+# s3 virtual host name
+S3_VIRTUAL_HOSTNAME = "s3.%s" % LOCALSTACK_HOST.host
+S3_STATIC_WEBSITE_HOSTNAME = "s3-website.%s" % LOCALSTACK_HOST.host
+
+BOTO_WAITER_DELAY = int(os.environ.get("BOTO_WAITER_DELAY") or "1")
+BOTO_WAITER_MAX_ATTEMPTS = int(os.environ.get("BOTO_WAITER_MAX_ATTEMPTS") or "120")
+DISABLE_CUSTOM_BOTO_WAITER_CONFIG = is_env_true("DISABLE_CUSTOM_BOTO_WAITER_CONFIG")
+
+# defaults to false
+# if `DISABLE_BOTO_RETRIES=1` is set, all our created boto clients will have retries disabled
+DISABLE_BOTO_RETRIES = is_env_true("DISABLE_BOTO_RETRIES")
+
+DISTRIBUTED_MODE = is_env_true("DISTRIBUTED_MODE")
+
+# This flag enables `connect_to` to be in-memory only and not do networking calls
+IN_MEMORY_CLIENT = is_env_true("IN_MEMORY_CLIENT")
+
+# List of environment variable names used for configuration that are passed from the host into the LocalStack container.
+# => Synchronize this list with the above and the configuration docs:
+# https://docs.localstack.cloud/references/configuration/
+# => Sort this list alphabetically
+# => Add deprecated environment variables to deprecations.py and add a comment in this list
+# => Move removed legacy variables to the section grouped by release (still relevant for deprecation warnings)
+# => Do *not* include any internal developer configurations that apply to host-mode only in this list.
+CONFIG_ENV_VARS = [
+ "ALLOW_NONSTANDARD_REGIONS",
+ "BOTO_WAITER_DELAY",
+ "BOTO_WAITER_MAX_ATTEMPTS",
+ "BUCKET_MARKER_LOCAL",
+ "CFN_IGNORE_UNSUPPORTED_RESOURCE_TYPES",
+ "CFN_PER_RESOURCE_TIMEOUT",
+ "CFN_STRING_REPLACEMENT_DENY_LIST",
+ "CFN_VERBOSE_ERRORS",
+ "CI",
+ "CONTAINER_RUNTIME",
+ "CUSTOM_SSL_CERT_PATH",
+ "DEBUG",
+ "DEBUG_HANDLER_CHAIN",
+ "DEVELOP",
+ "DEVELOP_PORT",
+ "DISABLE_BOTO_RETRIES",
+ "DISABLE_CORS_CHECKS",
+ "DISABLE_CORS_HEADERS",
+ "DISABLE_CUSTOM_BOTO_WAITER_CONFIG",
+ "DISABLE_CUSTOM_CORS_APIGATEWAY",
+ "DISABLE_CUSTOM_CORS_S3",
+ "DISABLE_EVENTS",
+ "DISTRIBUTED_MODE",
+ "DNS_ADDRESS",
+ "DNS_PORT",
+ "DNS_LOCAL_NAME_PATTERNS",
+ "DNS_NAME_PATTERNS_TO_RESOLVE_UPSTREAM",
+ "DNS_RESOLVE_IP",
+ "DNS_SERVER",
+ "DNS_VERIFICATION_DOMAIN",
+ "DOCKER_BRIDGE_IP",
+ "DOCKER_SDK_DEFAULT_TIMEOUT_SECONDS",
+ "DYNAMODB_ERROR_PROBABILITY",
+ "DYNAMODB_HEAP_SIZE",
+ "DYNAMODB_IN_MEMORY",
+ "DYNAMODB_LOCAL_PORT",
+ "DYNAMODB_SHARE_DB",
+ "DYNAMODB_READ_ERROR_PROBABILITY",
+ "DYNAMODB_REMOVE_EXPIRED_ITEMS",
+ "DYNAMODB_WRITE_ERROR_PROBABILITY",
+ "EAGER_SERVICE_LOADING",
+ "ENABLE_CONFIG_UPDATES",
+ "EVENT_RULE_ENGINE",
+ "EXTRA_CORS_ALLOWED_HEADERS",
+ "EXTRA_CORS_ALLOWED_ORIGINS",
+ "EXTRA_CORS_EXPOSE_HEADERS",
+ "GATEWAY_LISTEN",
+ "GATEWAY_SERVER",
+ "GATEWAY_WORKER_THREAD_COUNT",
+ "HOSTNAME",
+ "HOSTNAME_FROM_LAMBDA",
+ "IN_MEMORY_CLIENT",
+ "KINESIS_ERROR_PROBABILITY",
+ "KINESIS_MOCK_PERSIST_INTERVAL",
+ "KINESIS_MOCK_LOG_LEVEL",
+ "KINESIS_ON_DEMAND_STREAM_COUNT_LIMIT",
+ "KINESIS_PERSISTENCE",
+ "LAMBDA_DEBUG_MODE",
+ "LAMBDA_DEBUG_MODE_CONFIG",
+ "LAMBDA_DISABLE_AWS_ENDPOINT_URL",
+ "LAMBDA_DOCKER_DNS",
+ "LAMBDA_DOCKER_FLAGS",
+ "LAMBDA_DOCKER_NETWORK",
+ "LAMBDA_EVENTS_INTERNAL_SQS",
+ "LAMBDA_EVENT_SOURCE_MAPPING",
+ "LAMBDA_IGNORE_ARCHITECTURE",
+ "LAMBDA_INIT_DEBUG",
+ "LAMBDA_INIT_BIN_PATH",
+ "LAMBDA_INIT_BOOTSTRAP_PATH",
+ "LAMBDA_INIT_DELVE_PATH",
+ "LAMBDA_INIT_DELVE_PORT",
+ "LAMBDA_INIT_POST_INVOKE_WAIT_MS",
+ "LAMBDA_INIT_USER",
+ "LAMBDA_INIT_RELEASE_VERSION",
+ "LAMBDA_KEEPALIVE_MS",
+ "LAMBDA_LIMITS_CONCURRENT_EXECUTIONS",
+ "LAMBDA_LIMITS_MINIMUM_UNRESERVED_CONCURRENCY",
+ "LAMBDA_LIMITS_TOTAL_CODE_SIZE",
+ "LAMBDA_LIMITS_CODE_SIZE_ZIPPED",
+ "LAMBDA_LIMITS_CODE_SIZE_UNZIPPED",
+ "LAMBDA_LIMITS_CREATE_FUNCTION_REQUEST_SIZE",
+ "LAMBDA_LIMITS_MAX_FUNCTION_ENVVAR_SIZE_BYTES",
+ "LAMBDA_LIMITS_MAX_FUNCTION_PAYLOAD_SIZE_BYTES",
+ "LAMBDA_PREBUILD_IMAGES",
+ "LAMBDA_RUNTIME_IMAGE_MAPPING",
+ "LAMBDA_REMOVE_CONTAINERS",
+ "LAMBDA_RETRY_BASE_DELAY_SECONDS",
+ "LAMBDA_RUNTIME_EXECUTOR",
+ "LAMBDA_RUNTIME_ENVIRONMENT_TIMEOUT",
+ "LAMBDA_RUNTIME_VALIDATION",
+ "LAMBDA_SYNCHRONOUS_CREATE",
+ "LAMBDA_SQS_EVENT_SOURCE_MAPPING_INTERVAL",
+ "LAMBDA_TRUNCATE_STDOUT",
+ "LEGACY_DOCKER_CLIENT",
+ "LEGACY_SNS_GCM_PUBLISHING",
+ "LOCALSTACK_API_KEY",
+ "LOCALSTACK_AUTH_TOKEN",
+ "LOCALSTACK_HOST",
+ "LOG_LICENSE_ISSUES",
+ "LS_LOG",
+ "MAIN_CONTAINER_NAME",
+ "MAIN_DOCKER_NETWORK",
+ "OPENAPI_VALIDATE_REQUEST",
+ "OPENAPI_VALIDATE_RESPONSE",
+ "OPENSEARCH_ENDPOINT_STRATEGY",
+ "OUTBOUND_HTTP_PROXY",
+ "OUTBOUND_HTTPS_PROXY",
+ "PARITY_AWS_ACCESS_KEY_ID",
+ "PERSISTENCE",
+ "PORTS_CHECK_DOCKER_IMAGE",
+ "REQUESTS_CA_BUNDLE",
+ "REMOVE_SSL_CERT",
+ "S3_SKIP_SIGNATURE_VALIDATION",
+ "S3_SKIP_KMS_KEY_VALIDATION",
+ "SERVICES",
+ "SKIP_INFRA_DOWNLOADS",
+ "SKIP_SSL_CERT_DOWNLOAD",
+ "SNAPSHOT_LOAD_STRATEGY",
+ "SNAPSHOT_SAVE_STRATEGY",
+ "SNAPSHOT_FLUSH_INTERVAL",
+ "SNS_SES_SENDER_ADDRESS",
+ "SQS_DELAY_PURGE_RETRY",
+ "SQS_DELAY_RECENTLY_DELETED",
+ "SQS_ENABLE_MESSAGE_RETENTION_PERIOD",
+ "SQS_ENDPOINT_STRATEGY",
+ "SQS_DISABLE_CLOUDWATCH_METRICS",
+ "SQS_CLOUDWATCH_METRICS_REPORT_INTERVAL",
+ "STEPFUNCTIONS_LAMBDA_ENDPOINT",
+ "STRICT_SERVICE_LOADING",
+ "TF_COMPAT_MODE",
+ "USE_SSL",
+ "WAIT_FOR_DEBUGGER",
+ "WINDOWS_DOCKER_MOUNT_PREFIX",
+ # Removed legacy variables in 2.0.0
+ # DATA_DIR => do *not* include in this list, as it is treated separately. # deprecated since 1.0.0
+ "LEGACY_DIRECTORIES", # deprecated since 1.0.0
+ "SYNCHRONOUS_API_GATEWAY_EVENTS", # deprecated since 1.3.0
+ "SYNCHRONOUS_DYNAMODB_EVENTS", # deprecated since 1.3.0
+ "SYNCHRONOUS_SNS_EVENTS", # deprecated since 1.3.0
+ "SYNCHRONOUS_SQS_EVENTS", # deprecated since 1.3.0
+ # Removed legacy variables in 3.0.0
+ "DEFAULT_REGION", # deprecated since 0.12.7
+ "EDGE_BIND_HOST", # deprecated since 2.0.0
+ "EDGE_FORWARD_URL", # deprecated since 1.4.0
+ "EDGE_PORT", # deprecated since 2.0.0
+ "EDGE_PORT_HTTP", # deprecated since 2.0.0
+ "ES_CUSTOM_BACKEND", # deprecated since 0.14.0
+ "ES_ENDPOINT_STRATEGY", # deprecated since 0.14.0
+ "ES_MULTI_CLUSTER", # deprecated since 0.14.0
+ "HOSTNAME_EXTERNAL", # deprecated since 2.0.0
+ "KINESIS_INITIALIZE_STREAMS", # deprecated since 1.4.0
+ "KINESIS_PROVIDER", # deprecated since 1.3.0
+ "KMS_PROVIDER", # deprecated since 1.4.0
+ "LAMBDA_XRAY_INIT", # deprecated since 2.0.0
+ "LAMBDA_CODE_EXTRACT_TIME", # deprecated since 2.0.0
+ "LAMBDA_CONTAINER_REGISTRY", # deprecated since 2.0.0
+ "LAMBDA_EXECUTOR", # deprecated since 2.0.0
+ "LAMBDA_FALLBACK_URL", # deprecated since 2.0.0
+ "LAMBDA_FORWARD_URL", # deprecated since 2.0.0
+ "LAMBDA_JAVA_OPTS", # currently only supported in old Lambda provider but not officially deprecated
+ "LAMBDA_REMOTE_DOCKER", # deprecated since 2.0.0
+ "LAMBDA_STAY_OPEN_MODE", # deprecated since 2.0.0
+ "LEGACY_EDGE_PROXY", # deprecated since 1.0.0
+ "LOCALSTACK_HOSTNAME", # deprecated since 2.0.0
+ "SQS_PORT_EXTERNAL", # deprecated only in docs since 2022-07-13
+ "SYNCHRONOUS_KINESIS_EVENTS", # deprecated since 1.3.0
+ "USE_SINGLE_REGION", # deprecated since 0.12.7
+ "MOCK_UNIMPLEMENTED", # deprecated since 1.3.0
+]
+
+
+def is_local_test_mode() -> bool:
+ """Returns True if we are running in the context of our local integration tests."""
+ return is_env_true(ENV_INTERNAL_TEST_RUN)
+
+
+def is_collect_metrics_mode() -> bool:
+ """Returns True if metric collection is enabled."""
+ return is_env_true(ENV_INTERNAL_TEST_COLLECT_METRIC)
+
+
+def collect_config_items() -> List[Tuple[str, Any]]:
+ """Returns a list of key-value tuples of LocalStack configuration values."""
+ none = object() # sentinel object
+
+ # collect which keys to print
+ keys = []
+ keys.extend(CONFIG_ENV_VARS)
+ keys.append("DATA_DIR")
+ keys.sort()
+
+ values = globals()
+
+ result = []
+ for k in keys:
+ v = values.get(k, none)
+ if v is none:
+ continue
+ result.append((k, v))
+ result.sort()
+ return result
+
+
+def populate_config_env_var_names():
+ global CONFIG_ENV_VARS
+
+ CONFIG_ENV_VARS += [
+ key
+ for key in [key.upper() for key in os.environ]
+ if (key.startswith("LOCALSTACK_") or key.startswith("PROVIDER_OVERRIDE_"))
+ # explicitly exclude LOCALSTACK_CLI (it's prefixed with "LOCALSTACK_",
+ # but is only used in the CLI (should not be forwarded to the container)
+ and key != "LOCALSTACK_CLI"
+ ]
+
+ # create variable aliases prefixed with LOCALSTACK_ (except LOCALSTACK_HOST)
+ CONFIG_ENV_VARS += [
+ "LOCALSTACK_" + v for v in CONFIG_ENV_VARS if not v.startswith("LOCALSTACK_")
+ ]
+
+ CONFIG_ENV_VARS = list(set(CONFIG_ENV_VARS))
+
+
+# populate env var names to be passed to the container
+populate_config_env_var_names()
+
+
+# helpers to build urls
+def get_protocol() -> str:
+ return "https" if USE_SSL else "http"
+
+
+def external_service_url(
+ host: Optional[str] = None,
+ port: Optional[int] = None,
+ protocol: Optional[str] = None,
+ subdomains: Optional[str] = None,
+) -> str:
+ """Returns a service URL (e.g., SQS queue URL) to an external client (e.g., boto3) potentially running on another
+ machine than LocalStack. The configurations LOCALSTACK_HOST and USE_SSL can customize these returned URLs.
+ The optional parameters can be used to customize the defaults.
+ Examples with default configuration:
+ * external_service_url() == http://localhost.localstack.cloud:4566
+ * external_service_url(subdomains="s3") == http://s3.localhost.localstack.cloud:4566
+ """
+ protocol = protocol or get_protocol()
+ subdomains = f"{subdomains}." if subdomains else ""
+ host = host or LOCALSTACK_HOST.host
+ port = port or LOCALSTACK_HOST.port
+ return f"{protocol}://{subdomains}{host}:{port}"
+
+
+def internal_service_url(
+ host: Optional[str] = None,
+ port: Optional[int] = None,
+ protocol: Optional[str] = None,
+ subdomains: Optional[str] = None,
+) -> str:
+ """Returns a service URL for internal use within LocalStack (i.e., same host).
+ The configuration USE_SSL can customize these returned URLs but LOCALSTACK_HOST has no effect.
+ The optional parameters can be used to customize the defaults.
+ Examples with default configuration:
+ * internal_service_url() == http://localhost:4566
+ * internal_service_url(port=8080) == http://localhost:8080
+ """
+ protocol = protocol or get_protocol()
+ subdomains = f"{subdomains}." if subdomains else ""
+ host = host or LOCALHOST
+ port = port or GATEWAY_LISTEN[0].port
+ return f"{protocol}://{subdomains}{host}:{port}"
+
+
+# DEPRECATED: old helpers for building URLs
+
+
+def service_url(service_key, host=None, port=None):
+ """@deprecated: Use `internal_service_url()` instead. We assume that most usages are internal
+ but really need to check and update each usage accordingly.
+ """
+ warnings.warn(
+ """@deprecated: Use `internal_service_url()` instead. We assume that most usages are
+ internal but really need to check and update each usage accordingly.""",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return internal_service_url(host=host, port=port)
+
+
+def service_port(service_key: str, external: bool = False) -> int:
+ """@deprecated: Use `localstack_host().port` for external and `GATEWAY_LISTEN[0].port` for
+ internal use."""
+ warnings.warn(
+ "Deprecated: use `localstack_host().port` for external and `GATEWAY_LISTEN[0].port` for "
+ "internal use.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ if external:
+ return LOCALSTACK_HOST.port
+ return GATEWAY_LISTEN[0].port
+
+
+def get_edge_port_http():
+ """@deprecated: Use `localstack_host().port` for external and `GATEWAY_LISTEN[0].port` for
+ internal use. This function is not needed anymore because we don't separate between HTTP
+ and HTTP ports anymore since LocalStack listens to both ports."""
+ warnings.warn(
+ """@deprecated: Use `localstack_host().port` for external and `GATEWAY_LISTEN[0].port`
+ for internal use. This function is also not needed anymore because we don't separate
+ between HTTP and HTTP ports anymore since LocalStack listens to both.""",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return GATEWAY_LISTEN[0].port
+
+
+def get_edge_url(localstack_hostname=None, protocol=None):
+ """@deprecated: Use `internal_service_url()` instead.
+ We assume that most usages are internal but really need to check and update each usage accordingly.
+ """
+ warnings.warn(
+ """@deprecated: Use `internal_service_url()` instead.
+ We assume that most usages are internal but really need to check and update each usage accordingly.
+ """,
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return internal_service_url(host=localstack_hostname, protocol=protocol)
+
+
+class ServiceProviderConfig(Mapping[str, str]):
+ _provider_config: Dict[str, str]
+ default_value: str
+ override_prefix: str = "PROVIDER_OVERRIDE_"
+
+ def __init__(self, default_value: str):
+ self._provider_config = {}
+ self.default_value = default_value
+
+ def load_from_environment(self, env: Mapping[str, str] = None):
+ if env is None:
+ env = os.environ
+ for key, value in env.items():
+ if key.startswith(self.override_prefix) and value:
+ self.set_provider(key[len(self.override_prefix) :].lower().replace("_", "-"), value)
+
+ def get_provider(self, service: str) -> str:
+ return self._provider_config.get(service, self.default_value)
+
+ def set_provider_if_not_exists(self, service: str, provider: str) -> None:
+ if service not in self._provider_config:
+ self._provider_config[service] = provider
+
+ def set_provider(self, service: str, provider: str):
+ self._provider_config[service] = provider
+
+ def bulk_set_provider_if_not_exists(self, services: List[str], provider: str):
+ for service in services:
+ self.set_provider_if_not_exists(service, provider)
+
+ def __getitem__(self, item):
+ return self.get_provider(item)
+
+ def __setitem__(self, key, value):
+ self.set_provider(key, value)
+
+ def __len__(self):
+ return len(self._provider_config)
+
+ def __iter__(self):
+ return self._provider_config.__iter__()
+
+
+SERVICE_PROVIDER_CONFIG = ServiceProviderConfig("default")
+
+SERVICE_PROVIDER_CONFIG.load_from_environment()
+
+
+def init_directories() -> Directories:
+ if is_in_docker:
+ return Directories.for_container()
+ else:
+ if is_env_true("LOCALSTACK_CLI"):
+ return Directories.for_cli()
+
+ return Directories.for_host()
+
+
+# initialize directories
+dirs: Directories
+dirs = init_directories()
diff --git a/localstack-core/localstack/constants.py b/localstack-core/localstack/constants.py
new file mode 100644
index 0000000000000..80b61954935bf
--- /dev/null
+++ b/localstack-core/localstack/constants.py
@@ -0,0 +1,183 @@
+import os
+
+from localstack.version import __version__
+
+VERSION = __version__
+
+# HTTP headers used to forward proxy request URLs
+HEADER_LOCALSTACK_EDGE_URL = "x-localstack-edge"
+HEADER_LOCALSTACK_REQUEST_URL = "x-localstack-request-url"
+# xXx custom localstack authorization header only used in ext
+HEADER_LOCALSTACK_AUTHORIZATION = "x-localstack-authorization"
+HEADER_LOCALSTACK_TARGET = "x-localstack-target"
+HEADER_AMZN_ERROR_TYPE = "X-Amzn-Errortype"
+
+# backend service ports, for services that are behind a proxy (counting down from 4566)
+DEFAULT_PORT_EDGE = 4566
+
+# host name for localhost
+LOCALHOST = "localhost"
+LOCALHOST_IP = "127.0.0.1"
+LOCALHOST_HOSTNAME = "localhost.localstack.cloud"
+
+# User-agent string used in outgoing HTTP requests made by LocalStack
+USER_AGENT_STRING = f"localstack/{VERSION}"
+
+# version of the Maven dependency with Java utility code
+LOCALSTACK_MAVEN_VERSION = "0.2.21"
+MAVEN_REPO_URL = "https://repo1.maven.org/maven2"
+
+# URL of localstack's artifacts repository on GitHub
+ARTIFACTS_REPO = "https://github.com/localstack/localstack-artifacts"
+
+# Artifacts endpoint
+ASSETS_ENDPOINT = "https://assets.localstack.cloud"
+
+# host to bind to when starting the services
+BIND_HOST = "0.0.0.0"
+
+# root code folder
+MODULE_MAIN_PATH = os.path.dirname(os.path.realpath(__file__))
+# TODO rename to "ROOT_FOLDER"!
+LOCALSTACK_ROOT_FOLDER = os.path.realpath(os.path.join(MODULE_MAIN_PATH, ".."))
+
+# virtualenv folder
+LOCALSTACK_VENV_FOLDER = os.environ.get("VIRTUAL_ENV")
+if not LOCALSTACK_VENV_FOLDER:
+ # fallback to the previous logic
+ LOCALSTACK_VENV_FOLDER = os.path.join(LOCALSTACK_ROOT_FOLDER, ".venv")
+ if not os.path.isdir(LOCALSTACK_VENV_FOLDER):
+ # assuming this package lives here: /lib/pythonX.X/site-packages/localstack/
+ LOCALSTACK_VENV_FOLDER = os.path.realpath(
+ os.path.join(LOCALSTACK_ROOT_FOLDER, "..", "..", "..")
+ )
+
+# default volume directory containing shared data
+DEFAULT_VOLUME_DIR = "/var/lib/localstack"
+
+# API Gateway path to indicate a user request sent to the gateway
+PATH_USER_REQUEST = "_user_request_"
+
+# name of LocalStack Docker image
+DOCKER_IMAGE_NAME = "localstack/localstack"
+DOCKER_IMAGE_NAME_PRO = "localstack/localstack-pro"
+DOCKER_IMAGE_NAME_FULL = "localstack/localstack-full"
+
+# backdoor API path used to retrieve or update config variables
+CONFIG_UPDATE_PATH = "/?_config_"
+
+# API path for localstack internal resources
+INTERNAL_RESOURCE_PATH = "/_localstack"
+
+# environment variable name to tag local test runs
+ENV_INTERNAL_TEST_RUN = "LOCALSTACK_INTERNAL_TEST_RUN"
+
+# environment variable name to tag collect metrics during a test run
+ENV_INTERNAL_TEST_COLLECT_METRIC = "LOCALSTACK_INTERNAL_TEST_COLLECT_METRIC"
+
+# environment variable that flags whether pro was activated. do not use it for security purposes!
+ENV_PRO_ACTIVATED = "PRO_ACTIVATED"
+
+# content types / encodings
+HEADER_CONTENT_TYPE = "Content-Type"
+TEXT_XML = "text/xml"
+APPLICATION_AMZ_JSON_1_0 = "application/x-amz-json-1.0"
+APPLICATION_AMZ_JSON_1_1 = "application/x-amz-json-1.1"
+APPLICATION_AMZ_CBOR_1_1 = "application/x-amz-cbor-1.1"
+APPLICATION_CBOR = "application/cbor"
+APPLICATION_JSON = "application/json"
+APPLICATION_XML = "application/xml"
+APPLICATION_OCTET_STREAM = "application/octet-stream"
+APPLICATION_X_WWW_FORM_URLENCODED = "application/x-www-form-urlencoded"
+HEADER_ACCEPT_ENCODING = "Accept-Encoding"
+
+# strings to indicate truthy/falsy values
+TRUE_STRINGS = ("1", "true", "True")
+FALSE_STRINGS = ("0", "false", "False")
+# strings with valid log levels for LS_LOG
+LOG_LEVELS = ("trace-internal", "trace", "debug", "info", "warn", "error", "warning")
+
+# the version of elasticsearch that is pre-seeded into the base image (sync with Dockerfile.base)
+ELASTICSEARCH_DEFAULT_VERSION = "Elasticsearch_7.10"
+# See https://docs.aws.amazon.com/ja_jp/elasticsearch-service/latest/developerguide/aes-supported-plugins.html
+ELASTICSEARCH_PLUGIN_LIST = [
+ "analysis-icu",
+ "ingest-attachment",
+ "analysis-kuromoji",
+ "mapper-murmur3",
+ "mapper-size",
+ "analysis-phonetic",
+ "analysis-smartcn",
+ "analysis-stempel",
+ "analysis-ukrainian",
+]
+# Default ES modules to exclude (save apprx 66MB in the final image)
+ELASTICSEARCH_DELETE_MODULES = ["ingest-geoip"]
+
+# the version of opensearch which is used by default
+OPENSEARCH_DEFAULT_VERSION = "OpenSearch_2.11"
+
+# See https://docs.aws.amazon.com/opensearch-service/latest/developerguide/supported-plugins.html
+OPENSEARCH_PLUGIN_LIST = [
+ "ingest-attachment",
+ "analysis-kuromoji",
+]
+
+# API endpoint for analytics events
+API_ENDPOINT = os.environ.get("API_ENDPOINT") or "https://api.localstack.cloud/v1"
+# new analytics API endpoint
+ANALYTICS_API = os.environ.get("ANALYTICS_API") or "https://analytics.localstack.cloud/v1"
+
+# environment variable to indicate this process should run the localstack infrastructure
+LOCALSTACK_INFRA_PROCESS = "LOCALSTACK_INFRA_PROCESS"
+
+# AWS region us-east-1
+AWS_REGION_US_EAST_1 = "us-east-1"
+
+# environment variable to override max pool connections
+try:
+ MAX_POOL_CONNECTIONS = int(os.environ["MAX_POOL_CONNECTIONS"])
+except Exception:
+ MAX_POOL_CONNECTIONS = 150
+
+# Fallback Account ID if not available in the client request
+DEFAULT_AWS_ACCOUNT_ID = "000000000000"
+
+# Credentials used for internal calls
+INTERNAL_AWS_ACCESS_KEY_ID = "__internal_call__"
+INTERNAL_AWS_SECRET_ACCESS_KEY = "__internal_call__"
+
+# trace log levels (excluding/including internal API calls), configurable via $LS_LOG
+LS_LOG_TRACE = "trace"
+LS_LOG_TRACE_INTERNAL = "trace-internal"
+TRACE_LOG_LEVELS = [LS_LOG_TRACE, LS_LOG_TRACE_INTERNAL]
+
+# list of official docker images
+OFFICIAL_IMAGES = [
+ "localstack/localstack",
+ "localstack/localstack-pro",
+]
+
+# port for debug py
+DEFAULT_DEVELOP_PORT = 5678
+
+# Default bucket name of the s3 bucket used for local lambda development
+# This name should be accepted by all IaC tools, so should respect s3 bucket naming conventions
+DEFAULT_BUCKET_MARKER_LOCAL = "hot-reload"
+LEGACY_DEFAULT_BUCKET_MARKER_LOCAL = "__local__"
+
+# user that starts the opensearch process if the current user is root
+OS_USER_OPENSEARCH = "localstack"
+
+# output string that indicates that the stack is ready
+READY_MARKER_OUTPUT = "Ready."
+
+# Regex for `Credential` field in the Authorization header in AWS signature version v4
+# The format is as follows:
+# Credential=////aws4_request
+# eg.
+# Credential=AKIAIOSFODNN7EXAMPLE/20130524/us-east-1/s3/aws4_request
+AUTH_CREDENTIAL_REGEX = r"Credential=(?P[a-zA-Z0-9-_.]{1,})/(?P\d{8})/(?P[a-z0-9-]{1,})/(?P[a-z0-9]{1,})/"
+
+# Custom resource tag to override the generated resource ID.
+TAG_KEY_CUSTOM_ID = "_custom_id_"
diff --git a/localstack-core/localstack/deprecations.py b/localstack-core/localstack/deprecations.py
new file mode 100644
index 0000000000000..1ece1f5ccfec3
--- /dev/null
+++ b/localstack-core/localstack/deprecations.py
@@ -0,0 +1,396 @@
+# A simple module to track deprecations over time / versions, and some simple functions guiding the affected users.
+import logging
+import os
+from dataclasses import dataclass
+from typing import Callable, List, Optional
+
+from localstack.utils.analytics import log
+
+LOG = logging.getLogger(__name__)
+
+
+@dataclass
+class EnvVarDeprecation:
+ """
+ Simple class defining a deprecation of an environment variable config.
+ It helps keeping track of deprecations over time.
+ """
+
+ env_var: str
+ deprecation_version: str
+ deprecation_path: str = None
+
+ @property
+ def is_affected(self) -> bool:
+ """
+ Checks whether an environment is affected.
+ :return: true if the environment is affected / is using a deprecated config
+ """
+ return os.environ.get(self.env_var) is not None
+
+
+#
+# List of deprecations
+#
+# Please make sure this is in-sync with https://docs.localstack.cloud/references/configuration/
+#
+DEPRECATIONS = [
+ # Since 0.11.3 - HTTP / HTTPS multiplexing
+ EnvVarDeprecation(
+ "USE_SSL",
+ "0.11.3",
+ "Each endpoint now supports multiplexing HTTP/HTTPS traffic over the same port. Please remove this environment variable.", # noqa
+ ),
+ # Since 0.12.8 - PORT_UI was removed
+ EnvVarDeprecation(
+ "PORT_WEB_UI",
+ "0.12.8",
+ "PORT_WEB_UI has been removed, and is not available anymore. Please remove this environment variable.",
+ ),
+ # Deprecated in 0.12.7, removed in 3.0.0
+ EnvVarDeprecation(
+ "USE_SINGLE_REGION",
+ "0.12.7",
+ "LocalStack now has full multi-region support. This option has no effect. Please remove it from your configuration.", # noqa
+ ),
+ # Deprecated in 0.12.7, removed in 3.0.0
+ EnvVarDeprecation(
+ "DEFAULT_REGION",
+ "0.12.7",
+ "LocalStack now has full multi-region support. This option has no effect. Please remove it from your configuration.", # noqa
+ ),
+ # Since 1.0.0 - New Persistence and file system
+ EnvVarDeprecation(
+ "DATA_DIR",
+ "1.0.0",
+ "Please use PERSISTENCE instead. The state will be stored in your LocalStack volume in the state/ directory.",
+ ),
+ EnvVarDeprecation(
+ "HOST_TMP_FOLDER",
+ "1.0.0",
+ "This option has no effect anymore. Please remove this environment variable.",
+ ),
+ EnvVarDeprecation(
+ "LEGACY_DIRECTORIES",
+ "1.0.0",
+ "This option has no effect anymore. Please migrate to the new filesystem layout (introduced with v1.0).",
+ ),
+ EnvVarDeprecation(
+ "TMPDIR", "1.0.0", "Please migrate to the new filesystem layout (introduced with v1.0)."
+ ),
+ EnvVarDeprecation(
+ "PERSISTENCE_SINGLE_FILE",
+ "1.0.0",
+ "The legacy persistence mechanism is not supported anymore, please migrate to the advanced persistence mechanism of LocalStack Pro.", # noqa
+ ),
+ # Since 1.0.0 - New ASF Gateway
+ EnvVarDeprecation(
+ "LEGACY_EDGE_PROXY",
+ "1.0.0",
+ "This option has no effect anymore. Please remove this environment variable.",
+ ),
+ # Since 1.1.0 - Kinesalite removed with 1.3, only kinesis-mock is used as kinesis provider / backend
+ EnvVarDeprecation(
+ "KINESIS_PROVIDER",
+ "1.1.0",
+ "This option has no effect anymore. Please remove this environment variable.",
+ ),
+ # Since 1.1.0 - Init dir has been deprecated in favor of pluggable init hooks
+ EnvVarDeprecation(
+ "LEGACY_INIT_DIR",
+ "1.1.0",
+ "This option has no effect anymore. "
+ "Please use the pluggable initialization hooks in /etc/localhost/init/.d instead.",
+ ),
+ EnvVarDeprecation(
+ "INIT_SCRIPTS_PATH",
+ "1.1.0",
+ "This option has no effect anymore. "
+ "Please use the pluggable initialization hooks in /etc/localhost/init/.d instead.",
+ ),
+ # Since 1.3.0 - Synchronous events break AWS parity
+ EnvVarDeprecation(
+ "SYNCHRONOUS_SNS_EVENTS",
+ "1.3.0",
+ "This option has no effect anymore. Please remove this environment variable.",
+ ),
+ EnvVarDeprecation(
+ "SYNCHRONOUS_SQS_EVENTS",
+ "1.3.0",
+ "This option has no effect anymore. Please remove this environment variable.",
+ ),
+ EnvVarDeprecation(
+ "SYNCHRONOUS_API_GATEWAY_EVENTS",
+ "1.3.0",
+ "This option has no effect anymore. Please remove this environment variable.",
+ ),
+ EnvVarDeprecation(
+ "SYNCHRONOUS_KINESIS_EVENTS",
+ "1.3.0",
+ "This option has no effect anymore. Please remove this environment variable.",
+ ),
+ EnvVarDeprecation(
+ "SYNCHRONOUS_DYNAMODB_EVENTS",
+ "1.3.0",
+ "This option has no effect anymore. Please remove this environment variable.",
+ ),
+ # Since 1.3.0 - All non-pre-seeded infra is downloaded asynchronously
+ EnvVarDeprecation(
+ "SKIP_INFRA_DOWNLOADS",
+ "1.3.0",
+ "Infra downloads are triggered on-demand now. Please remove this environment variable.",
+ ),
+ # Since 1.3.0 - Mocking for unimplemented operations will be removed
+ EnvVarDeprecation(
+ "MOCK_UNIMPLEMENTED",
+ "1.3.0",
+ "This feature is not supported anymore. Please remove this environment variable.",
+ ),
+ # Since 1.4.0 - The Edge Forwarding is only used for legacy HTTPS proxying and will be removed
+ EnvVarDeprecation(
+ "EDGE_FORWARD_URL",
+ "1.4.0",
+ "This option has no effect anymore. Please remove this environment variable.",
+ ),
+ # Deprecated in 1.4.0, removed in 3.0.0
+ EnvVarDeprecation(
+ "KMS_PROVIDER",
+ "1.4.0",
+ "This option has no effect. Please remove it from your configuration.",
+ ),
+ # Since 2.0.0 - HOSTNAME_EXTERNAL will be replaced with LOCALSTACK_HOST
+ EnvVarDeprecation(
+ "HOSTNAME_EXTERNAL",
+ "2.0.0",
+ "This configuration will be migrated to LOCALSTACK_HOST",
+ ),
+ # Since 2.0.0 - LOCALSTACK_HOST will be replaced with LOCALSTACK_HOST
+ EnvVarDeprecation(
+ "LOCALSTACK_HOSTNAME",
+ "2.0.0",
+ "This configuration will be migrated to LOCALSTACK_HOST",
+ ),
+ # Since 2.0.0 - redefined as GATEWAY_LISTEN
+ EnvVarDeprecation(
+ "EDGE_BIND_HOST",
+ "2.0.0",
+ "This configuration will be migrated to GATEWAY_LISTEN",
+ ),
+ # Since 2.0.0 - redefined as GATEWAY_LISTEN
+ EnvVarDeprecation(
+ "EDGE_PORT",
+ "2.0.0",
+ "This configuration will be migrated to GATEWAY_LISTEN",
+ ),
+ # Since 2.0.0 - redefined as GATEWAY_LISTEN
+ EnvVarDeprecation(
+ "EDGE_PORT_HTTP",
+ "2.0.0",
+ "This configuration will be migrated to GATEWAY_LISTEN",
+ ),
+ EnvVarDeprecation(
+ "LAMBDA_EXECUTOR",
+ "2.0.0",
+ "This configuration is obsolete with the new lambda provider "
+ "https://docs.localstack.cloud/user-guide/aws/lambda/#migrating-to-lambda-v2\n"
+ "Please mount the Docker socket /var/run/docker.sock as a volume when starting LocalStack.",
+ ),
+ EnvVarDeprecation(
+ "LAMBDA_STAY_OPEN_MODE",
+ "2.0.0",
+ "Stay open mode is the default behavior in the new lambda provider "
+ "https://docs.localstack.cloud/user-guide/aws/lambda/#migrating-to-lambda-v2",
+ ),
+ EnvVarDeprecation(
+ "LAMBDA_REMOTE_DOCKER",
+ "2.0.0",
+ "The new lambda provider copies zip files by default and automatically configures hot reloading "
+ "https://docs.localstack.cloud/user-guide/aws/lambda/#migrating-to-lambda-v2",
+ ),
+ EnvVarDeprecation(
+ "LAMBDA_CODE_EXTRACT_TIME",
+ "2.0.0",
+ "Function creation now happens asynchronously in the new lambda provider "
+ "https://docs.localstack.cloud/user-guide/aws/lambda/#migrating-to-lambda-v2",
+ ),
+ EnvVarDeprecation(
+ "LAMBDA_CONTAINER_REGISTRY",
+ "2.0.0",
+ "The new lambda provider uses LAMBDA_RUNTIME_IMAGE_MAPPING instead "
+ "https://docs.localstack.cloud/user-guide/aws/lambda/#migrating-to-lambda-v2",
+ ),
+ EnvVarDeprecation(
+ "LAMBDA_FALLBACK_URL",
+ "2.0.0",
+ "This feature is not supported in the new lambda provider "
+ "https://docs.localstack.cloud/user-guide/aws/lambda/#migrating-to-lambda-v2",
+ ),
+ EnvVarDeprecation(
+ "LAMBDA_FORWARD_URL",
+ "2.0.0",
+ "This feature is not supported in the new lambda provider "
+ "https://docs.localstack.cloud/user-guide/aws/lambda/#migrating-to-lambda-v2",
+ ),
+ EnvVarDeprecation(
+ "LAMBDA_XRAY_INIT",
+ "2.0.0",
+ "The X-Ray daemon is always initialized in the new lambda provider "
+ "https://docs.localstack.cloud/user-guide/aws/lambda/#migrating-to-lambda-v2",
+ ),
+ EnvVarDeprecation(
+ "KINESIS_INITIALIZE_STREAMS",
+ "1.4.0",
+ "This option has no effect anymore. Please use the AWS client and init hooks instead.",
+ ),
+ EnvVarDeprecation(
+ "SQS_PORT_EXTERNAL",
+ "1.0.0",
+ "This option has no effect anymore. Please use LOCALSTACK_HOST instead.",
+ ),
+ EnvVarDeprecation(
+ "PROVIDER_OVERRIDE_LAMBDA",
+ "3.0.0",
+ "This option is ignored because the legacy Lambda provider (v1) has been removed since 3.0.0. "
+ "Please remove PROVIDER_OVERRIDE_LAMBDA and migrate to our new Lambda provider (v2): "
+ "https://docs.localstack.cloud/user-guide/aws/lambda/#migrating-to-lambda-v2",
+ ),
+ EnvVarDeprecation(
+ "ES_CUSTOM_BACKEND",
+ "0.14.0",
+ "This option has no effect anymore. Please use OPENSEARCH_CUSTOM_BACKEND instead.",
+ ),
+ EnvVarDeprecation(
+ "ES_MULTI_CLUSTER",
+ "0.14.0",
+ "This option has no effect anymore. Please use OPENSEARCH_MULTI_CLUSTER instead.",
+ ),
+ EnvVarDeprecation(
+ "ES_ENDPOINT_STRATEGY",
+ "0.14.0",
+ "This option has no effect anymore. Please use OPENSEARCH_ENDPOINT_STRATEGY instead.",
+ ),
+ EnvVarDeprecation(
+ "PERSIST_ALL",
+ "2.3.2",
+ "LocalStack treats backends and assets the same with respect to persistence. Please remove PERSIST_ALL.",
+ ),
+ EnvVarDeprecation(
+ "DNS_LOCAL_NAME_PATTERNS",
+ "3.0.0",
+ "This option was confusingly named. Please use DNS_NAME_PATTERNS_TO_RESOLVE_UPSTREAM "
+ "instead.",
+ ),
+ EnvVarDeprecation(
+ "LAMBDA_EVENTS_INTERNAL_SQS",
+ "4.0.0",
+ "This option is ignored because the LocalStack SQS dependency for event invokes has been removed since 4.0.0"
+ " in favor of a lightweight Lambda-internal SQS implementation.",
+ ),
+ EnvVarDeprecation(
+ "LAMBDA_EVENT_SOURCE_MAPPING",
+ "4.0.0",
+ "This option has no effect anymore. Please remove this environment variable.",
+ ),
+ EnvVarDeprecation(
+ "LAMBDA_SQS_EVENT_SOURCE_MAPPING_INTERVAL_SEC",
+ "4.0.0",
+ "This option is not supported by the new Lambda Event Source Mapping v2 implementation."
+ " Please create a GitHub issue if you experience any performance challenges.",
+ ),
+ EnvVarDeprecation(
+ "PROVIDER_OVERRIDE_STEPFUNCTIONS",
+ "4.0.0",
+ "This option is ignored because the legacy StepFunctions provider (v1) has been removed since 4.0.0."
+ " Please remove PROVIDER_OVERRIDE_STEPFUNCTIONS.",
+ ),
+ EnvVarDeprecation(
+ "EVENT_RULE_ENGINE",
+ "4.0.3",
+ "This option is ignored because the Java-based event ruler has been removed since 4.1.0."
+ " Our latest Python-native implementation introduced in 4.0.3"
+ " is faster, achieves great AWS parity, and fixes compatibility issues with the StepFunctions JSONata feature."
+ " Please remove EVENT_RULE_ENGINE.",
+ ),
+]
+
+
+def collect_affected_deprecations(
+ deprecations: Optional[List[EnvVarDeprecation]] = None,
+) -> List[EnvVarDeprecation]:
+ """
+ Collects all deprecations which are used in the OS environ.
+ :param deprecations: List of deprecations to check. Uses DEPRECATIONS list by default.
+ :return: List of deprecations which are used in the current environment
+ """
+ if deprecations is None:
+ deprecations = DEPRECATIONS
+ return [deprecation for deprecation in deprecations if deprecation.is_affected]
+
+
+def log_env_warning(deprecations: List[EnvVarDeprecation]) -> None:
+ """
+ Logs warnings for the given deprecations.
+ :param deprecations: list of affected deprecations to show a warning for
+ """
+ """
+ Logs a warning if a given environment variable is set (no matter what the value is).
+ :param env_var: to check
+ :param deprecation_version: version with which the env variable has been deprecated
+ """
+ if deprecations:
+ env_vars = []
+
+ # Print warnings for the env vars and collect them (for the analytics event)
+ for deprecation in deprecations:
+ LOG.warning(
+ "%s is deprecated (since %s) and will be removed in upcoming releases of LocalStack! %s",
+ deprecation.env_var,
+ deprecation.deprecation_version,
+ deprecation.deprecation_path,
+ )
+ env_vars.append(deprecation.env_var)
+
+ # Log an event if deprecated env vars are used
+ log.event(event="deprecated_env_usage", payload={"deprecated_env_vars": env_vars})
+
+
+def log_deprecation_warnings(deprecations: Optional[List[EnvVarDeprecation]] = None) -> None:
+ affected_deprecations = collect_affected_deprecations(deprecations)
+ log_env_warning(affected_deprecations)
+
+ provider_override_events = os.environ.get("PROVIDER_OVERRIDE_EVENTS")
+ if provider_override_events and provider_override_events in ["v1", "legacy"]:
+ env_var_value = f"PROVIDER_OVERRIDE_EVENTS={provider_override_events}"
+ deprecation_version = "4.0.0"
+ deprecation_path = f"Remove {env_var_value} to use the new EventBridge implementation."
+ LOG.warning(
+ "%s is deprecated (since %s) and will be removed in upcoming releases of LocalStack! %s",
+ env_var_value,
+ deprecation_version,
+ deprecation_path,
+ )
+
+
+def deprecated_endpoint(
+ endpoint: Callable, previous_path: str, deprecation_version: str, new_path: str
+) -> Callable:
+ """
+ Wrapper function which logs a warning (and a deprecation path) whenever a deprecated URL is invoked by the router.
+
+ :param endpoint: to wrap (log a warning whenever it is invoked)
+ :param previous_path: route path it is triggered by
+ :param deprecation_version: version of LocalStack with which this endpoint is deprecated
+ :param new_path: new route path which should be used instead
+ :return: wrapped function which can be registered for a route
+ """
+
+ def deprecated_wrapper(*args, **kwargs):
+ LOG.warning(
+ "%s is deprecated (since %s) and will be removed in upcoming releases of LocalStack! Use %s instead.",
+ previous_path,
+ deprecation_version,
+ new_path,
+ )
+ return endpoint(*args, **kwargs)
+
+ return deprecated_wrapper
diff --git a/localstack/services/apigateway/__init__.py b/localstack-core/localstack/dev/__init__.py
similarity index 100%
rename from localstack/services/apigateway/__init__.py
rename to localstack-core/localstack/dev/__init__.py
diff --git a/localstack/services/awslambda/__init__.py b/localstack-core/localstack/dev/debugger/__init__.py
similarity index 100%
rename from localstack/services/awslambda/__init__.py
rename to localstack-core/localstack/dev/debugger/__init__.py
diff --git a/localstack-core/localstack/dev/debugger/plugins.py b/localstack-core/localstack/dev/debugger/plugins.py
new file mode 100644
index 0000000000000..aa1d163f57b85
--- /dev/null
+++ b/localstack-core/localstack/dev/debugger/plugins.py
@@ -0,0 +1,25 @@
+import logging
+
+from localstack import config, constants
+from localstack.runtime import hooks
+
+LOG = logging.getLogger(__name__)
+
+
+def enable_debugger():
+ from localstack.packages.debugpy import debugpy_package
+
+ debugpy_package.install()
+ import debugpy # noqa: T100
+
+ LOG.info("Starting debug server at: %s:%s", constants.BIND_HOST, config.DEVELOP_PORT)
+ debugpy.listen((constants.BIND_HOST, config.DEVELOP_PORT)) # noqa: T100
+
+ if config.WAIT_FOR_DEBUGGER:
+ debugpy.wait_for_client() # noqa: T100
+
+
+@hooks.on_infra_start()
+def conditionally_enable_debugger():
+ if config.DEVELOP:
+ enable_debugger()
diff --git a/localstack/services/cloudformation/__init__.py b/localstack-core/localstack/dev/kubernetes/__init__.py
similarity index 100%
rename from localstack/services/cloudformation/__init__.py
rename to localstack-core/localstack/dev/kubernetes/__init__.py
diff --git a/localstack-core/localstack/dev/kubernetes/__main__.py b/localstack-core/localstack/dev/kubernetes/__main__.py
new file mode 100644
index 0000000000000..cf326a4dc3404
--- /dev/null
+++ b/localstack-core/localstack/dev/kubernetes/__main__.py
@@ -0,0 +1,249 @@
+import os
+
+import click
+import yaml
+
+from localstack import version as localstack_version
+
+
+def generate_k8s_cluster_config(pro: bool = False, mount_moto: bool = False, port: int = 4566):
+ volumes = []
+ root_path = os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")
+ localstack_code_path = os.path.join(root_path, "localstack-core", "localstack")
+ volumes.append(
+ {
+ "volume": f"{os.path.normpath(localstack_code_path)}:/code/localstack",
+ "nodeFilters": ["server:*", "agent:*"],
+ }
+ )
+
+ egg_path = os.path.join(
+ root_path, "localstack-core", "localstack_core.egg-info/entry_points.txt"
+ )
+ volumes.append(
+ {
+ "volume": f"{os.path.normpath(egg_path)}:/code/entry_points_community",
+ "nodeFilters": ["server:*", "agent:*"],
+ }
+ )
+ if pro:
+ pro_path = os.path.join(root_path, "..", "localstack-ext")
+ pro_code_path = os.path.join(pro_path, "localstack-pro-core", "localstack", "pro", "core")
+ volumes.append(
+ {
+ "volume": f"{os.path.normpath(pro_code_path)}:/code/localstack_ext",
+ "nodeFilters": ["server:*", "agent:*"],
+ }
+ )
+
+ egg_path = os.path.join(
+ pro_path, "localstack-pro-core", "localstack_ext.egg-info/entry_points.txt"
+ )
+ volumes.append(
+ {
+ "volume": f"{os.path.normpath(egg_path)}:/code/entry_points_ext",
+ "nodeFilters": ["server:*", "agent:*"],
+ }
+ )
+
+ if mount_moto:
+ moto_path = os.path.join(root_path, "..", "moto", "moto")
+ volumes.append(
+ {"volume": f"{moto_path}:/code/moto", "nodeFilters": ["server:*", "agent:*"]}
+ )
+
+ ports = [{"port": f"{port}:31566", "nodeFilters": ["server:0"]}]
+
+ config = {"apiVersion": "k3d.io/v1alpha5", "kind": "Simple", "volumes": volumes, "ports": ports}
+
+ return config
+
+
+def snake_to_kebab_case(string: str):
+ return string.lower().replace("_", "-")
+
+
+def generate_k8s_cluster_overrides(
+ pro: bool = False, cluster_config: dict = None, env: list[str] | None = None
+):
+ volumes = []
+ for volume in cluster_config["volumes"]:
+ name = snake_to_kebab_case(volume["volume"].split(":")[-1].split("/")[-1])
+ volume_type = "Directory" if name != "entry-points" else "File"
+ volumes.append(
+ {
+ "name": name,
+ "hostPath": {"path": volume["volume"].split(":")[-1], "type": volume_type},
+ }
+ )
+
+ volume_mounts = []
+ target_path = "/opt/code/localstack/"
+ venv_path = os.path.join(target_path, ".venv", "lib", "python3.11", "site-packages")
+ for volume in volumes:
+ if volume["name"] == "entry-points":
+ entry_points_path = os.path.join(
+ target_path, "localstack_core.egg-info", "entry_points.txt"
+ )
+ if pro:
+ project = "localstack_ext-"
+ version = localstack_version.__version__
+ dist_info = f"{project}{version}0.dist-info"
+ entry_points_path = os.path.join(venv_path, dist_info, "entry_points.txt")
+
+ volume_mounts.append(
+ {
+ "name": volume["name"],
+ "readOnly": True,
+ "mountPath": entry_points_path,
+ }
+ )
+ continue
+
+ volume_mounts.append(
+ {
+ "name": volume["name"],
+ "readOnly": True,
+ "mountPath": os.path.join(venv_path, volume["hostPath"]["path"].split("/")[-1]),
+ }
+ )
+
+ extra_env_vars = []
+ if env:
+ for env_variable in env:
+ lhs, _, rhs = env_variable.partition("=")
+ extra_env_vars.append(
+ {
+ "name": lhs,
+ "value": rhs,
+ }
+ )
+
+ if pro:
+ extra_env_vars.append(
+ {
+ "name": "LOCALSTACK_AUTH_TOKEN",
+ "value": "test",
+ }
+ )
+
+ image_repository = "localstack/localstack-pro" if pro else "localstack/localstack"
+
+ overrides = {
+ "debug": True,
+ "volumes": volumes,
+ "volumeMounts": volume_mounts,
+ "extraEnvVars": extra_env_vars,
+ "image": {"repository": image_repository},
+ }
+
+ return overrides
+
+
+def write_file(content: dict, output_path: str, file_name: str):
+ path = os.path.join(output_path, file_name)
+ with open(path, "w") as f:
+ f.write(yaml.dump(content))
+ f.close()
+ print(f"Generated file at {path}")
+
+
+def print_file(content: dict, file_name: str):
+ print(f"Generated file:\t{file_name}")
+ print("=====================================")
+ print(yaml.dump(content))
+ print("=====================================")
+
+
+@click.command("run")
+@click.option(
+ "--pro", is_flag=True, default=None, help="Mount the localstack-pro code into the cluster."
+)
+@click.option(
+ "--mount-moto", is_flag=True, default=None, help="Mount the moto code into the cluster."
+)
+@click.option(
+ "--write",
+ is_flag=True,
+ default=None,
+ help="Write the configuration and overrides to files.",
+)
+@click.option(
+ "--output-dir",
+ "-o",
+ type=click.Path(exists=True, file_okay=False, resolve_path=True),
+ help="Output directory for generated files.",
+)
+@click.option(
+ "--overrides-file",
+ "-of",
+ default=None,
+ help="Name of the overrides file (default: overrides.yml).",
+)
+@click.option(
+ "--config-file",
+ "-cf",
+ default=None,
+ help="Name of the configuration file (default: configuration.yml).",
+)
+@click.option(
+ "--env", "-e", default=None, help="Environment variable to set in the pod", multiple=True
+)
+@click.option(
+ "--port",
+ "-p",
+ default=4566,
+ help="Port to expose from the kubernetes node",
+ type=click.IntRange(0, 65535),
+)
+@click.argument("command", nargs=-1, required=False)
+def run(
+ pro: bool = None,
+ mount_moto: bool = False,
+ write: bool = False,
+ output_dir=None,
+ overrides_file: str = None,
+ config_file: str = None,
+ command: str = None,
+ env: list[str] = None,
+ port: int = None,
+):
+ """
+ A tool for localstack developers to generate the kubernetes cluster configuration file and the overrides to mount the localstack code into the cluster.
+ """
+
+ config = generate_k8s_cluster_config(pro=pro, mount_moto=mount_moto, port=port)
+
+ overrides = generate_k8s_cluster_overrides(pro, config, env=env)
+
+ output_dir = output_dir or os.getcwd()
+ overrides_file = overrides_file or "overrides.yml"
+ config_file = config_file or "configuration.yml"
+
+ if write:
+ write_file(config, output_dir, config_file)
+ write_file(overrides, output_dir, overrides_file)
+ else:
+ print_file(config, config_file)
+ print_file(overrides, overrides_file)
+
+ overrides_file_path = os.path.join(output_dir, overrides_file)
+ config_file_path = os.path.join(output_dir, config_file)
+
+ print("\nTo create a k3d cluster with the generated configuration, follow these steps:")
+ print("1. Run the following command to create the cluster:")
+ print(f"\n k3d cluster create --config {config_file_path}\n")
+
+ print("2. Once the cluster is created, start LocalStack with the generated overrides:")
+ print("\n helm repo add localstack https://localstack.github.io/helm-charts # (if required)")
+ print(
+ f"\n helm upgrade --install localstack localstack/localstack -f {overrides_file_path}\n"
+ )
+
+
+def main():
+ run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/localstack/services/dynamodb/__init__.py b/localstack-core/localstack/dev/run/__init__.py
similarity index 100%
rename from localstack/services/dynamodb/__init__.py
rename to localstack-core/localstack/dev/run/__init__.py
diff --git a/localstack-core/localstack/dev/run/__main__.py b/localstack-core/localstack/dev/run/__main__.py
new file mode 100644
index 0000000000000..39ab236c9e3c2
--- /dev/null
+++ b/localstack-core/localstack/dev/run/__main__.py
@@ -0,0 +1,408 @@
+import dataclasses
+import os
+from typing import Iterable, Tuple
+
+import click
+from rich.rule import Rule
+
+from localstack import config
+from localstack.cli import console
+from localstack.runtime import hooks
+from localstack.utils.bootstrap import Container, ContainerConfigurators
+from localstack.utils.container_utils.container_client import (
+ ContainerConfiguration,
+ PortMappings,
+ VolumeMappings,
+)
+from localstack.utils.container_utils.docker_cmd_client import CmdDockerClient
+from localstack.utils.files import cache_dir
+from localstack.utils.run import run_interactive
+from localstack.utils.strings import short_uid
+
+from .configurators import (
+ ConfigEnvironmentConfigurator,
+ DependencyMountConfigurator,
+ EntryPointMountConfigurator,
+ ImageConfigurator,
+ PortConfigurator,
+ SourceVolumeMountConfigurator,
+)
+from .paths import HOST_PATH_MAPPINGS, HostPaths
+
+
+@click.command("run")
+@click.option(
+ "--image",
+ type=str,
+ required=False,
+ help="Overwrite the container image to be used (defaults to localstack/localstack or "
+ "localstack/localstack-pro).",
+)
+@click.option(
+ "--volume-dir",
+ type=click.Path(file_okay=False, dir_okay=True),
+ required=False,
+ help="The localstack volume on the host, default: ~/.cache/localstack/volume",
+)
+@click.option(
+ "--pro/--community",
+ is_flag=True,
+ default=None,
+ help="Whether to start localstack pro or community. If not set, it will guess from the current directory",
+)
+@click.option(
+ "--develop/--no-develop",
+ is_flag=True,
+ default=False,
+ help="Install debugpy and expose port 5678",
+)
+@click.option(
+ "--randomize",
+ is_flag=True,
+ default=False,
+ help="Randomize container name and ports to start multiple instances",
+)
+@click.option(
+ "--mount-source/--no-mount-source",
+ is_flag=True,
+ default=True,
+ help="Mount source files from localstack and localstack-ext. Use --local-packages for optional dependencies such as moto.",
+)
+@click.option(
+ "--mount-dependencies/--no-mount-dependencies",
+ is_flag=True,
+ default=False,
+ help="Whether to mount the dependencies of the current .venv directory into the container. Note this only works if the dependencies are compatible with the python and platform version from the venv and the container.",
+)
+@click.option(
+ "--mount-entrypoints/--no-mount-entrypoints",
+ is_flag=True,
+ default=False,
+ help="Mount entrypoints",
+)
+@click.option("--mount-docker-socket/--no-docker-socket", is_flag=True, default=True)
+@click.option(
+ "--env",
+ "-e",
+ help="Additional environment variables that are passed to the LocalStack container",
+ multiple=True,
+ required=False,
+)
+@click.option(
+ "--volume",
+ "-v",
+ help="Additional volume mounts that are passed to the LocalStack container",
+ multiple=True,
+ required=False,
+)
+@click.option(
+ "--publish",
+ "-p",
+ help="Additional ports that are published to the host",
+ multiple=True,
+ required=False,
+)
+@click.option(
+ "--entrypoint",
+ type=str,
+ required=False,
+ help="Additional entrypoint flag passed to docker",
+)
+@click.option(
+ "--network",
+ type=str,
+ required=False,
+ help="Docker network to start the container in",
+)
+@click.option(
+ "--local-packages",
+ "-l",
+ multiple=True,
+ required=False,
+ type=click.Choice(HOST_PATH_MAPPINGS.keys(), case_sensitive=False),
+ help="Mount specified packages into the container",
+)
+@click.argument("command", nargs=-1, required=False)
+def run(
+ image: str = None,
+ volume_dir: str = None,
+ pro: bool = None,
+ develop: bool = False,
+ randomize: bool = False,
+ mount_source: bool = True,
+ mount_dependencies: bool = False,
+ mount_entrypoints: bool = False,
+ mount_docker_socket: bool = True,
+ env: Tuple = (),
+ volume: Tuple = (),
+ publish: Tuple = (),
+ entrypoint: str = None,
+ network: str = None,
+ local_packages: list[str] | None = None,
+ command: str = None,
+):
+ """
+ A tool for localstack developers to start localstack containers. Run this in your localstack or
+ localstack-ext source tree to mount local source files or dependencies into the container.
+ Here are some examples::
+
+ \b
+ python -m localstack.dev.run
+ python -m localstack.dev.run -e DEBUG=1 -e LOCALSTACK_AUTH_TOKEN=test
+ python -m localstack.dev.run -- bash -c 'echo "hello"'
+
+ Explanations and more examples:
+
+ Start a normal container localstack container. If you run this from the localstack-ext repo,
+ it will start localstack-pro::
+
+ python -m localstack.dev.run
+
+ If you start localstack-pro, you might also want to add the API KEY as environment variable::
+
+ python -m localstack.dev.run -e DEBUG=1 -e LOCALSTACK_AUTH_TOKEN=test
+
+ If your local changes are making modifications to plux plugins (e.g., adding new providers or hooks),
+ then you also want to mount the newly generated entry_point.txt files into the container::
+
+ python -m localstack.dev.run --mount-entrypoints
+
+ Start a new container with randomized gateway and service ports, and randomized container name::
+
+ python -m localstack.dev.run --randomize
+
+ You can also run custom commands:
+
+ python -m localstack.dev.run bash -c 'echo "hello"'
+
+ Or use custom entrypoints:
+
+ python -m localstack.dev.run --entrypoint /bin/bash -- echo "hello"
+
+ You can import and expose debugpy:
+
+ python -m localstack.dev.run --develop
+
+ You can also mount local dependencies (e.g., pytest and other test dependencies, and then use that
+ in the container)::
+
+ \b
+ python -m localstack.dev.run --mount-dependencies \\
+ -v $PWD/tests:/opt/code/localstack/tests \\
+ -- .venv/bin/python -m pytest tests/unit/http_/
+
+ The script generally assumes that you are executing in either localstack or localstack-ext source
+ repositories that are organized like this::
+
+ \b
+ somedir <- your workspace directory
+ βββ localstack <- execute script in here
+ β βββ ...
+ β βββ localstack-core
+ β β βββ localstack <- will be mounted into the container
+ β β βββ localstack_core.egg-info
+ β βββ pyproject.toml
+ β βββ tests
+ β βββ ...
+ βββ localstack-ext <- or execute script in here
+ β βββ ...
+ β βββ localstack-pro-core
+ β β βββ localstack
+ β β β βββ pro
+ β β β βββ core <- will be mounted into the container
+ β β βββ localstack_ext.egg-info
+ β β βββ pyproject.toml
+ β β βββ tests
+ β βββ ...
+ βββ moto
+ β βββ AUTHORS.md
+ β βββ ...
+ β βββ moto <- will be mounted into the container
+ β βββ moto_ext.egg-info
+ β βββ pyproject.toml
+ β βββ tests
+ β βββ ...
+
+ You can choose which local source repositories are mounted in. For example, if `moto` and `rolo` are
+ both present, only mount `rolo` into the container.
+
+ \b
+ python -m localstack.dev.run --local-packages rolo
+
+ If both `rolo` and `moto` are available and both should be mounted, use the flag twice.
+
+ \b
+ python -m localstack.dev.run --local-packages rolo --local-packages moto
+ """
+ with console.status("Configuring") as status:
+ env_vars = parse_env_vars(env)
+ configure_licensing_credentials_environment(env_vars)
+
+ # run all prepare_host hooks
+ hooks.prepare_host.run()
+
+ # set the VOLUME_DIR config variable like in the CLI
+ if not os.environ.get("LOCALSTACK_VOLUME_DIR", "").strip():
+ config.VOLUME_DIR = str(cache_dir() / "volume")
+
+ # setup important paths on the host
+ host_paths = HostPaths(
+ # we assume that python -m localstack.dev.run is always executed in the repo source
+ workspace_dir=os.path.abspath(os.path.join(os.getcwd(), "..")),
+ volume_dir=volume_dir or config.VOLUME_DIR,
+ )
+
+ # auto-set pro flag
+ if pro is None:
+ if os.getcwd().endswith("localstack-ext"):
+ pro = True
+ else:
+ pro = False
+
+ # setup base configuration
+ container_config = ContainerConfiguration(
+ image_name=image,
+ name=config.MAIN_CONTAINER_NAME if not randomize else f"localstack-{short_uid()}",
+ remove=True,
+ interactive=True,
+ tty=True,
+ env_vars=dict(),
+ volumes=VolumeMappings(),
+ ports=PortMappings(),
+ network=network,
+ )
+
+ # replicate pro startup
+ if pro:
+ try:
+ from localstack.pro.core.plugins import modify_gateway_listen_config
+
+ modify_gateway_listen_config(config)
+ except ImportError:
+ pass
+
+ # setup configurators
+ configurators = [
+ ImageConfigurator(pro, image),
+ PortConfigurator(randomize),
+ ConfigEnvironmentConfigurator(pro),
+ ContainerConfigurators.mount_localstack_volume(host_paths.volume_dir),
+ ContainerConfigurators.config_env_vars,
+ ]
+
+ # create stub container with configuration to apply
+ c = Container(container_config=container_config)
+
+ # apply existing hooks first that can later be overwritten
+ hooks.configure_localstack_container.run(c)
+
+ if command:
+ configurators.append(ContainerConfigurators.custom_command(list(command)))
+ if entrypoint:
+ container_config.entrypoint = entrypoint
+ if mount_docker_socket:
+ configurators.append(ContainerConfigurators.mount_docker_socket)
+ if mount_source:
+ configurators.append(
+ SourceVolumeMountConfigurator(
+ host_paths=host_paths,
+ pro=pro,
+ chosen_packages=local_packages,
+ )
+ )
+ if mount_entrypoints:
+ configurators.append(EntryPointMountConfigurator(host_paths=host_paths, pro=pro))
+ if mount_dependencies:
+ configurators.append(DependencyMountConfigurator(host_paths=host_paths))
+ if develop:
+ configurators.append(ContainerConfigurators.develop)
+
+ # make sure anything coming from CLI arguments has priority
+ configurators.extend(
+ [
+ ContainerConfigurators.volume_cli_params(volume),
+ ContainerConfigurators.port_cli_params(publish),
+ ContainerConfigurators.env_cli_params(env),
+ ]
+ )
+
+ # run configurators
+ for configurator in configurators:
+ configurator(container_config)
+ # print the config
+ print_config(container_config)
+
+ # run the container
+ docker = CmdDockerClient()
+ status.update("Creating container")
+ container_id = docker.create_container_from_config(container_config)
+
+ rule = Rule(f"Interactive session with {container_id[:12]} π»")
+ console.print(rule)
+ try:
+ cmd = [*docker._docker_cmd(), "start", "--interactive", "--attach", container_id]
+ run_interactive(cmd)
+ finally:
+ if container_config.remove:
+ try:
+ if docker.is_container_running(container_id):
+ docker.stop_container(container_id)
+ docker.remove_container(container_id)
+ except Exception:
+ pass
+
+
+def print_config(cfg: ContainerConfiguration):
+ d = dataclasses.asdict(cfg)
+
+ d["volumes"] = [v.to_str() for v in d["volumes"].mappings]
+ d["ports"] = [p for p in d["ports"].to_list() if p != "-p"]
+
+ for k in list(d.keys()):
+ if d[k] is None:
+ d.pop(k)
+
+ console.print(d)
+
+
+def parse_env_vars(params: Iterable[str] = None) -> dict[str, str]:
+ env = {}
+
+ if not params:
+ return env
+
+ for e in params:
+ if "=" in e:
+ k, v = e.split("=", maxsplit=1)
+ env[k] = v
+ else:
+ # there's currently no way in our abstraction to only pass the variable name (as
+ # you can do in docker) so we resolve the value here.
+ env[e] = os.getenv(e)
+
+ return env
+
+
+def configure_licensing_credentials_environment(env_vars: dict[str, str]):
+ """
+ If an api key or auth token is set in the parsed CLI parameters, then we also set them into the OS environment
+ unless they are already set. This is just convenience so you don't have to set them twice.
+
+ :param env_vars: the environment variables parsed from the CLI parameters
+ """
+ if os.environ.get("LOCALSTACK_API_KEY"):
+ return
+ if os.environ.get("LOCALSTACK_AUTH_TOKEN"):
+ return
+ if api_key := env_vars.get("LOCALSTACK_API_KEY"):
+ os.environ["LOCALSTACK_API_KEY"] = api_key
+ if api_key := env_vars.get("LOCALSTACK_AUTH_TOKEN"):
+ os.environ["LOCALSTACK_AUTH_TOKEN"] = api_key
+
+
+def main():
+ run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/localstack-core/localstack/dev/run/configurators.py b/localstack-core/localstack/dev/run/configurators.py
new file mode 100644
index 0000000000000..2c3b253965e87
--- /dev/null
+++ b/localstack-core/localstack/dev/run/configurators.py
@@ -0,0 +1,375 @@
+"""
+Several ContainerConfigurator implementations to set up a development version of a localstack container.
+"""
+
+import gzip
+import os
+from pathlib import Path, PurePosixPath
+from tempfile import gettempdir
+
+from localstack import config, constants
+from localstack.utils.bootstrap import ContainerConfigurators
+from localstack.utils.container_utils.container_client import (
+ ContainerClient,
+ ContainerConfiguration,
+ VolumeBind,
+ VolumeMappings,
+)
+from localstack.utils.docker_utils import DOCKER_CLIENT
+from localstack.utils.files import get_user_cache_dir
+from localstack.utils.run import run
+from localstack.utils.strings import md5
+
+from .paths import (
+ HOST_PATH_MAPPINGS,
+ CommunityContainerPaths,
+ ContainerPaths,
+ HostPaths,
+ ProContainerPaths,
+)
+
+
+class ConfigEnvironmentConfigurator:
+ """Configures the environment variables from the localstack and localstack-pro config."""
+
+ def __init__(self, pro: bool):
+ self.pro = pro
+
+ def __call__(self, cfg: ContainerConfiguration):
+ if cfg.env_vars is None:
+ cfg.env_vars = {}
+
+ if self.pro:
+ # import localstack.pro.core.config extends the list of config vars
+ from localstack.pro.core import config as config_pro # noqa
+
+ ContainerConfigurators.config_env_vars(cfg)
+
+
+class PortConfigurator:
+ """
+ Configures the port mappings. Can be randomized to run multiple localstack instances.
+ """
+
+ def __init__(self, randomize: bool = True):
+ self.randomize = randomize
+
+ def __call__(self, cfg: ContainerConfiguration):
+ cfg.ports.bind_host = config.GATEWAY_LISTEN[0].host
+
+ if self.randomize:
+ ContainerConfigurators.random_gateway_port(cfg)
+ ContainerConfigurators.random_service_port_range()(cfg)
+ else:
+ ContainerConfigurators.gateway_listen(config.GATEWAY_LISTEN)(cfg)
+ ContainerConfigurators.service_port_range(cfg)
+
+
+class ImageConfigurator:
+ """
+ Sets the container image to use for the container (by default either localstack/localstack or
+ localstack/localstack-pro)
+ """
+
+ def __init__(self, pro: bool, image_name: str | None):
+ self.pro = pro
+ self.image_name = image_name
+
+ def __call__(self, cfg: ContainerConfiguration):
+ if self.image_name:
+ cfg.image_name = self.image_name
+ else:
+ if self.pro:
+ cfg.image_name = constants.DOCKER_IMAGE_NAME_PRO
+ else:
+ cfg.image_name = constants.DOCKER_IMAGE_NAME
+
+
+class CustomEntryPointConfigurator:
+ """
+ Creates a ``docker-entrypoint-.sh`` script from the given source and mounts it into the container.
+ It also configures the container to then use that entrypoint.
+ """
+
+ def __init__(self, script: str, tmp_dir: str = None):
+ self.script = script.lstrip(os.linesep)
+ self.container_paths = ProContainerPaths()
+ self.tmp_dir = tmp_dir
+
+ def __call__(self, cfg: ContainerConfiguration):
+ h = md5(self.script)
+ tempdir = gettempdir() if not self.tmp_dir else self.tmp_dir
+ file_name = f"docker-entrypoint-{h}.sh"
+
+ file = Path(tempdir, file_name)
+ if not file.exists():
+ # newline separator should be '\n' independent of the os, since the entrypoint is executed in the container
+ # encoding needs to be "utf-8" since scripts could include emojis
+ file.write_text(self.script, newline="\n", encoding="utf-8")
+ file.chmod(0o777)
+ cfg.volumes.add(VolumeBind(str(file), f"/tmp/{file.name}"))
+ cfg.entrypoint = f"/tmp/{file.name}"
+
+
+class SourceVolumeMountConfigurator:
+ """
+ Mounts source code of localstack, localstack_ext, and moto into the container. It does this by assuming
+ that there is a "workspace" directory in which the source repositories are checked out into.
+ Depending on whether we want to start the pro container, the source paths for localstack are different.
+ """
+
+ def __init__(
+ self,
+ *,
+ host_paths: HostPaths = None,
+ pro: bool = False,
+ chosen_packages: list[str] | None = None,
+ ):
+ self.host_paths = host_paths or HostPaths()
+ self.container_paths = ProContainerPaths() if pro else CommunityContainerPaths()
+ self.pro = pro
+ self.chosen_packages = chosen_packages or []
+
+ def __call__(self, cfg: ContainerConfiguration):
+ # localstack source code if available
+ source = self.host_paths.aws_community_package_dir
+ if source.exists():
+ cfg.volumes.add(
+ # read_only=False is a temporary workaround to make the mounting of the pro source work
+ # this can be reverted once we don't need the nested mounting anymore
+ VolumeBind(str(source), self.container_paths.localstack_source_dir, read_only=False)
+ )
+
+ # ext source code if available
+ if self.pro:
+ source = self.host_paths.aws_pro_package_dir
+ if source.exists():
+ cfg.volumes.add(
+ VolumeBind(
+ str(source), self.container_paths.localstack_pro_source_dir, read_only=True
+ )
+ )
+
+ # mount local code checkouts if possible
+ for package_name in self.chosen_packages:
+ # Unconditional lookup because the CLI rejects incorect items
+ extractor = HOST_PATH_MAPPINGS[package_name]
+ self.try_mount_to_site_packages(cfg, extractor(self.host_paths))
+
+ # docker entrypoint
+ if self.pro:
+ source = self.host_paths.localstack_pro_project_dir / "bin" / "docker-entrypoint.sh"
+ else:
+ source = self.host_paths.localstack_project_dir / "bin" / "docker-entrypoint.sh"
+ if source.exists():
+ cfg.volumes.add(
+ VolumeBind(str(source), self.container_paths.docker_entrypoint, read_only=True)
+ )
+
+ def try_mount_to_site_packages(self, cfg: ContainerConfiguration, sources_path: Path):
+ """
+ Attempts to mount something like `~/workspace/plux/plugin` on the host into
+ ``.venv/.../site-packages/plugin``.
+
+ :param cfg:
+ :param sources_path:
+ :return:
+ """
+ if sources_path.exists():
+ cfg.volumes.add(
+ VolumeBind(
+ str(sources_path),
+ self.container_paths.dependency_source(sources_path.name),
+ read_only=True,
+ )
+ )
+
+
+class EntryPointMountConfigurator:
+ """
+ Mounts ``entry_points.txt`` files of localstack and dependencies into the venv in the container.
+
+ For example, when starting the pro container, the entrypoints of localstack-ext on the host would be in
+ ``~/workspace/localstack-ext/localstack-pro-core/localstack_ext.egg-info/entry_points.txt``
+ which needs to be mounted into the distribution info of the installed dependency within the container:
+ ``/opt/code/localstack/.venv/.../site-packages/localstack_ext-2.1.0.dev0.dist-info/entry_points.txt``.
+ """
+
+ entry_point_glob = (
+ "/opt/code/localstack/.venv/lib/python3.*/site-packages/*.dist-info/entry_points.txt"
+ )
+ localstack_community_entry_points = (
+ "/opt/code/localstack/localstack_core.egg-info/entry_points.txt"
+ )
+
+ def __init__(
+ self,
+ *,
+ host_paths: HostPaths = None,
+ container_paths: ContainerPaths = None,
+ pro: bool = False,
+ ):
+ self.host_paths = host_paths or HostPaths()
+ self.pro = pro
+ self.container_paths = container_paths or None
+
+ def __call__(self, cfg: ContainerConfiguration):
+ # special case for community code
+ if not self.pro:
+ host_path = self.host_paths.aws_community_package_dir
+ if host_path.exists():
+ cfg.volumes.append(
+ VolumeBind(
+ str(host_path), self.localstack_community_entry_points, read_only=True
+ )
+ )
+
+ # locate all relevant entry_point.txt files within the container
+ pattern = self.entry_point_glob
+ files = _list_files_in_container_image(DOCKER_CLIENT, cfg.image_name)
+ paths = [PurePosixPath(f) for f in files]
+ paths = [p for p in paths if p.match(pattern)]
+
+ # then, check whether they exist in some form on the host within the workspace directory
+ for container_path in paths:
+ dep_path = container_path.parent.name.removesuffix(".dist-info")
+ dep, ver = dep_path.split("-")
+
+ if dep == "localstack_core":
+ host_path = (
+ self.host_paths.localstack_project_dir
+ / "localstack-core"
+ / "localstack_core.egg-info"
+ / "entry_points.txt"
+ )
+ if host_path.is_file():
+ cfg.volumes.add(
+ VolumeBind(
+ str(host_path),
+ str(container_path),
+ read_only=True,
+ )
+ )
+ continue
+ elif dep == "localstack_ext":
+ host_path = (
+ self.host_paths.localstack_pro_project_dir
+ / "localstack-pro-core"
+ / "localstack_ext.egg-info"
+ / "entry_points.txt"
+ )
+ if host_path.is_file():
+ cfg.volumes.add(
+ VolumeBind(
+ str(host_path),
+ str(container_path),
+ read_only=True,
+ )
+ )
+ continue
+ for host_path in self.host_paths.workspace_dir.glob(
+ f"*/{dep}.egg-info/entry_points.txt"
+ ):
+ cfg.volumes.add(VolumeBind(str(host_path), str(container_path), read_only=True))
+ break
+
+
+class DependencyMountConfigurator:
+ """
+ Mounts source folders from your host's .venv directory into the container's .venv.
+ """
+
+ dependency_glob = "/opt/code/localstack/.venv/lib/python3.*/site-packages/*"
+
+ # skip mounting dependencies with incompatible binaries (e.g., on macOS)
+ skipped_dependencies = ["cryptography", "psutil", "rpds"]
+
+ def __init__(
+ self,
+ *,
+ host_paths: HostPaths = None,
+ container_paths: ContainerPaths = None,
+ pro: bool = False,
+ ):
+ self.host_paths = host_paths or HostPaths()
+ self.pro = pro
+ self.container_paths = container_paths or (
+ ProContainerPaths() if pro else CommunityContainerPaths()
+ )
+
+ def __call__(self, cfg: ContainerConfiguration):
+ # locate all relevant dependency directories
+ pattern = self.dependency_glob
+ files = _list_files_in_container_image(DOCKER_CLIENT, cfg.image_name)
+ paths = [PurePosixPath(f) for f in files]
+ # builds an index of "jinja2: /opt/code/.../site-packages/jinja2"
+ container_path_index = {p.name: p for p in paths if p.match(pattern)}
+
+ # find dependencies from the host
+ for dep_path in self.host_paths.venv_dir.glob("lib/python3.*/site-packages/*"):
+ # filter out everything that heuristically cannot be a source path
+ if not self._can_be_source_path(dep_path):
+ continue
+ if dep_path.name.endswith(".dist-info"):
+ continue
+ if dep_path.name == "__pycache__":
+ continue
+
+ if dep_path.name in self.skipped_dependencies:
+ continue
+
+ if dep_path.name in container_path_index:
+ # find the target path in the index if it exists
+ target_path = str(container_path_index[dep_path.name])
+ else:
+ # if the given dependency is not in the container, then we mount it anyway
+ # FIXME: we should also mount the dist-info directory. perhaps this method should be
+ # re-written completely
+ target_path = self.container_paths.dependency_source(dep_path.name)
+
+ if self._has_mount(cfg.volumes, target_path):
+ continue
+
+ cfg.volumes.append(VolumeBind(str(dep_path), target_path))
+
+ def _can_be_source_path(self, path: Path) -> bool:
+ return path.is_dir() or (path.name.endswith(".py") and not path.name.startswith("__"))
+
+ def _has_mount(self, volumes: VolumeMappings, target_path: str) -> bool:
+ return True if volumes.find_target_mapping(target_path) else False
+
+
+def _list_files_in_container_image(container_client: ContainerClient, image_name: str) -> list[str]:
+ """
+ Uses ``docker export | tar -t`` to list all files in a given docker image. It caches the result based on
+ the image ID into a gziped file into ``~/.cache/localstack-dev-cli`` to (significantly) speed up
+ subsequent calls.
+
+ :param container_client: the container client to use
+ :param image_name: the container image to analyze
+ :return: a list of file paths
+ """
+ if not image_name:
+ raise ValueError("missing image name")
+
+ image_id = container_client.inspect_image(image_name)["Id"]
+
+ cache_dir = get_user_cache_dir() / "localstack-dev-cli"
+ cache_dir.mkdir(exist_ok=True, parents=True)
+ cache_file = cache_dir / f"{image_id}.files.txt.gz"
+
+ if not cache_file.exists():
+ container_id = container_client.create_container(image_name=image_name)
+ try:
+ # docker export yields paths without prefixed slashes, so we add them here
+ # since the file is pretty big (~4MB for community, ~7MB for pro) we gzip it
+ cmd = "docker export %s | tar -t | awk '{ print \"/\" $0 }' | gzip > %s" % (
+ container_id,
+ cache_file,
+ )
+ run(cmd, shell=True)
+ finally:
+ container_client.remove_container(container_id)
+
+ with gzip.open(cache_file, mode="rt") as fd:
+ return fd.read().splitlines(keepends=False)
diff --git a/localstack-core/localstack/dev/run/paths.py b/localstack-core/localstack/dev/run/paths.py
new file mode 100644
index 0000000000000..b1fe9a95f24fd
--- /dev/null
+++ b/localstack-core/localstack/dev/run/paths.py
@@ -0,0 +1,94 @@
+"""Utilities to resolve important paths on the host and in the container."""
+
+import os
+from pathlib import Path
+from typing import Callable, Optional, Union
+
+
+class HostPaths:
+ workspace_dir: Path
+ """We assume all repositories live in a workspace directory, e.g., ``~/workspace/ls/localstack``,
+ ``~/workspace/ls/localstack-ext``, ..."""
+
+ localstack_project_dir: Path
+ localstack_pro_project_dir: Path
+ moto_project_dir: Path
+ postgresql_proxy: Path
+ rolo_dir: Path
+ volume_dir: Path
+ venv_dir: Path
+
+ def __init__(
+ self,
+ workspace_dir: Union[os.PathLike, str] = None,
+ volume_dir: Union[os.PathLike, str] = None,
+ venv_dir: Union[os.PathLike, str] = None,
+ ):
+ self.workspace_dir = Path(workspace_dir or os.path.abspath(os.path.join(os.getcwd(), "..")))
+ self.localstack_project_dir = self.workspace_dir / "localstack"
+ self.localstack_pro_project_dir = self.workspace_dir / "localstack-ext"
+ self.moto_project_dir = self.workspace_dir / "moto"
+ self.postgresql_proxy = self.workspace_dir / "postgresql-proxy"
+ self.rolo_dir = self.workspace_dir / "rolo"
+ self.volume_dir = Path(volume_dir or "/tmp/localstack")
+ self.venv_dir = Path(
+ venv_dir
+ or os.getenv("VIRTUAL_ENV")
+ or os.getenv("VENV_DIR")
+ or os.path.join(os.getcwd(), ".venv")
+ )
+
+ @property
+ def aws_community_package_dir(self) -> Path:
+ return self.localstack_project_dir / "localstack-core" / "localstack"
+
+ @property
+ def aws_pro_package_dir(self) -> Path:
+ return (
+ self.localstack_pro_project_dir / "localstack-pro-core" / "localstack" / "pro" / "core"
+ )
+
+
+# Type representing how to extract a specific path from a common root path, typically a lambda function
+PathMappingExtractor = Callable[[HostPaths], Path]
+
+# Declaration of which local packages can be mounted into the container, and their locations on the host
+HOST_PATH_MAPPINGS: dict[
+ str,
+ PathMappingExtractor,
+] = {
+ "moto": lambda paths: paths.moto_project_dir / "moto",
+ "postgresql_proxy": lambda paths: paths.postgresql_proxy / "postgresql_proxy",
+ "rolo": lambda paths: paths.rolo_dir / "rolo",
+ "plux": lambda paths: paths.workspace_dir / "plux" / "plugin",
+}
+
+
+class ContainerPaths:
+ """Important paths in the container"""
+
+ project_dir: str = "/opt/code/localstack"
+ site_packages_target_dir: str = "/opt/code/localstack/.venv/lib/python3.11/site-packages"
+ docker_entrypoint: str = "/usr/local/bin/docker-entrypoint.sh"
+ localstack_supervisor: str = "/usr/local/bin/localstack-supervisor"
+ localstack_source_dir: str
+ localstack_pro_source_dir: Optional[str]
+
+ def dependency_source(self, name: str) -> str:
+ """Returns path of the given source dependency in the site-packages directory."""
+ return self.site_packages_target_dir + f"/{name}"
+
+
+class CommunityContainerPaths(ContainerPaths):
+ """In the community image, code is copied into /opt/code/localstack/localstack-core/localstack"""
+
+ def __init__(self):
+ self.localstack_source_dir = f"{self.project_dir}/localstack-core/localstack"
+
+
+class ProContainerPaths(ContainerPaths):
+ """In the pro image, localstack and ext are installed into the venv as dependency"""
+
+ def __init__(self):
+ self.localstack_source_dir = self.dependency_source("localstack")
+ self.localstack_pro_source_dir = self.dependency_source("localstack") + "/pro/core"
diff --git a/localstack/services/dynamodbstreams/__init__.py b/localstack-core/localstack/dns/__init__.py
similarity index 100%
rename from localstack/services/dynamodbstreams/__init__.py
rename to localstack-core/localstack/dns/__init__.py
diff --git a/localstack-core/localstack/dns/models.py b/localstack-core/localstack/dns/models.py
new file mode 100644
index 0000000000000..6df70bf6e0d86
--- /dev/null
+++ b/localstack-core/localstack/dns/models.py
@@ -0,0 +1,175 @@
+import dataclasses
+from enum import Enum, auto
+from typing import Callable, Protocol
+
+
+class RecordType(Enum):
+ A = auto()
+ AAAA = auto()
+ CNAME = auto()
+ TXT = auto()
+ MX = auto()
+ SOA = auto()
+ NS = auto()
+ SRV = auto()
+
+
+@dataclasses.dataclass(frozen=True)
+class NameRecord:
+ """
+ Dataclass of a stored record
+ """
+
+ record_type: RecordType
+ record_id: str | None = None
+
+
+@dataclasses.dataclass(frozen=True)
+class _TargetRecordBase:
+ """
+ Dataclass of a stored record
+ """
+
+ target: str
+
+
+@dataclasses.dataclass(frozen=True)
+class TargetRecord(NameRecord, _TargetRecordBase):
+ pass
+
+
+@dataclasses.dataclass(frozen=True)
+class _SOARecordBase:
+ m_name: str
+ r_name: str
+
+
+@dataclasses.dataclass(frozen=True)
+class SOARecord(NameRecord, _SOARecordBase):
+ pass
+
+
+@dataclasses.dataclass(frozen=True)
+class AliasTarget:
+ target: str
+ alias_id: str | None = None
+ health_check: Callable[[], bool] | None = None
+
+
+@dataclasses.dataclass(frozen=True)
+class _DynamicRecordBase:
+ """
+ Dataclass of a record that is dynamically determined at query time to return the IP address
+ of the LocalStack container
+ """
+
+ record_type: RecordType
+
+
+@dataclasses.dataclass(frozen=True)
+class DynamicRecord(NameRecord, _DynamicRecordBase):
+ pass
+
+
+# TODO decide if we need the whole concept of multiple zones in our DNS implementation
+class DnsServerProtocol(Protocol):
+ def add_host(self, name: str, record: NameRecord) -> None:
+ """
+ Add a host resolution to the DNS server.
+ This will resolve the given host to the record provided, if it matches.
+
+ :param name: Name pattern to add resolution for. Can be arbitrary regex.
+ :param record: Record, consisting of a record type, an optional record id, and the attached data.
+ Has to be a subclass of a NameRecord, not a NameRecord itself to contain some data.
+ """
+ pass
+
+ def delete_host(self, name: str, record: NameRecord) -> None:
+ """
+ Deletes a host resolution from the DNS server.
+ Only the name, the record type, and optionally the given record id will be used to find entries to delete.
+ All matching entries will be deleted.
+
+ :param name: Name pattern, identically to the one registered with `add_host`
+ :param record: Record, ideally identically to the one registered with add_host but only record_type and
+ record_id have to match to find the record.
+
+ :raises ValueError: If no record that was previously registered with `add_host` was found which matches the provided record
+ """
+ pass
+
+ def add_host_pointing_to_localstack(self, name: str) -> None:
+ """
+ Add a dns name which should be pointing to LocalStack when resolved.
+
+ :param name: Name which should be pointing to LocalStack when resolved
+ """
+ pass
+
+ def delete_host_pointing_to_localstack(self, name: str) -> None:
+ """
+ Removes a dns name from pointing to LocalStack
+
+ :param name: Name to be removed
+ :raises ValueError: If the host pointing to LocalStack was not previously registered using `add_host_pointing_to_localstack`
+ """
+ pass
+
+ def add_alias(self, source_name: str, record_type: RecordType, target: AliasTarget) -> None:
+ """
+ Adds an alias to the DNS, with an optional healthcheck callback.
+ When a request which matches `source_name` comes in, the DNS will check the aliases, and if the healthcheck
+ (if provided) succeeds, the resolution result for the `target_name` will be returned instead.
+ If multiple aliases are registered for the same source_name record_type tuple, and no health checks interfere,
+ the server will process requests with the first added alias
+
+ :param source_name: Alias name
+ :param record_type: Record type of the alias
+ :param target: Target of the alias
+ """
+ pass
+
+ def delete_alias(self, source_name: str, record_type: RecordType, target: AliasTarget) -> None:
+ """
+ Removes an alias from the DNS.
+ Only the name, the record type, and optionally the given alias id will be used to find entries to delete.
+ All matching entries will be deleted.
+
+ :param source_name: Alias name
+ :param record_type: Record type of the alias to remove
+ :param target: Target of the alias. Only relevant data for deletion will be its id.
+ :raises ValueError: If the alias was not previously registered using `add_alias`
+ """
+ pass
+
+ # TODO: support regex or wildcard?
+ # need to update when custom cloudpod destination is enabled
+ # has standard list of skips: localstack.services.dns_server.SKIP_PATTERNS
+ def add_skip(self, skip_pattern: str) -> None:
+ """
+ Add a skip pattern to the DNS server.
+
+ A skip pattern will prevent the DNS server from resolving a matching request against it's internal zones or
+ aliases, and will directly contact an upstream DNS for resolution.
+
+ This is usually helpful if AWS endpoints are overwritten by internal entries, but we have to reach AWS for
+ some reason. (Often used for cloudpods or installers).
+
+ :param skip_pattern: Skip pattern to add. Can be a valid regex.
+ """
+ pass
+
+ def delete_skip(self, skip_pattern: str) -> None:
+ """
+ Removes a skip pattern from the DNS server.
+
+ :param skip_pattern: Skip pattern to remove
+ :raises ValueError: If the skip pattern was not previously registered using `add_skip`
+ """
+ pass
+
+ def clear(self):
+ """
+ Removes all runtime configurations.
+ """
+ pass
diff --git a/localstack-core/localstack/dns/plugins.py b/localstack-core/localstack/dns/plugins.py
new file mode 100644
index 0000000000000..05566573cfec8
--- /dev/null
+++ b/localstack-core/localstack/dns/plugins.py
@@ -0,0 +1,45 @@
+import logging
+
+from localstack import config
+from localstack.runtime import hooks
+
+LOG = logging.getLogger(__name__)
+
+# Note: Don't want to introduce a possible import order conflict by importing SERVICE_SHUTDOWN_PRIORITY
+# TODO: consider extracting these priorities into some static configuration
+DNS_SHUTDOWN_PRIORITY = -30
+"""Make sure the DNS server is shut down after the ON_AFTER_SERVICE_SHUTDOWN_HANDLERS, which in turn is after
+SERVICE_SHUTDOWN_PRIORITY. Currently this value needs to be less than -20"""
+
+
+@hooks.on_infra_start(priority=10)
+def start_dns_server():
+ try:
+ from localstack.dns import server
+
+ server.start_dns_server(port=config.DNS_PORT, asynchronous=True)
+ except Exception as e:
+ LOG.warning("Unable to start DNS: %s", e)
+
+
+@hooks.on_infra_start()
+def setup_dns_configuration_on_host():
+ try:
+ from localstack.dns import server
+
+ if server.is_server_running():
+ # Prepare network interfaces for DNS server for the infra.
+ server.setup_network_configuration()
+ except Exception as e:
+ LOG.warning("error setting up dns server: %s", e)
+
+
+@hooks.on_infra_shutdown(priority=DNS_SHUTDOWN_PRIORITY)
+def stop_server():
+ try:
+ from localstack.dns import server
+
+ server.revert_network_configuration()
+ server.stop_servers()
+ except Exception as e:
+ LOG.warning("Unable to stop DNS servers: %s", e)
diff --git a/localstack-core/localstack/dns/server.py b/localstack-core/localstack/dns/server.py
new file mode 100644
index 0000000000000..6cf61ec0b0937
--- /dev/null
+++ b/localstack-core/localstack/dns/server.py
@@ -0,0 +1,980 @@
+import argparse
+import copy
+import logging
+import os
+import re
+import textwrap
+import threading
+from datetime import datetime
+from functools import cache
+from ipaddress import IPv4Address, IPv4Interface
+from pathlib import Path
+from socket import AddressFamily
+from typing import Iterable, Literal, Tuple
+
+import psutil
+from cachetools import TTLCache, cached
+from dnslib import (
+ AAAA,
+ CNAME,
+ MX,
+ NS,
+ QTYPE,
+ RCODE,
+ RD,
+ RDMAP,
+ RR,
+ SOA,
+ TXT,
+ A,
+ DNSHeader,
+ DNSLabel,
+ DNSQuestion,
+ DNSRecord,
+)
+from dnslib.server import DNSHandler, DNSServer
+from psutil._common import snicaddr
+
+import dns.flags
+import dns.message
+import dns.query
+from dns.exception import Timeout
+
+# Note: avoid adding additional imports here, to avoid import issues when running the CLI
+from localstack import config
+from localstack.constants import LOCALHOST_HOSTNAME, LOCALHOST_IP
+from localstack.dns.models import (
+ AliasTarget,
+ DnsServerProtocol,
+ DynamicRecord,
+ NameRecord,
+ RecordType,
+ SOARecord,
+ TargetRecord,
+)
+from localstack.services.edge import run_module_as_sudo
+from localstack.utils import iputils
+from localstack.utils.net import Port, port_can_be_bound
+from localstack.utils.platform import in_docker
+from localstack.utils.serving import Server
+from localstack.utils.strings import to_bytes, to_str
+from localstack.utils.sync import sleep_forever
+
+EPOCH = datetime(1970, 1, 1)
+SERIAL = int((datetime.utcnow() - EPOCH).total_seconds())
+
+DEFAULT_FALLBACK_DNS_SERVER = "8.8.8.8"
+FALLBACK_DNS_LOCK = threading.RLock()
+VERIFICATION_DOMAIN = config.DNS_VERIFICATION_DOMAIN
+
+RCODE_REFUSED = 5
+
+DNS_SERVER: "DnsServerProtocol" = None
+PREVIOUS_RESOLV_CONF_FILE: str | None = None
+
+REQUEST_TIMEOUT_SECS = 7
+
+TYPE_LOOKUP = {
+ A: QTYPE.A,
+ AAAA: QTYPE.AAAA,
+ CNAME: QTYPE.CNAME,
+ MX: QTYPE.MX,
+ NS: QTYPE.NS,
+ SOA: QTYPE.SOA,
+ TXT: QTYPE.TXT,
+}
+
+LOG = logging.getLogger(__name__)
+
+THREAD_LOCAL = threading.local()
+
+# Type of the value given by DNSHandler.client_address
+# in the form (ip, port) e.g. ("127.0.0.1", 58291)
+ClientAddress = Tuple[str, int]
+
+psutil_cache = TTLCache(maxsize=100, ttl=10)
+
+
+# TODO: update route53 provider to use this util
+def normalise_dns_name(name: DNSLabel | str) -> str:
+ name = str(name)
+ if not name.endswith("."):
+ return f"{name}."
+
+ return name
+
+
+@cached(cache=psutil_cache)
+def list_network_interface_details() -> dict[str, list[snicaddr]]:
+ return psutil.net_if_addrs()
+
+
+class Record:
+ def __init__(self, rdata_type, *args, **kwargs):
+ rtype = kwargs.get("rtype")
+ rname = kwargs.get("rname")
+ ttl = kwargs.get("ttl")
+
+ if isinstance(rdata_type, RD):
+ # actually an instance, not a type
+ self._rtype = TYPE_LOOKUP[rdata_type.__class__]
+ rdata = rdata_type
+ else:
+ self._rtype = TYPE_LOOKUP[rdata_type]
+ if rdata_type == SOA and len(args) == 2:
+ # add sensible times to SOA
+ args += (
+ (
+ SERIAL, # serial number
+ 60 * 60 * 1, # refresh
+ 60 * 60 * 3, # retry
+ 60 * 60 * 24, # expire
+ 60 * 60 * 1, # minimum
+ ),
+ )
+ rdata = rdata_type(*args)
+
+ if rtype:
+ self._rtype = rtype
+ self._rname = rname
+ self.kwargs = dict(rdata=rdata, ttl=self.sensible_ttl() if ttl is None else ttl, **kwargs)
+
+ def try_rr(self, q):
+ if q.qtype == QTYPE.ANY or q.qtype == self._rtype:
+ return self.as_rr(q.qname)
+
+ def as_rr(self, alt_rname):
+ return RR(rname=self._rname or alt_rname, rtype=self._rtype, **self.kwargs)
+
+ def sensible_ttl(self):
+ if self._rtype in (QTYPE.NS, QTYPE.SOA):
+ return 60 * 60 * 24
+ else:
+ return 300
+
+ @property
+ def is_soa(self):
+ return self._rtype == QTYPE.SOA
+
+ def __str__(self):
+ return f"{QTYPE[self._rtype]}({self.kwargs})"
+
+ def __repr__(self):
+ return self.__str__()
+
+
+class RecordConverter:
+ """
+ Handles returning the correct DNS record for the stored name_record.
+
+ Particularly, if the record is a DynamicRecord, then perform dynamic IP address lookup.
+ """
+
+ def __init__(self, request: DNSRecord, client_address: ClientAddress):
+ self.request = request
+ self.client_address = client_address
+
+ def to_record(self, name_record: NameRecord) -> Record:
+ """
+ :param name_record: Internal representation of the name entry
+ :return: Record type for the associated name record
+ """
+ match name_record:
+ case TargetRecord(target=target, record_type=record_type):
+ return Record(RDMAP.get(record_type.name), target)
+ case SOARecord(m_name=m_name, r_name=r_name, record_type=_):
+ return Record(SOA, m_name, r_name)
+ case DynamicRecord(record_type=record_type):
+ # Marker indicating that the target of the domain name lookup should be resolved
+ # dynamically at query time to the most suitable LocalStack container IP address
+ ip = self._determine_best_ip()
+ # TODO: be more dynamic with IPv6
+ if record_type == RecordType.AAAA:
+ ip = "::1"
+ return Record(RDMAP.get(record_type.name), ip)
+ case _:
+ raise NotImplementedError(f"Record type '{type(name_record)}' not implemented")
+
+ def _determine_best_ip(self) -> str:
+ client_ip, _ = self.client_address
+ # allow for overriding if required
+ if config.DNS_RESOLVE_IP != LOCALHOST_IP:
+ return config.DNS_RESOLVE_IP
+
+ # Look up best matching ip address for the client
+ interfaces = self._fetch_interfaces()
+ for interface in interfaces:
+ subnet = interface.network
+ ip_address = IPv4Address(client_ip)
+ if ip_address in subnet:
+ # check if the request has come from the gateway or not. If so
+ # assume the request has come from the host, and return
+ # 127.0.0.1
+ if config.is_in_docker and self._is_gateway(ip_address):
+ return LOCALHOST_IP
+
+ return str(interface.ip)
+
+ # no best solution found
+ LOG.warning(
+ "could not determine subnet-matched IP address for %s, falling back to %s",
+ self.request.q.qname,
+ LOCALHOST_IP,
+ )
+ return LOCALHOST_IP
+
+ @staticmethod
+ def _is_gateway(ip: IPv4Address) -> bool:
+ """
+ Look up the gateways that this contianer has, and return True if the
+ supplied ip address is in that list.
+ """
+ return ip == iputils.get_default_gateway()
+
+ @staticmethod
+ def _fetch_interfaces() -> Iterable[IPv4Interface]:
+ interfaces = list_network_interface_details()
+ for _, addresses in interfaces.items():
+ for address in addresses:
+ if address.family != AddressFamily.AF_INET:
+ # TODO: IPv6
+ continue
+
+ # argument is of the form e.g. 127.0.0.1/255.0.0.0
+ net = IPv4Interface(f"{address.address}/{address.netmask}")
+ yield net
+
+
+class NonLoggingHandler(DNSHandler):
+ """Subclass of DNSHandler that avoids logging to stdout on error"""
+
+ def handle(self, *args, **kwargs):
+ try:
+ THREAD_LOCAL.client_address = self.client_address
+ THREAD_LOCAL.server = self.server
+ THREAD_LOCAL.request = self.request
+ return super(NonLoggingHandler, self).handle(*args, **kwargs)
+ except Exception:
+ pass
+
+
+NAME_PATTERNS_POINTING_TO_LOCALSTACK = [
+ f".*{LOCALHOST_HOSTNAME}",
+]
+
+
+def exclude_from_resolution(domain_regex: str):
+ """
+ Excludes the given domain pattern from being resolved to LocalStack.
+ Currently only works in docker, since in host mode dns is started as separate process
+ :param domain_regex: Domain regex string
+ """
+ if DNS_SERVER:
+ DNS_SERVER.add_skip(domain_regex)
+
+
+def revert_exclude_from_resolution(domain_regex: str):
+ """
+ Reverts the exclusion of the given domain pattern
+ :param domain_regex: Domain regex string
+ """
+ try:
+ if DNS_SERVER:
+ DNS_SERVER.delete_skip(domain_regex)
+ except ValueError:
+ pass
+
+
+def _should_delete_zone(record_to_delete: NameRecord, record_to_check: NameRecord):
+ """
+ Helper function to check if we should delete the record_to_check from the list we are iterating over
+ :param record_to_delete: Record which we got from the delete request
+ :param record_to_check: Record to be checked if it should be included in the records after delete
+ :return:
+ """
+ if record_to_delete == record_to_check:
+ return True
+ return (
+ record_to_delete.record_type == record_to_check.record_type
+ and record_to_delete.record_id == record_to_check.record_id
+ )
+
+
+def _should_delete_alias(alias_to_delete: AliasTarget, alias_to_check: AliasTarget):
+ """
+ Helper function to check if we should delete the alias_to_check from the list we are iterating over
+ :param alias_to_delete: Alias which we got from the delete request
+ :param alias_to_check: Alias to be checked if it should be included in the records after delete
+ :return:
+ """
+ return alias_to_delete.alias_id == alias_to_check.alias_id
+
+
+class NoopLogger:
+ """
+ Necessary helper class to avoid logging of any dns records by dnslib
+ """
+
+ def __init__(self, *args, **kwargs):
+ pass
+
+ def log_pass(self, *args, **kwargs):
+ pass
+
+ def log_prefix(self, *args, **kwargs):
+ pass
+
+ def log_recv(self, *args, **kwargs):
+ pass
+
+ def log_send(self, *args, **kwargs):
+ pass
+
+ def log_request(self, *args, **kwargs):
+ pass
+
+ def log_reply(self, *args, **kwargs):
+ pass
+
+ def log_truncated(self, *args, **kwargs):
+ pass
+
+ def log_error(self, *args, **kwargs):
+ pass
+
+ def log_data(self, *args, **kwargs):
+ pass
+
+
+class Resolver(DnsServerProtocol):
+ # Upstream DNS server
+ upstream_dns: str
+ # List of patterns which will be skipped for local resolution and always forwarded to upstream
+ skip_patterns: list[str]
+ # Dict of zones: (domain name or pattern) -> list[dns records]
+ zones: dict[str, list[NameRecord]]
+ # Alias map (source_name, record_type) => target_name (target name then still has to be resolved!)
+ aliases: dict[tuple[DNSLabel, RecordType], list[AliasTarget]]
+ # Lock to prevent issues due to concurrent modifications
+ lock: threading.RLock
+
+ def __init__(self, upstream_dns: str):
+ self.upstream_dns = upstream_dns
+ self.skip_patterns = []
+ self.zones = {}
+ self.aliases = {}
+ self.lock = threading.RLock()
+
+ def resolve(self, request: DNSRecord, handler: DNSHandler) -> DNSRecord | None:
+ """
+ Resolve a given request, by either checking locally registered records, or forwarding to the defined
+ upstream DNS server.
+
+ :param request: DNS Request
+ :param handler: Unused.
+ :return: DNS Reply
+ """
+ reply = request.reply()
+ found = False
+
+ try:
+ if not self._skip_local_resolution(request):
+ found = self._resolve_name(request, reply, handler.client_address)
+ except Exception as e:
+ LOG.info("Unable to get DNS result: %s", e)
+
+ if found:
+ return reply
+
+ # If we did not find a matching record in our local zones, we forward to our upstream dns
+ try:
+ req_parsed = dns.message.from_wire(bytes(request.pack()))
+ r = dns.query.udp(req_parsed, self.upstream_dns, timeout=REQUEST_TIMEOUT_SECS)
+ result = self._map_response_dnspython_to_dnslib(r)
+ return result
+ except Exception as e:
+ LOG.info(
+ "Unable to get DNS result from upstream server %s for domain %s: %s",
+ self.upstream_dns,
+ str(request.q.qname),
+ e,
+ )
+
+ # if we cannot reach upstream dns, return SERVFAIL
+ if not reply.rr and reply.header.get_rcode == RCODE.NOERROR:
+ # setting this return code will cause commands like 'host' to try the next nameserver
+ reply.header.set_rcode(RCODE.SERVFAIL)
+ return None
+
+ return reply
+
+ def _skip_local_resolution(self, request) -> bool:
+ """
+ Check whether we should skip local resolution for the given request, and directly contact upstream
+
+ :param request: DNS Request
+ :return: Whether the request local resolution should be skipped
+ """
+ request_name = to_str(str(request.q.qname))
+ for p in self.skip_patterns:
+ if re.match(p, request_name):
+ return True
+ return False
+
+ def _resolve_alias(
+ self, request: DNSRecord, reply: DNSRecord, client_address: ClientAddress
+ ) -> bool:
+ if request.q.qtype in (QTYPE.A, QTYPE.AAAA, QTYPE.CNAME):
+ key = (DNSLabel(to_bytes(request.q.qname)), RecordType[QTYPE[request.q.qtype]])
+ # check if we have aliases defined for our given qname/qtype pair
+ if aliases := self.aliases.get(key):
+ for alias in aliases:
+ # if there is no health check, or the healthcheck is successful, we will consider this alias
+ # take the first alias passing this check
+ if not alias.health_check or alias.health_check():
+ request_copy: DNSRecord = copy.deepcopy(request)
+ request_copy.q.qname = alias.target
+ # check if we can resolve the alias
+ found = self._resolve_name_from_zones(request_copy, reply, client_address)
+ if found:
+ LOG.debug(
+ "Found entry for AliasTarget '%s' ('%s')", request.q.qname, alias
+ )
+ # change the replaced rr-DNS names back to the original request
+ for rr in reply.rr:
+ rr.set_rname(request.q.qname)
+ else:
+ reply.header.set_rcode(RCODE.REFUSED)
+ return True
+ return False
+
+ def _resolve_name(
+ self, request: DNSRecord, reply: DNSRecord, client_address: ClientAddress
+ ) -> bool:
+ if alias_found := self._resolve_alias(request, reply, client_address):
+ LOG.debug("Alias found: %s", request.q.qname)
+ return alias_found
+ return self._resolve_name_from_zones(request, reply, client_address)
+
+ def _resolve_name_from_zones(
+ self, request: DNSRecord, reply: DNSRecord, client_address: ClientAddress
+ ) -> bool:
+ found = False
+
+ converter = RecordConverter(request, client_address)
+
+ # check for direct (not regex based) response
+ zone = self.zones.get(normalise_dns_name(request.q.qname))
+ if zone is not None:
+ for zone_records in zone:
+ rr = converter.to_record(zone_records).try_rr(request.q)
+ if rr:
+ found = True
+ reply.add_answer(rr)
+ else:
+ # no direct zone so look for an SOA record for a higher level zone
+ for zone_label, zone_records in self.zones.items():
+ # try regex match
+ pattern = re.sub(r"(^|[^.])\*", ".*", str(zone_label))
+ if re.match(pattern, str(request.q.qname)):
+ for record in zone_records:
+ rr = converter.to_record(record).try_rr(request.q)
+ if rr:
+ found = True
+ reply.add_answer(rr)
+ # try suffix match
+ elif request.q.qname.matchSuffix(to_bytes(zone_label)):
+ try:
+ soa_record = next(r for r in zone_records if converter.to_record(r).is_soa)
+ except StopIteration:
+ continue
+ else:
+ found = True
+ reply.add_answer(converter.to_record(soa_record).as_rr(zone_label))
+ break
+ return found
+
+ def _parse_section(self, section: str) -> list[RR]:
+ result = []
+ for line in section.split("\n"):
+ line = line.strip()
+ if line:
+ if line.startswith(";"):
+ # section ended, stop parsing
+ break
+ else:
+ result += RR.fromZone(line)
+ return result
+
+ def _map_response_dnspython_to_dnslib(self, response):
+ """Map response object from dnspython to dnslib (looks like we cannot
+ simply export/import the raw messages from the wire)"""
+ flags = dns.flags.to_text(response.flags)
+
+ def flag(f):
+ return 1 if f.upper() in flags else 0
+
+ questions = []
+ for q in response.question:
+ questions.append(DNSQuestion(qname=str(q.name), qtype=q.rdtype, qclass=q.rdclass))
+
+ result = DNSRecord(
+ DNSHeader(
+ qr=flag("qr"), aa=flag("aa"), ra=flag("ra"), id=response.id, rcode=response.rcode()
+ ),
+ q=questions[0],
+ )
+
+ # extract answers
+ answer_parts = str(response).partition(";ANSWER")
+ result.add_answer(*self._parse_section(answer_parts[2]))
+ # extract authority information
+ authority_parts = str(response).partition(";AUTHORITY")
+ result.add_auth(*self._parse_section(authority_parts[2]))
+ return result
+
+ def add_host(self, name: str, record: NameRecord):
+ LOG.debug("Adding host %s with record %s", name, record)
+ name = normalise_dns_name(name)
+ with self.lock:
+ self.zones.setdefault(name, [])
+ self.zones[name].append(record)
+
+ def delete_host(self, name: str, record: NameRecord):
+ LOG.debug("Deleting host %s with record %s", name, record)
+ name = normalise_dns_name(name)
+ with self.lock:
+ if not self.zones.get(name):
+ raise ValueError("Could not find entry %s for name %s in zones", record, name)
+ self.zones.setdefault(name, [])
+ current_zones = self.zones[name]
+ self.zones[name] = [
+ zone for zone in self.zones[name] if not _should_delete_zone(record, zone)
+ ]
+ if self.zones[name] == current_zones:
+ raise ValueError("Could not find entry %s for name %s in zones", record, name)
+ # if we deleted the last entry, clean up
+ if not self.zones[name]:
+ del self.zones[name]
+
+ def add_alias(self, source_name: str, record_type: RecordType, target: AliasTarget):
+ LOG.debug("Adding alias %s with record type %s target %s", source_name, record_type, target)
+ label = (DNSLabel(to_bytes(source_name)), record_type)
+ with self.lock:
+ self.aliases.setdefault(label, [])
+ self.aliases[label].append(target)
+
+ def delete_alias(self, source_name: str, record_type: RecordType, target: AliasTarget):
+ LOG.debug(
+ "Deleting alias %s with record type %s",
+ source_name,
+ record_type,
+ )
+ label = (DNSLabel(to_bytes(source_name)), record_type)
+ with self.lock:
+ if not self.aliases.get(label):
+ raise ValueError(
+ "Could not find entry %s for name %s, record type %s in aliases",
+ target,
+ source_name,
+ record_type,
+ )
+ self.aliases.setdefault(label, [])
+ current_aliases = self.aliases[label]
+ self.aliases[label] = [
+ alias for alias in self.aliases[label] if not _should_delete_alias(target, alias)
+ ]
+ if self.aliases[label] == current_aliases:
+ raise ValueError(
+ "Could not find entry %s for name %s, record_type %s in aliases",
+ target,
+ source_name,
+ record_type,
+ )
+ # if we deleted the last entry, clean up
+ if not self.aliases[label]:
+ del self.aliases[label]
+
+ def add_host_pointing_to_localstack(self, name: str):
+ LOG.debug("Adding host %s pointing to LocalStack", name)
+ self.add_host(name, DynamicRecord(record_type=RecordType.A))
+ if config.DNS_RESOLVE_IP == config.LOCALHOST_IP:
+ self.add_host(name, DynamicRecord(record_type=RecordType.AAAA))
+
+ def delete_host_pointing_to_localstack(self, name: str):
+ LOG.debug("Deleting host %s pointing to LocalStack", name)
+ self.delete_host(name, DynamicRecord(record_type=RecordType.A))
+ if config.DNS_RESOLVE_IP == config.LOCALHOST_IP:
+ self.delete_host(name, DynamicRecord(record_type=RecordType.AAAA))
+
+ def add_skip(self, skip_pattern: str):
+ LOG.debug("Adding skip pattern %s", skip_pattern)
+ self.skip_patterns.append(skip_pattern)
+
+ def delete_skip(self, skip_pattern: str):
+ LOG.debug("Deleting skip pattern %s", skip_pattern)
+ self.skip_patterns.remove(skip_pattern)
+
+ def clear(self):
+ LOG.debug("Clearing DNS zones")
+ self.skip_patterns.clear()
+ self.zones.clear()
+ self.aliases.clear()
+
+
+class DnsServer(Server, DnsServerProtocol):
+ servers: list[DNSServer]
+ resolver: Resolver | None
+
+ def __init__(
+ self,
+ port: int,
+ protocols: list[Literal["udp", "tcp"]],
+ upstream_dns: str,
+ host: str = "0.0.0.0",
+ ) -> None:
+ super().__init__(port, host)
+ self.resolver = Resolver(upstream_dns=upstream_dns)
+ self.protocols = protocols
+ self.servers = []
+ self.handler_class = NonLoggingHandler
+
+ def _get_servers(self) -> list[DNSServer]:
+ servers = []
+ for protocol in self.protocols:
+ # TODO add option to use normal logger instead of NoopLogger for verbose debug mode
+ servers.append(
+ DNSServer(
+ self.resolver,
+ handler=self.handler_class,
+ logger=NoopLogger(),
+ port=self.port,
+ address=self.host,
+ tcp=protocol == "tcp",
+ )
+ )
+ return servers
+
+ @property
+ def protocol(self):
+ return "udp"
+
+ def health(self):
+ """
+ Runs a health check on the server. The default implementation performs is_port_open on the server URL.
+ """
+ try:
+ request = dns.message.make_query("localhost.localstack.cloud", "A")
+ answers = dns.query.udp(request, "127.0.0.1", port=self.port, timeout=0.5).answer
+ return len(answers) > 0
+ except Exception:
+ return False
+
+ def do_run(self):
+ self.servers = self._get_servers()
+ for server in self.servers:
+ server.start_thread()
+ LOG.debug("DNS Server started")
+ for server in self.servers:
+ server.thread.join()
+
+ def do_shutdown(self):
+ for server in self.servers:
+ server.stop()
+
+ def add_host(self, name: str, record: NameRecord):
+ self.resolver.add_host(name, record)
+
+ def delete_host(self, name: str, record: NameRecord):
+ self.resolver.delete_host(name, record)
+
+ def add_alias(self, source_name: str, record_type: RecordType, target: AliasTarget):
+ self.resolver.add_alias(source_name, record_type, target)
+
+ def delete_alias(self, source_name: str, record_type: RecordType, target: AliasTarget):
+ self.resolver.delete_alias(source_name, record_type, target)
+
+ def add_host_pointing_to_localstack(self, name: str):
+ self.resolver.add_host_pointing_to_localstack(name)
+
+ def delete_host_pointing_to_localstack(self, name: str):
+ self.resolver.delete_host_pointing_to_localstack(name)
+
+ def add_skip(self, skip_pattern: str):
+ self.resolver.add_skip(skip_pattern)
+
+ def delete_skip(self, skip_pattern: str):
+ self.resolver.delete_skip(skip_pattern)
+
+ def clear(self):
+ self.resolver.clear()
+
+
+class SeparateProcessDNSServer(Server, DnsServerProtocol):
+ def __init__(
+ self,
+ port: int = 53,
+ host: str = "0.0.0.0",
+ ) -> None:
+ super().__init__(port, host)
+
+ @property
+ def protocol(self):
+ return "udp"
+
+ def health(self):
+ """
+ Runs a health check on the server. The default implementation performs is_port_open on the server URL.
+ """
+ try:
+ request = dns.message.make_query("localhost.localstack.cloud", "A")
+ answers = dns.query.udp(request, "127.0.0.1", port=self.port, timeout=0.5).answer
+ return len(answers) > 0
+ except Exception:
+ return False
+
+ def do_start_thread(self):
+ # For host mode
+ env_vars = {}
+ for env_var in config.CONFIG_ENV_VARS:
+ if env_var.startswith("DNS_"):
+ value = os.environ.get(env_var, None)
+ if value is not None:
+ env_vars[env_var] = value
+
+ # note: running in a separate process breaks integration with Route53 (to be fixed for local dev mode!)
+ thread = run_module_as_sudo(
+ "localstack.dns.server",
+ asynchronous=True,
+ env_vars=env_vars,
+ arguments=["-p", str(self.port)],
+ )
+ return thread
+
+
+def get_fallback_dns_server():
+ return config.DNS_SERVER or get_available_dns_server()
+
+
+@cache
+def get_available_dns_server():
+ # TODO check if more loop-checks are necessary than just not using our own DNS server
+ with FALLBACK_DNS_LOCK:
+ resolver = dns.resolver.Resolver()
+ # we do not want to include localhost here, or a loop might happen
+ candidates = [r for r in resolver.nameservers if r != "127.0.0.1"]
+ result = None
+ candidates.append(DEFAULT_FALLBACK_DNS_SERVER)
+ for ns in candidates:
+ resolver.nameservers = [ns]
+ try:
+ try:
+ answer = resolver.resolve(VERIFICATION_DOMAIN, "a", lifetime=3)
+ answer = [
+ res.to_text() for answers in answer.response.answer for res in answers.items
+ ]
+ except Timeout:
+ answer = None
+ if not answer:
+ continue
+ result = ns
+ break
+ except Exception:
+ pass
+
+ if result:
+ LOG.debug("Determined fallback dns: %s", result)
+ else:
+ LOG.info(
+ "Unable to determine fallback DNS. Please check if '%s' is reachable by your configured DNS servers"
+ "DNS fallback will be disabled.",
+ VERIFICATION_DOMAIN,
+ )
+ return result
+
+
+# ###### LEGACY METHODS ######
+def add_resolv_entry(file_path: Path | str = Path("/etc/resolv.conf")):
+ global PREVIOUS_RESOLV_CONF_FILE
+ # never overwrite the host configuration without the user's permission
+ if not in_docker():
+ LOG.warning("Incorrectly attempted to alter host networking config")
+ return
+
+ LOG.debug("Overwriting container DNS server to point to localhost")
+ content = textwrap.dedent(
+ """
+ # The following line is required by LocalStack
+ nameserver 127.0.0.1
+ """
+ )
+ file_path = Path(file_path)
+ try:
+ with file_path.open("r+") as outfile:
+ PREVIOUS_RESOLV_CONF_FILE = outfile.read()
+ previous_resolv_conf_without_nameservers = [
+ line
+ for line in PREVIOUS_RESOLV_CONF_FILE.splitlines()
+ if not line.startswith("nameserver")
+ ]
+ outfile.seek(0)
+ outfile.write(content)
+ outfile.write("\n".join(previous_resolv_conf_without_nameservers))
+ outfile.truncate()
+ except Exception:
+ LOG.warning(
+ "Could not update container DNS settings", exc_info=LOG.isEnabledFor(logging.DEBUG)
+ )
+
+
+def revert_resolv_entry(file_path: Path | str = Path("/etc/resolv.conf")):
+ # never overwrite the host configuration without the user's permission
+ if not in_docker():
+ LOG.warning("Incorrectly attempted to alter host networking config")
+ return
+
+ if not PREVIOUS_RESOLV_CONF_FILE:
+ LOG.warning("resolv.conf file to restore not found.")
+ return
+
+ LOG.debug("Reverting container DNS config")
+ file_path = Path(file_path)
+ try:
+ with file_path.open("w") as outfile:
+ outfile.write(PREVIOUS_RESOLV_CONF_FILE)
+ except Exception:
+ LOG.warning(
+ "Could not revert container DNS settings", exc_info=LOG.isEnabledFor(logging.DEBUG)
+ )
+
+
+def setup_network_configuration():
+ # check if DNS is disabled
+ if not config.use_custom_dns():
+ return
+
+ # add entry to /etc/resolv.conf
+ if in_docker():
+ add_resolv_entry()
+
+
+def revert_network_configuration():
+ # check if DNS is disabled
+ if not config.use_custom_dns():
+ return
+
+ # add entry to /etc/resolv.conf
+ if in_docker():
+ revert_resolv_entry()
+
+
+def start_server(upstream_dns: str, host: str, port: int = config.DNS_PORT):
+ global DNS_SERVER
+
+ if DNS_SERVER:
+ # already started - bail
+ LOG.debug("DNS servers are already started. Avoid starting again.")
+ return
+
+ LOG.debug("Starting DNS servers (tcp/udp port %s on %s)...", port, host)
+ dns_server = DnsServer(port, protocols=["tcp", "udp"], host=host, upstream_dns=upstream_dns)
+
+ for name in NAME_PATTERNS_POINTING_TO_LOCALSTACK:
+ dns_server.add_host_pointing_to_localstack(name)
+ if config.LOCALSTACK_HOST.host != LOCALHOST_HOSTNAME:
+ dns_server.add_host_pointing_to_localstack(f".*{config.LOCALSTACK_HOST.host}")
+
+ # support both DNS_NAME_PATTERNS_TO_RESOLVE_UPSTREAM and DNS_LOCAL_NAME_PATTERNS
+ # until the next major version change
+ # TODO(srw): remove the usage of DNS_LOCAL_NAME_PATTERNS
+ skip_local_resolution = " ".join(
+ [
+ config.DNS_NAME_PATTERNS_TO_RESOLVE_UPSTREAM,
+ config.DNS_LOCAL_NAME_PATTERNS,
+ ]
+ ).strip()
+ if skip_local_resolution:
+ for skip_pattern in re.split(r"[,;\s]+", skip_local_resolution):
+ dns_server.add_skip(skip_pattern.strip(" \"'"))
+
+ dns_server.start()
+ if not dns_server.wait_is_up(timeout=5):
+ LOG.warning("DNS server did not come up within 5 seconds.")
+ dns_server.shutdown()
+ return
+ DNS_SERVER = dns_server
+ LOG.debug("DNS server startup finished.")
+
+
+def stop_servers():
+ if DNS_SERVER:
+ DNS_SERVER.shutdown()
+
+
+def start_dns_server_as_sudo(port: int):
+ global DNS_SERVER
+ LOG.debug(
+ "Starting the DNS on its privileged port (%s) needs root permissions. Trying to start DNS with sudo.",
+ config.DNS_PORT,
+ )
+
+ dns_server = SeparateProcessDNSServer(port)
+ dns_server.start()
+
+ if not dns_server.wait_is_up(timeout=5):
+ LOG.warning("DNS server did not come up within 5 seconds.")
+ dns_server.shutdown()
+ return
+
+ DNS_SERVER = dns_server
+ LOG.debug("DNS server startup finished (as sudo).")
+
+
+def start_dns_server(port: int, asynchronous: bool = False, standalone: bool = False):
+ if DNS_SERVER:
+ # already started - bail
+ LOG.error("DNS servers are already started. Avoid starting again.")
+ return
+
+ # check if DNS server is disabled
+ if not config.use_custom_dns():
+ LOG.debug("Not starting DNS. DNS_ADDRESS=%s", config.DNS_ADDRESS)
+ return
+
+ upstream_dns = get_fallback_dns_server()
+ if not upstream_dns:
+ LOG.warning("Error starting the DNS server: No upstream dns server found.")
+ return
+
+ # host to bind the DNS server to. In docker we always want to bind to "0.0.0.0"
+ host = config.DNS_ADDRESS
+ if in_docker():
+ host = "0.0.0.0"
+
+ if port_can_be_bound(Port(port, "udp"), address=host):
+ start_server(port=port, host=host, upstream_dns=upstream_dns)
+ if not asynchronous:
+ sleep_forever()
+ return
+
+ if standalone:
+ LOG.debug("Already in standalone mode and port binding still fails.")
+ return
+
+ start_dns_server_as_sudo(port)
+
+
+def get_dns_server() -> DnsServerProtocol:
+ return DNS_SERVER
+
+
+def is_server_running() -> bool:
+ return DNS_SERVER is not None
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-p", "--port", required=False, default=53, type=int)
+ args = parser.parse_args()
+
+ start_dns_server(asynchronous=False, port=args.port, standalone=True)
diff --git a/localstack-core/localstack/extensions/__init__.py b/localstack-core/localstack/extensions/__init__.py
new file mode 100644
index 0000000000000..3b52add044d38
--- /dev/null
+++ b/localstack-core/localstack/extensions/__init__.py
@@ -0,0 +1,3 @@
+"""Extensions are third-party software modules to customize localstack."""
+
+name = "extensions"
diff --git a/localstack-core/localstack/extensions/api/__init__.py b/localstack-core/localstack/extensions/api/__init__.py
new file mode 100644
index 0000000000000..9335bae5fe7c2
--- /dev/null
+++ b/localstack-core/localstack/extensions/api/__init__.py
@@ -0,0 +1,7 @@
+"""Public facing API for users to build LocalStack extensions."""
+
+from .extension import Extension
+
+name = "api"
+
+__all__ = ["Extension"]
diff --git a/localstack-core/localstack/extensions/api/aws.py b/localstack-core/localstack/extensions/api/aws.py
new file mode 100644
index 0000000000000..120bf4958e72b
--- /dev/null
+++ b/localstack-core/localstack/extensions/api/aws.py
@@ -0,0 +1,31 @@
+from localstack.aws.api import (
+ CommonServiceException,
+ RequestContext,
+ ServiceException,
+ ServiceRequest,
+ ServiceResponse,
+)
+from localstack.aws.chain import (
+ CompositeExceptionHandler,
+ CompositeHandler,
+ CompositeResponseHandler,
+ ExceptionHandler,
+ HandlerChain,
+)
+from localstack.aws.chain import Handler as RequestHandler
+from localstack.aws.chain import Handler as ResponseHandler
+
+__all__ = [
+ "RequestContext",
+ "ServiceRequest",
+ "ServiceResponse",
+ "ServiceException",
+ "CommonServiceException",
+ "RequestHandler",
+ "ResponseHandler",
+ "HandlerChain",
+ "CompositeHandler",
+ "ExceptionHandler",
+ "CompositeResponseHandler",
+ "CompositeExceptionHandler",
+]
diff --git a/localstack-core/localstack/extensions/api/extension.py b/localstack-core/localstack/extensions/api/extension.py
new file mode 100644
index 0000000000000..080735c4ae3a3
--- /dev/null
+++ b/localstack-core/localstack/extensions/api/extension.py
@@ -0,0 +1,99 @@
+from plux import Plugin
+
+from .aws import CompositeExceptionHandler, CompositeHandler, CompositeResponseHandler
+from .http import RouteHandler, Router
+
+
+class BaseExtension(Plugin):
+ """
+ Base extension.
+ """
+
+ def load(self, *args, **kwargs):
+ """
+ Provided to plux to load the plugins. Do NOT overwrite! PluginManagers managing extensions expect the load method to return the Extension itself.
+
+ :param args: load arguments
+ :param kwargs: load keyword arguments
+ :return: this extension object
+ """
+ return self
+
+ def on_extension_load(self, *args, **kwargs):
+ """
+ Called when LocalStack loads the extension.
+ """
+ raise NotImplementedError
+
+
+class Extension(BaseExtension):
+ """
+ An extension that is loaded into LocalStack dynamically.
+
+ The method execution order of an extension is as follows:
+
+ - on_extension_load
+ - on_platform_start
+ - update_gateway_routes
+ - update_request_handlers
+ - update_response_handlers
+ - on_platform_ready
+ """
+
+ namespace = "localstack.extensions"
+
+ def on_extension_load(self):
+ """
+ Called when LocalStack loads the extension.
+ """
+ pass
+
+ def on_platform_start(self):
+ """
+ Called when LocalStack starts the main runtime.
+ """
+ pass
+
+ def update_gateway_routes(self, router: Router[RouteHandler]):
+ """
+ Called with the Router attached to the LocalStack gateway. Overwrite this to add or update routes.
+
+ :param router: the Router attached in the gateway
+ """
+ pass
+
+ def update_request_handlers(self, handlers: CompositeHandler):
+ """
+ Called with the custom request handlers of the LocalStack gateway. Overwrite this to add or update handlers.
+
+ :param handlers: custom request handlers of the gateway
+ """
+ pass
+
+ def update_response_handlers(self, handlers: CompositeResponseHandler):
+ """
+ Called with the custom response handlers of the LocalStack gateway. Overwrite this to add or update handlers.
+
+ :param handlers: custom response handlers of the gateway
+ """
+ pass
+
+ def update_exception_handlers(self, handlers: CompositeExceptionHandler):
+ """
+ Called with the custom exception handlers of the LocalStack gateway. Overwrite this to add or update handlers.
+
+ :param handlers: custom exception handlers of the gateway
+ """
+ pass
+
+ def on_platform_ready(self):
+ """
+ Called when LocalStack is ready and the Ready marker has been printed.
+ """
+ pass
+
+ def on_platform_shutdown(self):
+ """
+ Called when LocalStack is shutting down. Can be used to close any resources (threads, processes, sockets, etc.).
+ """
+ pass
diff --git a/localstack-core/localstack/extensions/api/http.py b/localstack-core/localstack/extensions/api/http.py
new file mode 100644
index 0000000000000..5845856625206
--- /dev/null
+++ b/localstack-core/localstack/extensions/api/http.py
@@ -0,0 +1,16 @@
+from localstack.http import Request, Response, Router
+from localstack.http.client import HttpClient, SimpleRequestsClient
+from localstack.http.dispatcher import Handler as RouteHandler
+from localstack.http.proxy import Proxy, ProxyHandler, forward
+
+__all__ = [
+ "Request",
+ "Response",
+ "Router",
+ "HttpClient",
+ "SimpleRequestsClient",
+ "Proxy",
+ "ProxyHandler",
+ "forward",
+ "RouteHandler",
+]
diff --git a/localstack-core/localstack/extensions/api/runtime.py b/localstack-core/localstack/extensions/api/runtime.py
new file mode 100644
index 0000000000000..426036659c951
--- /dev/null
+++ b/localstack-core/localstack/extensions/api/runtime.py
@@ -0,0 +1,3 @@
+from localstack.utils.analytics import get_session_id
+
+__all__ = ["get_session_id"]
diff --git a/localstack-core/localstack/extensions/api/services.py b/localstack-core/localstack/extensions/api/services.py
new file mode 100644
index 0000000000000..c41152ef0d121
--- /dev/null
+++ b/localstack-core/localstack/extensions/api/services.py
@@ -0,0 +1,5 @@
+from localstack.utils.common import external_service_ports
+
+__all__ = [
+ "external_service_ports",
+]
diff --git a/localstack/services/es/__init__.py b/localstack-core/localstack/extensions/patterns/__init__.py
similarity index 100%
rename from localstack/services/es/__init__.py
rename to localstack-core/localstack/extensions/patterns/__init__.py
diff --git a/localstack-core/localstack/extensions/patterns/webapp.py b/localstack-core/localstack/extensions/patterns/webapp.py
new file mode 100644
index 0000000000000..ab69d935d729c
--- /dev/null
+++ b/localstack-core/localstack/extensions/patterns/webapp.py
@@ -0,0 +1,333 @@
+import importlib
+import logging
+import mimetypes
+import typing as t
+from functools import cached_property
+
+from rolo.gateway import HandlerChain
+from rolo.router import RuleAdapter, WithHost
+from werkzeug.routing import Submount
+
+from localstack import config
+from localstack.aws.api import RequestContext
+from localstack.extensions.api import Extension, http
+
+if t.TYPE_CHECKING:
+ # although jinja2 is included transitively via moto, let's make sure jinja2 stays optional
+ import jinja2
+
+LOG = logging.getLogger(__name__)
+
+_default = object()
+
+
+class WebAppExtension(Extension):
+ """
+ EXPERIMENTAL! This class is experimental and the API may change without notice.
+
+ A webapp extension serves routes, templates, and static files via a submount and a subdomain through
+ localstack.
+
+ It assumes you have the following directory layout::
+
+ my_extension
+ βββ extension.py
+ βββ __init__.py
+ βββ static <-- make sure static resources get packaged!
+ β βββ __init__.py
+ β βββ favicon.ico
+ β βββ style.css
+ βββ templates <-- jinja2 templates
+ βββ index.html
+
+ Given this layout, you can define your extensions in ``my_extension.extension`` like this. Routes defined in the
+ extension itself are automatically registered::
+
+ class MyExtension(WebAppExtension):
+ name = "my-extension"
+
+ @route("/")
+ def index(request: Request) -> Response:
+ # reference `static/style.css` to serve the static file from your package
+ return self.render_template_response("index.html")
+
+ @route("/hello")
+ def hello(request: Request):
+ return {"message": "Hello World!"}
+
+ This will create an extension that localstack serves via:
+
+ * Submount: https://localhost.localstack.cloud:4566/_extension/my-extension
+ * Subdomain: https://my-extension.localhost.localstack.cloud:4566/
+
+ Both are created for full flexibility:
+
+ * Subdomains: create a domain namespace that can be helpful for some extensions, especially when
+ running on the local machine
+ * Submounts: for some environments, like in ephemeral instances where subdomains are harder to control,
+ submounts are more convenient
+
+ Any routes added by the extension will be served relative to these URLs.
+ """
+
+ def __init__(
+ self,
+ mount: str = None,
+ submount: str | None = _default,
+ subdomain: str | None = _default,
+ template_package_path: str | None = _default,
+ static_package_path: str | None = _default,
+ static_url_path: str = None,
+ ):
+ """
+ Overwrite to customize your extension. For example, you can disable certain behavior by calling
+ ``super( ).__init__(subdomain=None, static_package_path=None)``, which will disable serving through
+ a subdomain, and disable static file serving.
+
+ :param mount: the "mount point" which will be used as default value for the submount and
+ subdirectory, i.e., ``.localhost.localstack.cloud`` and
+ ``localhost.localstack.cloud/_extension/``. Defaults to the extension name. Note that,
+ in case the mount name clashes with another extension, extensions may overwrite each other's
+ routes.
+ :param submount: the submount path, needs to start with a trailing slash (default
+ ``/_extension/``)
+ :param subdomain: the subdomain (defaults to the value of ``mount``)
+ :param template_package_path: the path to the templates within the module. defaults to
+ ``templates`` which expands to ``.templates``)
+ :param static_package_path: the package serving static files. defaults to ``static``, which expands to
+ ``.static``.
+ :param static_url_path: the URL path to serve static files from (defaults to `/static`)
+ """
+ mount = mount or self.name
+
+ self.submount = f"/_extension/{mount}" if submount is _default else submount
+ self.subdomain = mount if subdomain is _default else subdomain
+
+ self.template_package_path = (
+ "templates" if template_package_path is _default else template_package_path
+ )
+ self.static_package_path = (
+ "static" if static_package_path is _default else static_package_path
+ )
+ self.static_url_path = static_url_path or "/static"
+
+ self.static_resource_module = None
+
+ def collect_routes(self, routes: list[t.Any]):
+ """
+ This method can be overwritten to add more routes to the controller. Everything in ``routes`` will
+ be added to a ``RuleAdapter`` and subsequently mounted into the gateway router.
+
+ Here are some examples::
+
+ class MyRoutes:
+ @route("/hello")
+ def hello(request):
+ return "Hello World!"
+
+ class MyExtension(WebAppExtension):
+ name = "my-extension"
+
+ def collect_routes(self, routes: list[t.Any]):
+
+ # scans all routes of MyRoutes
+ routes.append(MyRoutes())
+ # use rule adapters to add routes without decorators
+ routes.append(RuleAdapter("/say-hello", self.say_hello))
+
+ # no idea why you would want to do this, but you can :-)
+ @route("/empty-dict")
+ def _inline_handler(request: Request) -> Response:
+ return Response.for_json({})
+ routes.append(_inline_handler)
+
+ def say_hello(request: Request):
+ return {"message": "Hello World!"}
+
+ This creates the following routes available through both subdomain and submount.
+
+ With subdomain:
+
+ * ``my-extension.localhost.localstack.cloud:4566/hello``
+ * ``my-extension.localhost.localstack.cloud:4566/say-hello``
+ * ``my-extension.localhost.localstack.cloud:4566/empty-dict``
+ * ``my-extension.localhost.localstack.cloud:4566/static`` <- automatically added static file endpoint
+
+ With submount:
+
+ * ``localhost.localstack.cloud:4566/_extension/my-extension/hello``
+ * ``localhost.localstack.cloud:4566/_extension/my-extension/say-hello``
+ * ``localhost.localstack.cloud:4566/_extension/my-extension/empty-dict``
+ * ``localhost.localstack.cloud:4566/_extension/my-extension/static`` <- auto-added static file serving
+
+ :param routes: the routes being collected
+ """
+ pass
+
+ @cached_property
+ def template_env(self) -> t.Optional["jinja2.Environment"]:
+ """
+ Returns the singleton jinja2 template environment. By default, the environment uses a
+ ``PackageLoader`` that loads from ``my_extension.templates`` (where ``my_extension`` is the root
+ module of the extension, and ``templates`` refers to ``self.template_package_path``,
+ which is ``templates`` by default).
+
+ :return: a template environment
+ """
+ if self.template_package_path:
+ return self._create_template_env()
+ return None
+
+ def _create_template_env(self) -> "jinja2.Environment":
+ """
+ Factory method to create the jinja2 template environment.
+ :return: a new jinja2 environment
+ """
+ import jinja2
+
+ return jinja2.Environment(
+ loader=jinja2.PackageLoader(
+ self.get_extension_module_root(), self.template_package_path
+ ),
+ autoescape=jinja2.select_autoescape(),
+ )
+
+ def render_template(self, template_name, **context) -> str:
+ """
+ Uses the ``template_env`` to render a template and return the string value.
+
+ :param template_name: the template name
+ :param context: template context
+ :return: the rendered result
+ """
+ template = self.template_env.get_template(template_name)
+ return template.render(**context)
+
+ def render_template_response(self, template_name, **context) -> http.Response:
+ """
+ Uses the ``template_env`` to render a template into an HTTP response. It guesses the mimetype from the
+ template's file name.
+
+ :param template_name: the template name
+ :param context: template context
+ :return: the rendered result as response
+ """
+ template = self.template_env.get_template(template_name)
+
+ mimetype = mimetypes.guess_type(template.filename)
+ mimetype = mimetype[0] if mimetype and mimetype[0] else "text/plain"
+
+ return http.Response(response=template.render(**context), mimetype=mimetype)
+
+ def on_extension_load(self):
+ logging.getLogger(self.get_extension_module_root()).setLevel(
+ logging.DEBUG if config.DEBUG else logging.INFO
+ )
+
+ if self.static_package_path and not self.static_resource_module:
+ try:
+ self.static_resource_module = importlib.import_module(
+ self.get_extension_module_root() + "." + self.static_package_path
+ )
+ except ModuleNotFoundError:
+ LOG.warning("disabling static resources for extension %s", self.name)
+
+ def _preprocess_request(
+ self, chain: HandlerChain, context: RequestContext, _response: http.Response
+ ):
+ """
+ Default pre-processor, which implements a default behavior to add a trailing slash to the path if the
+ submount is used directly. For instance ``/_extension/my-extension``, then it forwards to
+ ``/_extension/my-extension/``. This is so you can reference relative paths like `` `` in your HTML safely, and it will work with both subdomain and submount.
+ """
+ path = context.request.path
+
+ if path == self.submount.rstrip("/"):
+ chain.respond(301, headers={"Location": context.request.url + "/"})
+
+ def update_gateway_routes(self, router: http.Router[http.RouteHandler]):
+ from localstack.aws.handlers import preprocess_request
+
+ if self.submount:
+ preprocess_request.append(self._preprocess_request)
+
+ # adding self here makes sure that any ``@route`` decorators to the extension are mapped automatically
+ routes = [self]
+
+ if self.static_resource_module:
+ routes.append(
+ RuleAdapter(f"{self.static_url_path}/", self._serve_static_file)
+ )
+
+ self.collect_routes(routes)
+
+ app = RuleAdapter(routes)
+
+ if self.submount:
+ router.add(Submount(self.submount, [app]))
+ LOG.info(
+ "%s extension available at %s%s",
+ self.name,
+ config.external_service_url(),
+ self.submount,
+ )
+
+ if self.subdomain:
+ router.add(WithHost(f"{self.subdomain}.<__host__>", [app]))
+ self._configure_cors_for_subdomain()
+ LOG.info(
+ "%s extension available at %s",
+ self.name,
+ config.external_service_url(subdomains=self.subdomain),
+ )
+
+ def _serve_static_file(self, _request: http.Request, path: str):
+ """Route for serving static files, for ``/_extension/my-extension/static/``."""
+ return http.Response.for_resource(self.static_resource_module, path)
+
+ def _configure_cors_for_subdomain(self):
+ """
+ Automatically configures CORS for the subdomain, for both HTTP and HTTPS.
+ """
+ from localstack.aws.handlers.cors import ALLOWED_CORS_ORIGINS
+
+ for protocol in ("http", "https"):
+ url = self.get_subdomain_url(protocol)
+ LOG.debug("adding %s to ALLOWED_CORS_ORIGINS", url)
+ ALLOWED_CORS_ORIGINS.append(url)
+
+ def get_subdomain_url(self, protocol: str = "https") -> str:
+ """
+ Returns the URL that serves the extension under its subdomain
+ ``https://my-extension.localhost.localstack.cloud:4566/``.
+
+ :return: a URL this extension is served at
+ """
+ if not self.subdomain:
+ raise ValueError(f"Subdomain for extension {self.name} is not set")
+ return config.external_service_url(subdomains=self.subdomain, protocol=protocol)
+
+ def get_submount_url(self, protocol: str = "https") -> str:
+ """
+ Returns the URL that serves the extension under its submount
+ ``https://localhost.localstack.cloud:4566/_extension/my-extension``.
+
+ :return: a URL this extension is served at
+ """
+
+ if not self.submount:
+ raise ValueError(f"Submount for extension {self.name} is not set")
+
+ return f"{config.external_service_url(protocol=protocol)}{self.submount}"
+
+ @classmethod
+ def get_extension_module_root(cls) -> str:
+ """
+ Returns the root of the extension module. For instance, if the extension lives in
+ ``my_extension/plugins/extension.py``, then this will return ``my_extension``. Used to set up the
+ logger as well as the template environment and the static file module.
+
+ :return: the root module the extension lives in
+ """
+ return cls.__module__.split(".")[0]
diff --git a/localstack-core/localstack/http/__init__.py b/localstack-core/localstack/http/__init__.py
new file mode 100644
index 0000000000000..d72ef9d669d66
--- /dev/null
+++ b/localstack-core/localstack/http/__init__.py
@@ -0,0 +1,6 @@
+from .request import Request
+from .resource import Resource, resource
+from .response import Response
+from .router import Router, route
+
+__all__ = ["route", "resource", "Resource", "Router", "Response", "Request"]
diff --git a/localstack-core/localstack/http/asgi.py b/localstack-core/localstack/http/asgi.py
new file mode 100644
index 0000000000000..8ba3dd3454bd3
--- /dev/null
+++ b/localstack-core/localstack/http/asgi.py
@@ -0,0 +1,21 @@
+from rolo.asgi import (
+ ASGIAdapter,
+ ASGILifespanListener,
+ RawHTTPRequestEventStreamAdapter,
+ WebSocketEnvironment,
+ WebSocketListener,
+ WsgiStartResponse,
+ create_wsgi_input,
+ populate_wsgi_environment,
+)
+
+__all__ = [
+ "WebSocketEnvironment",
+ "populate_wsgi_environment",
+ "create_wsgi_input",
+ "RawHTTPRequestEventStreamAdapter",
+ "WsgiStartResponse",
+ "ASGILifespanListener",
+ "WebSocketListener",
+ "ASGIAdapter",
+]
diff --git a/localstack-core/localstack/http/client.py b/localstack-core/localstack/http/client.py
new file mode 100644
index 0000000000000..cb8f4b33aee31
--- /dev/null
+++ b/localstack-core/localstack/http/client.py
@@ -0,0 +1,7 @@
+from rolo.client import HttpClient, SimpleRequestsClient, make_request
+
+__all__ = [
+ "HttpClient",
+ "SimpleRequestsClient",
+ "make_request",
+]
diff --git a/localstack-core/localstack/http/dispatcher.py b/localstack-core/localstack/http/dispatcher.py
new file mode 100644
index 0000000000000..308450fbd3296
--- /dev/null
+++ b/localstack-core/localstack/http/dispatcher.py
@@ -0,0 +1,25 @@
+from json import JSONEncoder
+from typing import Type
+
+from rolo.routing.handler import Handler, ResultValue
+from rolo.routing.handler import handler_dispatcher as _handler_dispatcher
+from rolo.routing.router import Dispatcher
+
+from localstack.utils.json import CustomEncoder
+
+__all__ = [
+ "ResultValue",
+ "Handler",
+ "handler_dispatcher",
+]
+
+
+def handler_dispatcher(json_encoder: Type[JSONEncoder] = None) -> Dispatcher[Handler]:
+ """
+ Replacement for ``rolo.dispatcher.handler_dispatcher`` that uses by default LocalStack's CustomEncoder for
+ serializing JSON documents.
+
+ :param json_encoder: the encoder to use
+ :return: a Dispatcher that dispatches to instances of a Handler
+ """
+ return _handler_dispatcher(json_encoder or CustomEncoder)
diff --git a/localstack-core/localstack/http/duplex_socket.py b/localstack-core/localstack/http/duplex_socket.py
new file mode 100644
index 0000000000000..8006f398668e5
--- /dev/null
+++ b/localstack-core/localstack/http/duplex_socket.py
@@ -0,0 +1,77 @@
+from __future__ import annotations
+
+import logging
+import socket
+import ssl
+from asyncio.selector_events import BaseSelectorEventLoop
+
+from localstack.utils.asyncio import run_sync
+from localstack.utils.objects import singleton_factory
+from localstack.utils.patch import Patch, patch
+
+# set up logger
+LOG = logging.getLogger(__name__)
+
+
+class DuplexSocket(ssl.SSLSocket):
+ """Simple duplex socket wrapper that allows serving HTTP/HTTPS over the same port."""
+
+ def accept(self):
+ newsock, addr = socket.socket.accept(self)
+ if DuplexSocket.is_ssl_socket(newsock) is not False:
+ newsock = self.context.wrap_socket(
+ newsock,
+ do_handshake_on_connect=self.do_handshake_on_connect,
+ suppress_ragged_eofs=self.suppress_ragged_eofs,
+ server_side=True,
+ )
+
+ return newsock, addr
+
+ @staticmethod
+ def is_ssl_socket(newsock):
+ """Returns True/False if the socket uses SSL or not, or None if the status cannot be
+ determined"""
+
+ def peek_ssl_header():
+ peek_bytes = 5
+ first_bytes = newsock.recv(peek_bytes, socket.MSG_PEEK)
+ if len(first_bytes or "") != peek_bytes:
+ return
+ first_byte = first_bytes[0]
+ return first_byte < 32 or first_byte >= 127
+
+ try:
+ return peek_ssl_header()
+ except Exception:
+ # Fix for "[Errno 11] Resource temporarily unavailable" - This can
+ # happen if we're using a non-blocking socket in a blocking thread.
+ newsock.setblocking(1)
+ newsock.settimeout(1)
+ try:
+ return peek_ssl_header()
+ except Exception:
+ return False
+
+
+@singleton_factory
+def enable_duplex_socket():
+ """
+ Function which replaces the ssl.SSLContext.sslsocket_class with the DuplexSocket, enabling serving both,
+ HTTP and HTTPS connections on a single port.
+ """
+
+ # set globally defined SSL socket implementation class
+ Patch(ssl.SSLContext, "sslsocket_class", DuplexSocket).apply()
+
+ if hasattr(BaseSelectorEventLoop, "_accept_connection2"):
+
+ @patch(BaseSelectorEventLoop._accept_connection2)
+ async def _accept_connection2(
+ fn, self, protocol_factory, conn, extra, sslcontext, *args, **kwargs
+ ):
+ is_ssl_socket = await run_sync(DuplexSocket.is_ssl_socket, conn)
+ if is_ssl_socket is False:
+ sslcontext = None
+ result = await fn(self, protocol_factory, conn, extra, sslcontext, *args, **kwargs)
+ return result
diff --git a/localstack-core/localstack/http/hypercorn.py b/localstack-core/localstack/http/hypercorn.py
new file mode 100644
index 0000000000000..e14f2e167c797
--- /dev/null
+++ b/localstack-core/localstack/http/hypercorn.py
@@ -0,0 +1,146 @@
+import asyncio
+import threading
+from asyncio import AbstractEventLoop
+
+from hypercorn import Config
+from hypercorn.asyncio import serve
+from hypercorn.typing import ASGIFramework
+
+from localstack.aws.gateway import Gateway
+from localstack.aws.handlers.proxy import ProxyHandler
+from localstack.aws.serving.asgi import AsgiGateway
+from localstack.config import HostAndPort
+from localstack.logging.setup import setup_hypercorn_logger
+from localstack.utils.collections import ensure_list
+from localstack.utils.functions import call_safe
+from localstack.utils.serving import Server
+from localstack.utils.ssl import create_ssl_cert, install_predefined_cert_if_available
+
+
+class HypercornServer(Server):
+ """
+ A sync wrapper around Hypercorn that implements the ``Server`` interface.
+ """
+
+ def __init__(self, app: ASGIFramework, config: Config, loop: AbstractEventLoop = None):
+ """
+ Create a new Hypercorn server instance. Note that, if you pass an event loop to the constructor,
+ you are yielding control of that event loop to the server, as it will invoke `run_until_complete` and
+ shutdown the loop.
+
+ :param app: the ASGI3 app
+ :param config: the hypercorn config
+ :param loop: optionally the event loop, otherwise ``asyncio.new_event_loop`` will be called
+ """
+ self.app = app
+ self.config = config
+ self.loop = loop or asyncio.new_event_loop()
+
+ self._close = asyncio.Event()
+ self._closed = threading.Event()
+
+ parts = config.bind[0].split(":")
+ if len(parts) == 1:
+ # check ssl
+ host = parts[0]
+ port = 443 if config.ssl_enabled else 80
+ else:
+ host, port = parts[0], int(parts[1])
+
+ super().__init__(port, host)
+
+ @property
+ def protocol(self):
+ return "https" if self.config.ssl_enabled else "http"
+
+ def do_run(self):
+ self.loop.run_until_complete(
+ serve(self.app, self.config, shutdown_trigger=self._shutdown_trigger)
+ )
+ self._closed.set()
+
+ def do_shutdown(self):
+ asyncio.run_coroutine_threadsafe(self._set_closed(), self.loop)
+ self._closed.wait(timeout=10)
+ asyncio.run_coroutine_threadsafe(self.loop.shutdown_asyncgens(), self.loop)
+ self.loop.shutdown_default_executor()
+ self.loop.stop()
+ call_safe(self.loop.close)
+
+ async def _set_closed(self):
+ self._close.set()
+
+ async def _shutdown_trigger(self):
+ await self._close.wait()
+
+
+class GatewayServer(HypercornServer):
+ """
+ A Hypercorn-based server implementation which serves a given Gateway.
+ It can be used to easily spawn new gateway servers, defining their individual request-, response-, and
+ exception-handlers.
+ """
+
+ def __init__(
+ self,
+ gateway: Gateway,
+ listen: HostAndPort | list[HostAndPort],
+ use_ssl: bool = False,
+ threads: int | None = None,
+ ):
+ """
+ Creates a new GatewayServer instance.
+
+ :param gateway: which will be served by this server
+ :param listen: defining the address and port pairs this server binds to. Can be a list of host and port pairs.
+ :param use_ssl: True if the LocalStack cert should be loaded and HTTP/HTTPS multiplexing should be enabled.
+ :param threads: Number of worker threads the gateway will use.
+ """
+ # build server config
+ config = Config()
+ config.h11_pass_raw_headers = True
+ setup_hypercorn_logger(config)
+
+ listens = ensure_list(listen)
+ config.bind = [str(host_and_port) for host_and_port in listens]
+
+ if use_ssl:
+ install_predefined_cert_if_available()
+ serial_number = listens[0].port
+ _, cert_file_name, key_file_name = create_ssl_cert(serial_number=serial_number)
+ config.certfile = cert_file_name
+ config.keyfile = key_file_name
+
+ # build gateway
+ loop = asyncio.new_event_loop()
+ app = AsgiGateway(gateway, event_loop=loop, threads=threads)
+
+ # start serving gateway
+ super().__init__(app, config, loop)
+
+ def do_shutdown(self):
+ super().do_shutdown()
+ self.app.close() # noqa (app will be of type AsgiGateway)
+
+
+class ProxyServer(GatewayServer):
+ """
+ Proxy server implementation which uses the localstack.http.proxy module.
+ These server instances can be spawned easily, while implementing HTTP/HTTPS multiplexing (if enabled),
+ and just forward all incoming requests to a backend.
+ """
+
+ def __init__(
+ self, forward_base_url: str, listen: HostAndPort | list[HostAndPort], use_ssl: bool = False
+ ):
+ """
+ Creates a new ProxyServer instance.
+
+ :param forward_base_url: URL of the backend system all requests this server receives should be forwarded to
+ :param port: defining the port of this server instance
+ :param bind_address: to bind this server instance to. Can be a host string or a list of host strings.
+ :param use_ssl: True if the LocalStack cert should be loaded and HTTP/HTTPS multiplexing should be enabled.
+ """
+ gateway = Gateway()
+ gateway.request_handlers.append(ProxyHandler(forward_base_url=forward_base_url))
+ super().__init__(gateway, listen, use_ssl)
diff --git a/localstack-core/localstack/http/proxy.py b/localstack-core/localstack/http/proxy.py
new file mode 100644
index 0000000000000..35cf74719277a
--- /dev/null
+++ b/localstack-core/localstack/http/proxy.py
@@ -0,0 +1,7 @@
+from rolo.proxy import Proxy, ProxyHandler, forward
+
+__all__ = [
+ "forward",
+ "Proxy",
+ "ProxyHandler",
+]
diff --git a/localstack-core/localstack/http/request.py b/localstack-core/localstack/http/request.py
new file mode 100644
index 0000000000000..411ead4ab6bde
--- /dev/null
+++ b/localstack-core/localstack/http/request.py
@@ -0,0 +1,21 @@
+from rolo.request import (
+ Request,
+ dummy_wsgi_environment,
+ get_full_raw_path,
+ get_raw_base_url,
+ get_raw_current_url,
+ get_raw_path,
+ restore_payload,
+ set_environment_headers,
+)
+
+__all__ = [
+ "dummy_wsgi_environment",
+ "set_environment_headers",
+ "Request",
+ "get_raw_path",
+ "get_full_raw_path",
+ "get_raw_base_url",
+ "get_raw_current_url",
+ "restore_payload",
+]
diff --git a/localstack-core/localstack/http/resource.py b/localstack-core/localstack/http/resource.py
new file mode 100644
index 0000000000000..40db6d941b0aa
--- /dev/null
+++ b/localstack-core/localstack/http/resource.py
@@ -0,0 +1,6 @@
+from rolo.resource import Resource, resource
+
+__all__ = [
+ "resource",
+ "Resource",
+]
diff --git a/localstack/services/firehose/__init__.py b/localstack-core/localstack/http/resources/__init__.py
similarity index 100%
rename from localstack/services/firehose/__init__.py
rename to localstack-core/localstack/http/resources/__init__.py
diff --git a/localstack/services/kinesis/__init__.py b/localstack-core/localstack/http/resources/swagger/__init__.py
similarity index 100%
rename from localstack/services/kinesis/__init__.py
rename to localstack-core/localstack/http/resources/swagger/__init__.py
diff --git a/localstack-core/localstack/http/resources/swagger/endpoints.py b/localstack-core/localstack/http/resources/swagger/endpoints.py
new file mode 100644
index 0000000000000..f6cef4c9a33f8
--- /dev/null
+++ b/localstack-core/localstack/http/resources/swagger/endpoints.py
@@ -0,0 +1,25 @@
+import os
+
+from jinja2 import Environment, FileSystemLoader
+from rolo import Request, route
+
+from localstack.config import external_service_url
+from localstack.http import Response
+
+
+def _get_service_url(request: Request) -> str:
+ # special case for ephemeral instances
+ if "sandbox.localstack.cloud" in request.host:
+ return external_service_url(protocol="https", port=443)
+ return external_service_url(protocol=request.scheme)
+
+
+class SwaggerUIApi:
+ @route("/_localstack/swagger", methods=["GET"])
+ def server_swagger_ui(self, request: Request) -> Response:
+ init_path = f"{_get_service_url(request)}/openapi.yaml"
+ oas_path = os.path.join(os.path.dirname(__file__), "templates")
+ env = Environment(loader=FileSystemLoader(oas_path))
+ template = env.get_template("index.html")
+ rendered_template = template.render(swagger_url=init_path)
+ return Response(rendered_template, content_type="text/html")
diff --git a/localstack-core/localstack/http/resources/swagger/plugins.py b/localstack-core/localstack/http/resources/swagger/plugins.py
new file mode 100644
index 0000000000000..2e464f50deacd
--- /dev/null
+++ b/localstack-core/localstack/http/resources/swagger/plugins.py
@@ -0,0 +1,23 @@
+import werkzeug
+import yaml
+from rolo.routing import RuleAdapter
+
+from localstack.http.resources.swagger.endpoints import SwaggerUIApi
+from localstack.runtime import hooks
+from localstack.services.edge import ROUTER
+from localstack.services.internal import get_internal_apis
+from localstack.utils.openapi import get_localstack_openapi_spec
+
+
+@hooks.on_infra_start()
+def register_swagger_endpoints():
+ get_internal_apis().add(SwaggerUIApi())
+
+ def _serve_openapi_spec(_request):
+ spec = get_localstack_openapi_spec()
+ response_body = yaml.dump(spec)
+ return werkzeug.Response(
+ response_body, content_type="application/yaml", direct_passthrough=True
+ )
+
+ ROUTER.add(RuleAdapter("/openapi.yaml", _serve_openapi_spec))
diff --git a/localstack-core/localstack/http/resources/swagger/templates/index.html b/localstack-core/localstack/http/resources/swagger/templates/index.html
new file mode 100644
index 0000000000000..a852b132deb56
--- /dev/null
+++ b/localstack-core/localstack/http/resources/swagger/templates/index.html
@@ -0,0 +1,22 @@
+
+
+
+
+
+
+ SwaggerUI
+
+
+
+
+
+
+
+
diff --git a/localstack-core/localstack/http/response.py b/localstack-core/localstack/http/response.py
new file mode 100644
index 0000000000000..66863c147d370
--- /dev/null
+++ b/localstack-core/localstack/http/response.py
@@ -0,0 +1,22 @@
+from json import JSONEncoder
+from typing import Any, Type
+
+from rolo import Response as RoloResponse
+
+from localstack.utils.common import CustomEncoder
+
+
+class Response(RoloResponse):
+ """
+ An HTTP Response object, which simply extends werkzeug's Response object with a few convenience methods.
+ """
+
+ def set_json(self, doc: Any, cls: Type[JSONEncoder] = CustomEncoder):
+ """
+ Serializes the given dictionary using localstack's ``CustomEncoder`` into a json response, and sets the
+ mimetype automatically to ``application/json``.
+
+ :param doc: the response dictionary to be serialized as JSON
+ :param cls: the json encoder used
+ """
+ return super().set_json(doc, cls or CustomEncoder)
diff --git a/localstack-core/localstack/http/router.py b/localstack-core/localstack/http/router.py
new file mode 100644
index 0000000000000..da3bcdfe043c0
--- /dev/null
+++ b/localstack-core/localstack/http/router.py
@@ -0,0 +1,52 @@
+from typing import (
+ Any,
+ Mapping,
+ TypeVar,
+)
+
+from rolo.routing import (
+ PortConverter,
+ RegexConverter,
+ Router,
+ RuleAdapter,
+ RuleGroup,
+ WithHost,
+ route,
+)
+from rolo.routing.router import Dispatcher, call_endpoint
+from werkzeug.routing import PathConverter
+
+HTTP_METHODS = ("GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS", "TRACE")
+
+E = TypeVar("E")
+RequestArguments = Mapping[str, Any]
+
+
+class GreedyPathConverter(PathConverter):
+ """
+ This converter makes sure that the path ``/mybucket//mykey`` can be matched to the pattern
+ ``/`` and will result in `Key` being `/mykey`.
+ """
+
+ regex = ".*?"
+
+ part_isolating = False
+ """From the werkzeug docs: If a custom converter can match a forward slash, /, it should have the
+ attribute part_isolating set to False. This will ensure that rules using the custom converter are
+ correctly matched."""
+
+
+__all__ = [
+ "RequestArguments",
+ "HTTP_METHODS",
+ "RegexConverter",
+ "PortConverter",
+ "Dispatcher",
+ "route",
+ "call_endpoint",
+ "Router",
+ "RuleAdapter",
+ "WithHost",
+ "RuleGroup",
+ "GreedyPathConverter",
+]
diff --git a/localstack-core/localstack/http/trace.py b/localstack-core/localstack/http/trace.py
new file mode 100644
index 0000000000000..7d52b9ebf36dc
--- /dev/null
+++ b/localstack-core/localstack/http/trace.py
@@ -0,0 +1,348 @@
+import dataclasses
+import inspect
+import logging
+import time
+from typing import Any, Callable
+
+from rolo import Response
+from rolo.gateway import ExceptionHandler, Handler, HandlerChain, RequestContext
+from werkzeug.datastructures import Headers
+
+from localstack.utils.patch import Patch, Patches
+
+LOG = logging.getLogger(__name__)
+
+
+class Action:
+ """
+ Encapsulates something that the handler performed on the request context, request, or response objects.
+ """
+
+ name: str
+
+ def __init__(self, name: str):
+ self.name = name
+
+ def __repr__(self):
+ return self.name
+
+
+class SetAttributeAction(Action):
+ """
+ The handler set an attribute of the request context or something else.
+ """
+
+ key: str
+ value: Any | None
+
+ def __init__(self, key: str, value: Any | None = None):
+ super().__init__("set")
+ self.key = key
+ self.value = value
+
+ def __repr__(self):
+ if self.value is None:
+ return f"set {self.key}"
+ return f"set {self.key} = {self.value!r}"
+
+
+class ModifyHeadersAction(Action):
+ """
+ The handler modified headers in some way, either adding, updating, or removing headers.
+ """
+
+ def __init__(self, name: str, before: Headers, after: Headers):
+ super().__init__(name)
+ self.before = before
+ self.after = after
+
+ @property
+ def header_actions(self) -> list[Action]:
+ after = self.after
+ before = self.before
+
+ actions = []
+
+ headers_set = dict(set(after.items()) - set(before.items()))
+ headers_removed = {k: v for k, v in before.items() if k not in after}
+
+ for k, v in headers_set.items():
+ actions.append(Action(f"set '{k}: {v}'"))
+ for k, v in headers_removed.items():
+ actions.append(Action(f"del '{k}: {v}'"))
+
+ return actions
+
+
+@dataclasses.dataclass
+class HandlerTrace:
+ handler: Handler
+ """The handler"""
+ duration_ms: float
+ """The runtime duration of the handler in milliseconds"""
+ actions: list[Action]
+ """The actions the handler chain performed"""
+
+ @property
+ def handler_module(self):
+ return self.handler.__module__
+
+ @property
+ def handler_name(self):
+ if inspect.isfunction(self.handler):
+ return self.handler.__name__
+ else:
+ return self.handler.__class__.__name__
+
+
+def _log_method_call(name: str, actions: list[Action]):
+ """Creates a wrapper around the original method `_fn`. It appends an action to the `actions`
+ list indicating that the function was called and then returns the original function."""
+
+ def _proxy(self, _fn, *args, **kwargs):
+ actions.append(Action(f"call {name}"))
+ return _fn(*args, **kwargs)
+
+ return _proxy
+
+
+class TracingHandlerBase:
+ """
+ This class is a Handler that records a trace of the execution of another request handler. It has two
+ attributes: `trace`, which stores the tracing information, and `delegate`, which is the handler or
+ exception handler that will be traced.
+ """
+
+ trace: HandlerTrace | None
+ delegate: Handler | ExceptionHandler
+
+ def __init__(self, delegate: Handler | ExceptionHandler):
+ self.trace = None
+ self.delegate = delegate
+
+ def do_trace_call(
+ self, fn: Callable, chain: HandlerChain, context: RequestContext, response: Response
+ ):
+ """
+ Wraps the function call with the tracing functionality and records a HandlerTrace.
+
+ The method determines changes made by the request handler to specific aspects of the request.
+ Changes made to the request context and the response headers/status by the request handler are then
+ examined, and appropriate actions are added to the `actions` list of the trace.
+
+ :param fn: which is the function to be traced, which is the request/response/exception handler
+ :param chain: the handler chain
+ :param context: the request context
+ :param response: the response object
+ """
+ then = time.perf_counter()
+
+ actions = []
+
+ prev_context = dict(context.__dict__)
+ prev_stopped = chain.stopped
+ prev_request_identity = id(context.request)
+ prev_terminated = chain.terminated
+ prev_request_headers = context.request.headers.copy()
+ prev_response_headers = response.headers.copy()
+ prev_response_status = response.status_code
+
+ # add patches to log invocations or certain functions
+ patches = Patches(
+ [
+ Patch.function(
+ context.request.get_data,
+ _log_method_call("request.get_data", actions),
+ ),
+ Patch.function(
+ context.request._load_form_data,
+ _log_method_call("request._load_form_data", actions),
+ ),
+ Patch.function(
+ response.get_data,
+ _log_method_call("response.get_data", actions),
+ ),
+ ]
+ )
+ patches.apply()
+
+ try:
+ return fn()
+ finally:
+ now = time.perf_counter()
+ # determine some basic things the handler changed in the context
+ patches.undo()
+
+ # chain
+ if chain.stopped and not prev_stopped:
+ actions.append(Action("stop chain"))
+ if chain.terminated and not prev_terminated:
+ actions.append(Action("terminate chain"))
+
+ # detect when attributes are set in the request contex
+ context_args = dict(context.__dict__)
+ context_args.pop("request", None) # request is handled separately
+
+ for k, v in context_args.items():
+ if not v:
+ continue
+ if prev_context.get(k):
+ # TODO: we could introduce "ModifyAttributeAction(k,v)" with an additional check
+ # ``if v != prev_context.get(k)``
+ continue
+ actions.append(SetAttributeAction(k, v))
+
+ # request
+ if id(context.request) != prev_request_identity:
+ actions.append(Action("replaced request object"))
+
+ # response
+ if response.status_code != prev_response_status:
+ actions.append(SetAttributeAction("response stats_code", response.status_code))
+ if context.request.headers != prev_request_headers:
+ actions.append(
+ ModifyHeadersAction(
+ "modify request headers",
+ prev_request_headers,
+ context.request.headers.copy(),
+ )
+ )
+ if response.headers != prev_response_headers:
+ actions.append(
+ ModifyHeadersAction(
+ "modify response headers", prev_response_headers, response.headers.copy()
+ )
+ )
+
+ self.trace = HandlerTrace(
+ handler=self.delegate, duration_ms=(now - then) * 1000, actions=actions
+ )
+
+
+class TracingHandler(TracingHandlerBase):
+ delegate: Handler
+
+ def __init__(self, delegate: Handler):
+ super().__init__(delegate)
+
+ def __call__(self, chain: HandlerChain, context: RequestContext, response: Response):
+ def _call():
+ return self.delegate(chain, context, response)
+
+ return self.do_trace_call(_call, chain, context, response)
+
+
+class TracingExceptionHandler(TracingHandlerBase):
+ delegate: ExceptionHandler
+
+ def __init__(self, delegate: ExceptionHandler):
+ super().__init__(delegate)
+
+ def __call__(
+ self, chain: HandlerChain, exception: Exception, context: RequestContext, response: Response
+ ):
+ def _call():
+ return self.delegate(chain, exception, context, response)
+
+ return self.do_trace_call(_call, chain, context, response)
+
+
+class TracingHandlerChain(HandlerChain):
+ """
+ DebuggingHandlerChain - A subclass of HandlerChain for logging and tracing handlers.
+
+ Attributes:
+ - duration (float): Total time taken for handling request in milliseconds.
+ - request_handler_traces (list[HandlerTrace]): List of request handler traces.
+ - response_handler_traces (list[HandlerTrace]): List of response handler traces.
+ - finalizer_traces (list[HandlerTrace]): List of finalizer traces.
+ - exception_handler_traces (list[HandlerTrace]): List of exception handler traces.
+ """
+
+ duration: float
+ request_handler_traces: list[HandlerTrace]
+ response_handler_traces: list[HandlerTrace]
+ finalizer_traces: list[HandlerTrace]
+ exception_handler_traces: list[HandlerTrace]
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.request_handler_traces = []
+ self.response_handler_traces = []
+ self.finalizer_traces = []
+ self.exception_handler_traces = []
+
+ def handle(self, context: RequestContext, response: Response):
+ """Overrides HandlerChain's handle method and adds tracing handler to request handlers. Logs the trace
+ report with request and response details."""
+ then = time.perf_counter()
+ try:
+ self.request_handlers = [TracingHandler(handler) for handler in self.request_handlers]
+ return super().handle(context, response)
+ finally:
+ self.duration = (time.perf_counter() - then) * 1000
+ self.request_handler_traces = [handler.trace for handler in self.request_handlers]
+ self._log_report()
+
+ def _call_response_handlers(self, response):
+ self.response_handlers = [TracingHandler(handler) for handler in self.response_handlers]
+ try:
+ return super()._call_response_handlers(response)
+ finally:
+ self.response_handler_traces = [handler.trace for handler in self.response_handlers]
+
+ def _call_finalizers(self, response):
+ self.finalizers = [TracingHandler(handler) for handler in self.finalizers]
+ try:
+ return super()._call_response_handlers(response)
+ finally:
+ self.finalizer_traces = [handler.trace for handler in self.finalizers]
+
+ def _call_exception_handlers(self, e, response):
+ self.exception_handlers = [
+ TracingExceptionHandler(handler) for handler in self.exception_handlers
+ ]
+ try:
+ return super()._call_exception_handlers(e, response)
+ finally:
+ self.exception_handler_traces = [handler.trace for handler in self.exception_handlers]
+
+ def _log_report(self):
+ report = []
+ request = self.context.request
+ response = self.response
+
+ def _append_traces(traces: list[HandlerTrace]):
+ """Format and appends a list of traces to the report, and recursively append the trace's
+ actions (if any)."""
+
+ for trace in traces:
+ if trace is None:
+ continue
+
+ report.append(
+ f"{trace.handler_module:43s} {trace.handler_name:30s} {trace.duration_ms:8.2f}ms"
+ )
+ _append_actions(trace.actions, 46)
+
+ def _append_actions(actions: list[Action], indent: int):
+ for action in actions:
+ report.append((" " * indent) + f"- {action!r}")
+
+ if isinstance(action, ModifyHeadersAction):
+ _append_actions(action.header_actions, indent + 2)
+
+ report.append(f"request: {request.method} {request.url}")
+ report.append(f"response: {response.status_code}")
+ report.append("---- request handlers " + ("-" * 63))
+ _append_traces(self.request_handler_traces)
+ report.append("---- response handlers " + ("-" * 63))
+ _append_traces(self.response_handler_traces)
+ report.append("---- finalizers " + ("-" * 63))
+ _append_traces(self.finalizer_traces)
+ report.append("---- exception handlers " + ("-" * 63))
+ _append_traces(self.exception_handler_traces)
+ # Add a separator and total duration value to the end of the report
+ report.append(f"{'=' * 68} total {self.duration:8.2f}ms")
+
+ LOG.info("handler chain trace report:\n%s\n%s", "=" * 85, "\n".join(report))
diff --git a/localstack-core/localstack/http/websocket.py b/localstack-core/localstack/http/websocket.py
new file mode 100644
index 0000000000000..9bd92a927a998
--- /dev/null
+++ b/localstack-core/localstack/http/websocket.py
@@ -0,0 +1,15 @@
+from rolo.websocket.websocket import (
+ WebSocket,
+ WebSocketDisconnectedError,
+ WebSocketError,
+ WebSocketProtocolError,
+ WebSocketRequest,
+)
+
+__all__ = [
+ "WebSocketError",
+ "WebSocketDisconnectedError",
+ "WebSocketProtocolError",
+ "WebSocket",
+ "WebSocketRequest",
+]
diff --git a/localstack/services/s3/__init__.py b/localstack-core/localstack/logging/__init__.py
similarity index 100%
rename from localstack/services/s3/__init__.py
rename to localstack-core/localstack/logging/__init__.py
diff --git a/localstack-core/localstack/logging/format.py b/localstack-core/localstack/logging/format.py
new file mode 100644
index 0000000000000..09655928bc6f8
--- /dev/null
+++ b/localstack-core/localstack/logging/format.py
@@ -0,0 +1,162 @@
+"""Tools for formatting localstack logs."""
+
+import logging
+from functools import lru_cache
+from typing import Any, Dict
+
+from localstack.utils.numbers import format_bytes
+
+MAX_THREAD_NAME_LEN = 12
+MAX_NAME_LEN = 26
+
+LOG_FORMAT = f"%(asctime)s.%(msecs)03d %(ls_level)5s --- [%(ls_thread){MAX_THREAD_NAME_LEN}s] %(ls_name)-{MAX_NAME_LEN}s : %(message)s"
+LOG_DATE_FORMAT = "%Y-%m-%dT%H:%M:%S"
+LOG_INPUT_FORMAT = "%(input_type)s(%(input)s, headers=%(request_headers)s)"
+LOG_OUTPUT_FORMAT = "%(output_type)s(%(output)s, headers=%(response_headers)s)"
+LOG_CONTEXT_FORMAT = "%(account_id)s/%(region)s"
+
+CUSTOM_LEVEL_NAMES = {
+ 50: "FATAL",
+ 40: "ERROR",
+ 30: "WARN",
+ 20: "INFO",
+ 10: "DEBUG",
+}
+
+
+class DefaultFormatter(logging.Formatter):
+ """
+ A formatter that uses ``LOG_FORMAT`` and ``LOG_DATE_FORMAT``.
+ """
+
+ def __init__(self, fmt=LOG_FORMAT, datefmt=LOG_DATE_FORMAT):
+ super(DefaultFormatter, self).__init__(fmt=fmt, datefmt=datefmt)
+
+
+class AddFormattedAttributes(logging.Filter):
+ """
+ Filter that adds three attributes to a log record:
+
+ - ls_level: the abbreviated loglevel that's max 5 characters long
+ - ls_name: the abbreviated name of the logger (e.g., `l.bootstrap.install`), trimmed to ``MAX_NAME_LEN``
+ - ls_thread: the abbreviated thread name (prefix trimmed, .e.g, ``omeThread-108``)
+ """
+
+ max_name_len: int
+ max_thread_len: int
+
+ def __init__(self, max_name_len: int = None, max_thread_len: int = None):
+ super(AddFormattedAttributes, self).__init__()
+ self.max_name_len = max_name_len if max_name_len else MAX_NAME_LEN
+ self.max_thread_len = max_thread_len if max_thread_len else MAX_THREAD_NAME_LEN
+
+ def filter(self, record):
+ record.ls_level = CUSTOM_LEVEL_NAMES.get(record.levelno, record.levelname)
+ record.ls_name = self._get_compressed_logger_name(record.name)
+ record.ls_thread = record.threadName[-self.max_thread_len :]
+ return True
+
+ @lru_cache(maxsize=256)
+ def _get_compressed_logger_name(self, name):
+ return compress_logger_name(name, self.max_name_len)
+
+
+def compress_logger_name(name: str, length: int) -> str:
+ """
+ Creates a short version of a logger name. For example ``my.very.long.logger.name`` with length=17 turns into
+ ``m.v.l.logger.name``.
+
+ :param name: the logger name
+ :param length: the max length of the logger name
+ :return: the compressed name
+ """
+ if len(name) <= length:
+ return name
+
+ parts = name.split(".")
+ parts.reverse()
+
+ new_parts = []
+
+ # we start by assuming that all parts are collapsed
+ # x.x.x requires 5 = 2n - 1 characters
+ cur_length = (len(parts) * 2) - 1
+
+ for i in range(len(parts)):
+ # try to expand the current part and calculate the resulting length
+ part = parts[i]
+ next_len = cur_length + (len(part) - 1)
+
+ if next_len > length:
+ # if the resulting length would exceed the limit, add only the first letter of the parts of all remaining
+ # parts
+ new_parts += [p[0] for p in parts[i:]]
+
+ # but if this is the first item, that means we would display nothing, so at least display as much of the
+ # max length as possible
+ if i == 0:
+ remaining = length - cur_length
+ if remaining > 0:
+ new_parts[0] = part[: (remaining + 1)]
+
+ break
+
+ # expanding the current part, i.e., instead of using just the one character, we add the entire part
+ new_parts.append(part)
+ cur_length = next_len
+
+ new_parts.reverse()
+ return ".".join(new_parts)
+
+
+class TraceLoggingFormatter(logging.Formatter):
+ aws_trace_log_format = "; ".join([LOG_FORMAT, LOG_INPUT_FORMAT, LOG_OUTPUT_FORMAT])
+ bytes_length_display_threshold = 512
+
+ def __init__(self):
+ super().__init__(fmt=self.aws_trace_log_format, datefmt=LOG_DATE_FORMAT)
+
+ def _replace_large_payloads(self, input: Any) -> Any:
+ """
+ Replaces large payloads in the logs with placeholders to avoid cluttering the logs with huge bytes payloads.
+ :param input: Input/output extra passed when logging. If it is bytes, it will be replaced if larger than
+ bytes_length_display_threshold
+ :return: Input, unless it is bytes and longer than bytes_length_display_threshold, then `Bytes(length_of_input)`
+ """
+ if isinstance(input, bytes) and len(input) > self.bytes_length_display_threshold:
+ return f"Bytes({format_bytes(len(input))})"
+ return input
+
+ def format(self, record: logging.LogRecord) -> str:
+ record.input = self._replace_large_payloads(record.input)
+ record.output = self._replace_large_payloads(record.output)
+ return super().format(record=record)
+
+
+class AwsTraceLoggingFormatter(TraceLoggingFormatter):
+ aws_trace_log_format = "; ".join(
+ [LOG_FORMAT, LOG_CONTEXT_FORMAT, LOG_INPUT_FORMAT, LOG_OUTPUT_FORMAT]
+ )
+
+ def __init__(self):
+ super().__init__()
+
+ def _copy_service_dict(self, service_dict: Dict) -> Dict:
+ if not isinstance(service_dict, Dict):
+ return service_dict
+ result = {}
+ for key, value in service_dict.items():
+ if isinstance(value, dict):
+ result[key] = self._copy_service_dict(value)
+ elif isinstance(value, bytes) and len(value) > self.bytes_length_display_threshold:
+ result[key] = f"Bytes({format_bytes(len(value))})"
+ elif isinstance(value, list):
+ result[key] = [self._copy_service_dict(item) for item in value]
+ else:
+ result[key] = value
+ return result
+
+ def format(self, record: logging.LogRecord) -> str:
+ record.input = self._copy_service_dict(record.input)
+ record.output = self._copy_service_dict(record.output)
+ return super().format(record=record)
diff --git a/localstack-core/localstack/logging/setup.py b/localstack-core/localstack/logging/setup.py
new file mode 100644
index 0000000000000..444742083e687
--- /dev/null
+++ b/localstack-core/localstack/logging/setup.py
@@ -0,0 +1,130 @@
+import logging
+import sys
+import warnings
+
+from localstack import config, constants
+
+from .format import AddFormattedAttributes, DefaultFormatter
+
+# The log levels for modules are evaluated incrementally for logging granularity,
+# from highest (DEBUG) to lowest (TRACE_INTERNAL). Hence, each module below should have
+# higher level which serves as the default.
+
+default_log_levels = {
+ "asyncio": logging.INFO,
+ "boto3": logging.INFO,
+ "botocore": logging.ERROR,
+ "docker": logging.WARNING,
+ "elasticsearch": logging.ERROR,
+ "hpack": logging.ERROR,
+ "moto": logging.WARNING,
+ "requests": logging.WARNING,
+ "s3transfer": logging.INFO,
+ "urllib3": logging.WARNING,
+ "werkzeug": logging.WARNING,
+ "rolo": logging.WARNING,
+ "parse": logging.WARNING,
+ "localstack.aws.accounts": logging.INFO,
+ "localstack.aws.protocol.serializer": logging.INFO,
+ "localstack.aws.serving.wsgi": logging.WARNING,
+ "localstack.request": logging.INFO,
+ "localstack.request.internal": logging.WARNING,
+ "localstack.state.inspect": logging.INFO,
+ "localstack_persistence": logging.INFO,
+}
+
+trace_log_levels = {
+ "rolo": logging.DEBUG,
+ "localstack.aws.protocol.serializer": logging.DEBUG,
+ "localstack.aws.serving.wsgi": logging.DEBUG,
+ "localstack.request": logging.DEBUG,
+ "localstack.request.internal": logging.INFO,
+ "localstack.state.inspect": logging.DEBUG,
+}
+
+trace_internal_log_levels = {
+ "localstack.aws.accounts": logging.DEBUG,
+ "localstack.request.internal": logging.DEBUG,
+}
+
+
+def setup_logging_for_cli(log_level=logging.INFO):
+ logging.basicConfig(level=log_level)
+
+ # set log levels of loggers
+ logging.root.setLevel(log_level)
+ logging.getLogger("localstack").setLevel(log_level)
+ for logger, level in default_log_levels.items():
+ logging.getLogger(logger).setLevel(level)
+
+
+def get_log_level_from_config():
+ # overriding the log level if LS_LOG has been set
+ if config.LS_LOG:
+ log_level = str(config.LS_LOG).upper()
+ if log_level.lower() in constants.TRACE_LOG_LEVELS:
+ log_level = "DEBUG"
+ log_level = logging._nameToLevel[log_level]
+ return log_level
+
+ return logging.DEBUG if config.DEBUG else logging.INFO
+
+
+def setup_logging_from_config():
+ log_level = get_log_level_from_config()
+ setup_logging(log_level)
+
+ if config.is_trace_logging_enabled():
+ for name, level in trace_log_levels.items():
+ logging.getLogger(name).setLevel(level)
+ if config.LS_LOG == constants.LS_LOG_TRACE_INTERNAL:
+ for name, level in trace_internal_log_levels.items():
+ logging.getLogger(name).setLevel(level)
+
+
+def create_default_handler(log_level: int):
+ log_handler = logging.StreamHandler(stream=sys.stderr)
+ log_handler.setLevel(log_level)
+ log_handler.setFormatter(DefaultFormatter())
+ log_handler.addFilter(AddFormattedAttributes())
+ return log_handler
+
+
+def setup_logging(log_level=logging.INFO) -> None:
+ """
+ Configures the python logging environment for LocalStack.
+
+ :param log_level: the optional log level.
+ """
+ # set create a default handler for the root logger (basically logging.basicConfig but explicit)
+ log_handler = create_default_handler(log_level)
+
+ # replace any existing handlers
+ logging.basicConfig(level=log_level, handlers=[log_handler])
+
+ # disable some logs and warnings
+ warnings.filterwarnings("ignore")
+ logging.captureWarnings(True)
+
+ # set log levels of loggers
+ logging.root.setLevel(log_level)
+ logging.getLogger("localstack").setLevel(log_level)
+ for logger, level in default_log_levels.items():
+ logging.getLogger(logger).setLevel(level)
+
+
+def setup_hypercorn_logger(hypercorn_config) -> None:
+ """
+ Sets the hypercorn loggers, which are created in a peculiar way, to the localstack settings.
+
+ :param hypercorn_config: a hypercorn.Config object
+ """
+ logger = hypercorn_config.log.access_logger
+ if logger:
+ logger.handlers[0].addFilter(AddFormattedAttributes())
+ logger.handlers[0].setFormatter(DefaultFormatter())
+
+ logger = hypercorn_config.log.error_logger
+ if logger:
+ logger.handlers[0].addFilter(AddFormattedAttributes())
+ logger.handlers[0].setFormatter(DefaultFormatter())
diff --git a/localstack-core/localstack/openapi.yaml b/localstack-core/localstack/openapi.yaml
new file mode 100644
index 0000000000000..b3656c3f6f1af
--- /dev/null
+++ b/localstack-core/localstack/openapi.yaml
@@ -0,0 +1,1070 @@
+openapi: 3.1.0
+info:
+ contact:
+ email: info@localstack.cloud
+ name: LocalStack Support
+ url: https://www.localstack.cloud/contact
+ summary: The LocalStack REST API exposes functionality related to diagnostics, health
+ checks, plugins, initialisation hooks, service introspection, and more.
+ termsOfService: https://www.localstack.cloud/legal/tos
+ title: LocalStack REST API for Community
+ version: latest
+externalDocs:
+ description: LocalStack Documentation
+ url: https://docs.localstack.cloud
+servers:
+ - url: http://{host}:{port}
+ variables:
+ port:
+ default: '4566'
+ host:
+ default: 'localhost.localstack.cloud'
+components:
+ parameters:
+ SesIdFilter:
+ description: Filter for the `id` field in SES message
+ in: query
+ name: id
+ required: false
+ schema:
+ type: string
+ SesEmailFilter:
+ description: Filter for the `source` field in SES message
+ in: query
+ name: email
+ required: false
+ schema:
+ type: string
+ SnsAccountId:
+ description: '`accountId` field of the resource'
+ in: query
+ name: accountId
+ required: false
+ schema:
+ default: '000000000000'
+ type: string
+ SnsEndpointArn:
+ description: '`endpointArn` field of the resource'
+ in: query
+ name: endpointArn
+ required: false
+ schema:
+ type: string
+ SnsPhoneNumber:
+ description: '`phoneNumber` field of the resource'
+ in: query
+ name: phoneNumber
+ required: false
+ schema:
+ type: string
+ SnsRegion:
+ description: '`region` field of the resource'
+ in: query
+ name: region
+ required: false
+ schema:
+ default: us-east-1
+ type: string
+ schemas:
+ InitScripts:
+ additionalProperties: false
+ properties:
+ completed:
+ additionalProperties: false
+ properties:
+ BOOT:
+ type: boolean
+ READY:
+ type: boolean
+ SHUTDOWN:
+ type: boolean
+ START:
+ type: boolean
+ required:
+ - BOOT
+ - START
+ - READY
+ - SHUTDOWN
+ type: object
+ scripts:
+ items:
+ additionalProperties: false
+ properties:
+ name:
+ type: string
+ stage:
+ type: string
+ state:
+ type: string
+ required:
+ - stage
+ - name
+ - state
+ type: object
+ type: array
+ required:
+ - completed
+ - scripts
+ type: object
+ InitScriptsStage:
+ additionalProperties: false
+ properties:
+ completed:
+ type: boolean
+ scripts:
+ items:
+ additionalProperties: false
+ properties:
+ name:
+ type: string
+ stage:
+ type: string
+ state:
+ type: string
+ required:
+ - stage
+ - name
+ - state
+ type: object
+ type: array
+ required:
+ - completed
+ - scripts
+ type: object
+ SESDestination:
+ type: object
+ description: Possible destination of a SES message
+ properties:
+ ToAddresses:
+ type: array
+ items:
+ type: string
+ format: email
+ CcAddresses:
+ type: array
+ items:
+ type: string
+ format: email
+ BccAddresses:
+ type: array
+ items:
+ type: string
+ format: email
+ additionalProperties: false
+ SesSentEmail:
+ additionalProperties: false
+ properties:
+ Body:
+ additionalProperties: false
+ properties:
+ html_part:
+ type: string
+ text_part:
+ type: string
+ required:
+ - text_part
+ type: object
+ Destination:
+ $ref: '#/components/schemas/SESDestination'
+ Id:
+ type: string
+ RawData:
+ type: string
+ Region:
+ type: string
+ Source:
+ type: string
+ Subject:
+ type: string
+ Template:
+ type: string
+ TemplateData:
+ type: string
+ Timestamp:
+ type: string
+ required:
+ - Id
+ - Region
+ - Timestamp
+ - Source
+ type: object
+ SessionInfo:
+ additionalProperties: false
+ properties:
+ edition:
+ type: string
+ is_docker:
+ type: boolean
+ is_license_activated:
+ type: boolean
+ machine_id:
+ type: string
+ server_time_utc:
+ type: string
+ session_id:
+ type: string
+ system:
+ type: string
+ uptime:
+ type: integer
+ version:
+ type: string
+ required:
+ - version
+ - edition
+ - is_license_activated
+ - session_id
+ - machine_id
+ - system
+ - is_docker
+ - server_time_utc
+ - uptime
+ type: object
+ SnsSubscriptionTokenError:
+ additionalProperties: false
+ properties:
+ error:
+ type: string
+ subscription_arn:
+ type: string
+ required:
+ - error
+ - subscription_arn
+ type: object
+ SNSPlatformEndpointMessage:
+ type: object
+ description: Message sent to a platform endpoint via SNS
+ additionalProperties: false
+ properties:
+ TargetArn:
+ type: string
+ TopicArn:
+ type: string
+ Message:
+ type: string
+ MessageAttributes:
+ type: object
+ MessageStructure:
+ type: string
+ Subject:
+ type: [string, 'null']
+ MessageId:
+ type: string
+ SNSMessage:
+ type: object
+ description: Message sent via SNS
+ properties:
+ PhoneNumber:
+ type: string
+ TopicArn:
+ type: [string, 'null']
+ SubscriptionArn:
+ type: [string, 'null']
+ MessageId:
+ type: string
+ Message:
+ type: string
+ MessageAttributes:
+ type: object
+ MessageStructure:
+ type: [string, 'null']
+ Subject:
+ type: [string, 'null']
+ SNSPlatformEndpointMessages:
+ type: object
+ description: |
+ Messages sent to the platform endpoint retrieved via the retrospective endpoint.
+ The endpoint ARN is the key with a list of messages as value.
+ additionalProperties:
+ type: array
+ items:
+ $ref: '#/components/schemas/SNSPlatformEndpointMessage'
+ SMSMessages:
+ type: object
+ description: |
+ SMS messages retrieved via the retrospective endpoint.
+ The phone number is the key with a list of messages as value.
+ additionalProperties:
+ type: array
+ items:
+ $ref: '#/components/schemas/SNSMessage'
+ SNSPlatformEndpointResponse:
+ type: object
+ additionalProperties: false
+ description: Response payload for the /_aws/sns/platform-endpoint-messages endpoint
+ properties:
+ region:
+ type: string
+ description: "The AWS region, e.g., us-east-1"
+ platform_endpoint_messages:
+ $ref: '#/components/schemas/SNSPlatformEndpointMessages'
+ required:
+ - region
+ - platform_endpoint_messages
+ SNSSMSMessagesResponse:
+ type: object
+ additionalProperties: false
+ description: Response payload for the /_aws/sns/sms-messages endpoint
+ properties:
+ region:
+ type: string
+ description: "The AWS region, e.g., us-east-1"
+ sms_messages:
+ $ref: '#/components/schemas/SMSMessages'
+ required:
+ - region
+ - sms_messages
+ ReceiveMessageRequest:
+ type: object
+ description: https://github.com/boto/botocore/blob/develop/botocore/data/sqs/2012-11-05/service-2.json
+ required:
+ - QueueUrl
+ properties:
+ QueueUrl:
+ type: string
+ format: uri
+ AttributeNames:
+ type: array
+ items:
+ type: string
+ MessageSystemAttributeNames:
+ type: array
+ items:
+ type: string
+ MessageAttributeNames:
+ type: array
+ items:
+ type: string
+ MaxNumberOfMessages:
+ type: integer
+ VisibilityTimeout:
+ type: integer
+ WaitTimeSeconds:
+ type: integer
+ ReceiveRequestAttemptId:
+ type: string
+ ReceiveMessageResult:
+ type: object
+ description: https://github.com/boto/botocore/blob/develop/botocore/data/sqs/2012-11-05/service-2.json
+ properties:
+ Messages:
+ type: array
+ items:
+ $ref: '#/components/schemas/Message'
+ Message:
+ type: object
+ properties:
+ MessageId:
+ type: [string, 'null']
+ ReceiptHandle:
+ type: [string, 'null']
+ MD5OfBody:
+ type: [string, 'null']
+ Body:
+ type: [string, 'null']
+ Attributes:
+ type: object
+ MessageAttributes:
+ type: object
+ CloudWatchMetrics:
+ additionalProperties: false
+ properties:
+ metrics:
+ items:
+ additionalProperties: false
+ properties:
+ account:
+ description: Account ID
+ type: string
+ d:
+ description: Dimensions
+ items:
+ additionalProperties: false
+ properties:
+ n:
+ description: Dimension name
+ type: string
+ v:
+ description: Dimension value
+ oneOf:
+ - type: string
+ - type: integer
+ required:
+ - n
+ - v
+ type: object
+ type: array
+ n:
+ description: Metric name
+ type: string
+ ns:
+ description: Namespace
+ type: string
+ region:
+ description: Region name
+ type: string
+ t:
+ description: Timestamp
+ oneOf:
+ - type: string
+ format: date-time
+ - type: number
+ v:
+ description: Metric value
+ oneOf:
+ - type: string
+ - type: integer
+ required:
+ - ns
+ - n
+ - v
+ - t
+ - d
+ - account
+ - region
+ type: object
+ type: array
+ required:
+ - metrics
+ type: object
+paths:
+ /_aws/cloudwatch/metrics/raw:
+ get:
+ description: Retrieve CloudWatch metrics
+ operationId: get_cloudwatch_metrics
+ tags: [aws]
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/CloudWatchMetrics'
+ description: CloudWatch metrics
+ /_aws/dynamodb/expired:
+ delete:
+ description: Delete expired items from TTL-enabled DynamoDB tables
+ operationId: delete_ddb_expired_items
+ tags: [aws]
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ additionalProperties: false
+ properties:
+ ExpiredItems:
+ description: Number of expired items that were deleted
+ type: integer
+ required:
+ - ExpiredItems
+ type: object
+ description: Operation was successful
+ /_aws/events/rules/{rule_arn}/trigger:
+ get:
+ description: Trigger a scheduled EventBridge rule
+ operationId: trigger_event_bridge_rule
+ tags: [aws]
+ parameters:
+ - description: EventBridge rule ARN
+ in: path
+ name: rule_arn
+ required: true
+ schema:
+ type: string
+ responses:
+ '200':
+ description: EventBridge rule was triggered
+ '404':
+ description: Not found
+ /_aws/lambda/init:
+ get:
+ description: Retrieve Lambda runtime init binary
+ operationId: get_lambda_init
+ tags: [aws]
+ responses:
+ '200':
+ content:
+ application/octet-stream: {}
+ description: Lambda runtime init binary
+ /_aws/lambda/runtimes:
+ get:
+ description: List available Lambda runtimes
+ operationId: get_lambda_runtimes
+ tags: [aws]
+ parameters:
+ - in: query
+ name: filter
+ required: false
+ schema:
+ default: supported
+ enum:
+ - all
+ - deprecated
+ - supported
+ type: string
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ additionalProperties: false
+ properties:
+ Runtimes:
+ items:
+ type: string
+ type: array
+ required:
+ - Runtimes
+ type: object
+ description: Available Lambda runtimes
+ /_aws/ses:
+ delete:
+ description: Discard sent SES messages
+ operationId: discard_ses_messages
+ tags: [aws]
+ parameters:
+ - $ref: '#/components/parameters/SesIdFilter'
+ responses:
+ '204':
+ description: Message was successfully discarded
+ get:
+ description: Retrieve sent SES messages
+ operationId: get_ses_messages
+ tags: [aws]
+ parameters:
+ - $ref: '#/components/parameters/SesIdFilter'
+ - $ref: '#/components/parameters/SesEmailFilter'
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ additionalProperties: false
+ properties:
+ messages:
+ items:
+ $ref: '#/components/schemas/SesSentEmail'
+ type: array
+ required:
+ - messages
+ type: object
+ description: List of sent messages
+ /_aws/sns/platform-endpoint-messages:
+ delete:
+ description: Discard the messages published to a platform endpoint via SNS
+ operationId: discard_sns_endpoint_messages
+ tags: [aws]
+ parameters:
+ - $ref: '#/components/parameters/SnsAccountId'
+ - $ref: '#/components/parameters/SnsRegion'
+ - $ref: '#/components/parameters/SnsEndpointArn'
+ responses:
+ '204':
+ description: Platform endpoint message was discarded
+ get:
+ description: Retrieve the messages sent to a platform endpoint via SNS
+ operationId: get_sns_endpoint_messages
+ tags: [aws]
+ parameters:
+ - $ref: '#/components/parameters/SnsAccountId'
+ - $ref: '#/components/parameters/SnsRegion'
+ - $ref: '#/components/parameters/SnsEndpointArn'
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/SNSPlatformEndpointResponse"
+ description: SNS messages via retrospective access
+ /_aws/sns/sms-messages:
+ delete:
+ description: Discard SNS SMS messages
+ operationId: discard_sns_sms_messages
+ tags: [aws]
+ parameters:
+ - $ref: '#/components/parameters/SnsAccountId'
+ - $ref: '#/components/parameters/SnsRegion'
+ - $ref: '#/components/parameters/SnsPhoneNumber'
+ responses:
+ '204':
+ description: SMS message was discarded
+ get:
+ description: Retrieve SNS SMS messages
+ operationId: get_sns_sms_messages
+ tags: [aws]
+ parameters:
+ - $ref: '#/components/parameters/SnsAccountId'
+ - $ref: '#/components/parameters/SnsRegion'
+ - $ref: '#/components/parameters/SnsPhoneNumber'
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/SNSSMSMessagesResponse"
+ description: SNS messages via retrospective access
+ /_aws/sns/subscription-tokens/{subscription_arn}:
+ get:
+ description: Retrieve SNS subscription token for confirmation
+ operationId: get_sns_subscription_token
+ tags: [aws]
+ parameters:
+ - description: '`subscriptionArn` resource of subscription token'
+ in: path
+ name: subscription_arn
+ required: true
+ schema:
+ type: string
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ additionalProperties: false
+ properties:
+ subscription_arn:
+ type: string
+ subscription_token:
+ type: string
+ required:
+ - subscription_token
+ - subscription_arn
+ type: object
+ description: Subscription token
+ '400':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/SnsSubscriptionTokenError'
+ description: Bad request
+ '404':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/SnsSubscriptionTokenError'
+ description: Not found
+ /_aws/sqs/messages:
+ get:
+ description: List SQS queue messages without side effects
+ operationId: list_all_sqs_messages
+ tags: [aws]
+ parameters:
+ - description: SQS queue URL
+ in: query
+ name: QueueUrl
+ required: false
+ schema:
+ type: string
+ responses:
+ '200':
+ content:
+ text/xml:
+ schema:
+ $ref: '#/components/schemas/ReceiveMessageResult'
+ application/json:
+ schema:
+ $ref: '#/components/schemas/ReceiveMessageResult'
+ description: SQS queue messages
+ '400':
+ content:
+ text/xml: {}
+ application/json: {}
+ description: Bad request
+ '404':
+ content:
+ text/xml: {}
+ application/json: {}
+ description: Not found
+ post:
+ summary: Retrieves one or more messages from the specified queue.
+ description: |
+ This API receives messages from an SQS queue.
+ https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_ReceiveMessage.html#API_ReceiveMessage_ResponseSyntax
+ operationId: receive_message
+ requestBody:
+ required: true
+ content:
+ application/x-www-form-urlencoded:
+ schema:
+ $ref: '#/components/schemas/ReceiveMessageRequest'
+ application/json:
+ schema:
+ $ref: '#/components/schemas/ReceiveMessageRequest'
+ responses:
+ '200':
+ content:
+ text/xml: {}
+ application/json:
+ schema:
+ $ref: '#/components/schemas/ReceiveMessageResult'
+ description: SQS queue messages
+ '400':
+ content:
+ text/xml: {}
+ application/json: {}
+ description: Bad request
+ '404':
+ content:
+ text/xml: {}
+ application/json: {}
+ description: Not found
+ /_aws/sqs/messages/{region}/{account_id}/{queue_name}:
+ get:
+ description: List SQS messages without side effects
+ operationId: list_sqs_messages
+ tags: [aws]
+ parameters:
+ - description: SQS queue region
+ in: path
+ name: region
+ required: true
+ schema:
+ type: string
+ - description: SQS queue account ID
+ in: path
+ name: account_id
+ required: true
+ schema:
+ type: string
+ - description: SQS queue name
+ in: path
+ name: queue_name
+ required: true
+ schema:
+ type: string
+ responses:
+ '200':
+ content:
+ text/xml: {}
+ application/json:
+ schema:
+ $ref: '#/components/schemas/ReceiveMessageResult'
+ description: SQS queue messages
+ '400':
+ content:
+ text/xml: {}
+ application/json: {}
+ description: Bad request
+ '404':
+ content:
+ text/xml: {}
+ application/json: {}
+ description: Not found
+ /_localstack/config:
+ get:
+ description: Get current LocalStack configuration
+ operationId: get_config
+ tags: [localstack]
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ type: object
+ description: Current LocalStack configuration
+ post:
+ description: Configuration option to update with new value
+ operationId: update_config_option
+ tags: [localstack]
+ requestBody:
+ content:
+ application/json:
+ schema:
+ additionalProperties: false
+ properties:
+ value:
+ type:
+ - number
+ - string
+ variable:
+ pattern: ^[_a-zA-Z0-9]+$
+ type: string
+ required:
+ - variable
+ - value
+ type: object
+ required: true
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ additionalProperties: false
+ properties:
+ value:
+ type:
+ - number
+ - string
+ variable:
+ type: string
+ required:
+ - variable
+ - value
+ type: object
+ description: Configuration option is updated
+ '400':
+ content:
+ application/json: {}
+ description: Bad request
+ /_localstack/diagnose:
+ get:
+ description: Get diagnostics report
+ operationId: get_diagnostics
+ tags: [localstack]
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ additionalProperties: false
+ properties:
+ config:
+ type: object
+ docker-dependent-image-hosts:
+ type: object
+ docker-inspect:
+ type: object
+ file-tree:
+ type: object
+ important-endpoints:
+ type: object
+ info:
+ $ref: '#/components/schemas/SessionInfo'
+ logs:
+ additionalProperties: false
+ properties:
+ docker:
+ type: string
+ required:
+ - docker
+ type: object
+ services:
+ type: object
+ usage:
+ type: object
+ version:
+ additionalProperties: false
+ properties:
+ host:
+ additionalProperties: false
+ properties:
+ kernel:
+ type: string
+ required:
+ - kernel
+ type: object
+ image-version:
+ additionalProperties: false
+ properties:
+ created:
+ type: string
+ id:
+ type: string
+ sha256:
+ type: string
+ tag:
+ type: string
+ required:
+ - id
+ - sha256
+ - tag
+ - created
+ type: object
+ localstack-version:
+ additionalProperties: false
+ properties:
+ build-date:
+ type:
+ - string
+ - 'null'
+ build-git-hash:
+ type:
+ - string
+ - 'null'
+ build-version:
+ type:
+ - string
+ - 'null'
+ required:
+ - build-date
+ - build-git-hash
+ - build-version
+ type: object
+ required:
+ - image-version
+ - localstack-version
+ - host
+ type: object
+ required:
+ - version
+ - info
+ - services
+ - config
+ - docker-inspect
+ - docker-dependent-image-hosts
+ - file-tree
+ - important-endpoints
+ - logs
+ - usage
+ type: object
+ description: Diagnostics report
+ /_localstack/health:
+ get:
+ description: Get available LocalStack features and AWS services
+ operationId: get_features_and_services
+ tags: [localstack]
+ parameters:
+ - allowEmptyValue: true
+ in: query
+ name: reload
+ required: false
+ schema:
+ type: string
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ additionalProperties: false
+ properties:
+ edition:
+ enum:
+ - community
+ - pro
+ - enterprise
+ - unknown
+ type: string
+ features:
+ type: object
+ services:
+ type: object
+ version:
+ type: string
+ required:
+ - edition
+ - services
+ - version
+ type: object
+ description: Available LocalStack features and AWS services
+ head:
+ tags: [localstack]
+ operationId: health
+ responses:
+ '200':
+ content:
+ text/plain: {}
+ description: ''
+ post:
+ description: Restart or terminate LocalStack session
+ operationId: manage_session
+ tags: [localstack]
+ requestBody:
+ content:
+ application/json:
+ schema:
+ additionalProperties: false
+ properties:
+ action:
+ enum:
+ - restart
+ - kill
+ type: string
+ required:
+ - action
+ type: object
+ description: Action to perform
+ required: true
+ responses:
+ '200':
+ content:
+ text/plain: {}
+ description: Action was successful
+ '400':
+ content:
+ text/plain: {}
+ description: Bad request
+ put:
+ description: Store arbitrary data to in-memory state
+ operationId: store_data
+ tags: [localstack]
+ requestBody:
+ content:
+ application/json:
+ schema:
+ type: object
+ description: Data to save
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ additionalProperties: false
+ properties:
+ status:
+ type: string
+ required:
+ - status
+ type: object
+ description: Data was saved
+ /_localstack/info:
+ get:
+ description: Get information about the current LocalStack session
+ operationId: get_session_info
+ tags: [localstack]
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/SessionInfo'
+ description: Information about the current LocalStack session
+ /_localstack/init:
+ get:
+ description: Get information about init scripts
+ operationId: get_init_script_info
+ tags: [localstack]
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/InitScripts'
+ description: Information about init scripts
+ /_localstack/init/{stage}:
+ get:
+ description: Get information about init scripts in a specific stage
+ operationId: get_init_script_info_stage
+ tags: [localstack]
+ parameters:
+ - in: path
+ name: stage
+ required: true
+ schema:
+ type: string
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/InitScriptsStage'
+ description: Information about init scripts in a specific stage
+ /_localstack/plugins:
+ get:
+ description: ''
+ operationId: get_plugins
+ tags: [localstack]
+ responses:
+ '200':
+ content:
+ application/json: {}
+ description: ''
+ /_localstack/usage:
+ get:
+ description: ''
+ operationId: get_usage
+ tags: [localstack]
+ responses:
+ '200':
+ content:
+ application/json: {}
+ description: ''
diff --git a/localstack-core/localstack/packages/__init__.py b/localstack-core/localstack/packages/__init__.py
new file mode 100644
index 0000000000000..f4f7585cfbe95
--- /dev/null
+++ b/localstack-core/localstack/packages/__init__.py
@@ -0,0 +1,25 @@
+from .api import (
+ InstallTarget,
+ NoSuchVersionException,
+ Package,
+ PackageException,
+ PackageInstaller,
+ PackagesPlugin,
+ package,
+ packages,
+)
+from .core import DownloadInstaller, GitHubReleaseInstaller, SystemNotSupportedException
+
+__all__ = [
+ "Package",
+ "PackageInstaller",
+ "GitHubReleaseInstaller",
+ "DownloadInstaller",
+ "InstallTarget",
+ "PackageException",
+ "NoSuchVersionException",
+ "SystemNotSupportedException",
+ "PackagesPlugin",
+ "package",
+ "packages",
+]
diff --git a/localstack-core/localstack/packages/api.py b/localstack-core/localstack/packages/api.py
new file mode 100644
index 0000000000000..b3260e9c5b83f
--- /dev/null
+++ b/localstack-core/localstack/packages/api.py
@@ -0,0 +1,403 @@
+import abc
+import functools
+import logging
+import os
+from collections import defaultdict
+from enum import Enum
+from inspect import getmodule
+from threading import RLock
+from typing import Callable, List, Optional, Tuple
+
+from plux import Plugin, PluginManager, PluginSpec
+
+from localstack import config
+
+LOG = logging.getLogger(__name__)
+
+
+class PackageException(Exception):
+ """Basic exception indicating that a package-specific exception occurred."""
+
+ pass
+
+
+class NoSuchVersionException(PackageException):
+ """Exception indicating that a requested installer version is not available / supported."""
+
+ def __init__(self, package: str = None, version: str = None):
+ message = "Unable to find requested version"
+ if package and version:
+ message += f"Unable to find requested version '{version}' for package '{package}'"
+ super().__init__(message)
+
+
+class InstallTarget(Enum):
+ """
+ Different installation targets.
+ Attention:
+ - These targets are directly used in the LPM API and are therefore part of a public API!
+ - The order of the entries in the enum define the default lookup order when looking for package installations.
+
+ These targets refer to the directories in config#Directories.
+ - VAR_LIBS: Used for packages installed at runtime. They are installed in a host-mounted volume.
+ This directory / these installations persist across multiple containers.
+ - STATIC_LIBS: Used for packages installed at build time. They are installed in a non-host-mounted volume.
+ This directory is re-created whenever a container is recreated.
+ """
+
+ VAR_LIBS = config.dirs.var_libs
+ STATIC_LIBS = config.dirs.static_libs
+
+
+class PackageInstaller(abc.ABC):
+ """
+ Base class for a specific installer.
+ An instance of an installer manages the installation of a specific Package (in a specific version, if there are
+ multiple versions).
+ """
+
+ def __init__(self, name: str, version: str, install_lock: Optional[RLock] = None):
+ """
+ :param name: technical package name, f.e. "opensearch"
+ :param version: version of the package to install
+ :param install_lock: custom lock which should be used for this package installer instance for the
+ complete #install call. Defaults to a per-instance reentrant lock (RLock).
+ Package instances create one installer per version. Therefore, by default, the lock
+ ensures that package installations of the same package and version are mutually exclusive.
+ """
+ self.name = name
+ self.version = version
+ self.install_lock = install_lock or RLock()
+ self._setup_for_target: dict[InstallTarget, bool] = defaultdict(lambda: False)
+
+ def install(self, target: Optional[InstallTarget] = None) -> None:
+ """
+ Performs the package installation.
+
+ :param target: preferred installation target. Default is VAR_LIBS.
+ :return: None
+ :raises PackageException: if the installation fails
+ """
+ try:
+ if not target:
+ target = InstallTarget.VAR_LIBS
+ # We have to acquire the lock before checking if the package is installed, as the is_installed check
+ # is _only_ reliable if no other thread is currently actually installing
+ with self.install_lock:
+ # Skip the installation if it's already installed
+ if not self.is_installed():
+ LOG.debug("Starting installation of %s %s...", self.name, self.version)
+ self._prepare_installation(target)
+ self._install(target)
+ self._post_process(target)
+ LOG.debug("Installation of %s %s finished.", self.name, self.version)
+ else:
+ LOG.debug(
+ "Installation of %s %s skipped (already installed).",
+ self.name,
+ self.version,
+ )
+ if not self._setup_for_target[target]:
+ LOG.debug("Performing runtime setup for already installed package.")
+ self._setup_existing_installation(target)
+ except PackageException as e:
+ raise e
+ except Exception as e:
+ raise PackageException(f"Installation of {self.name} {self.version} failed.") from e
+
+ def is_installed(self) -> bool:
+ """
+ Checks if the package is already installed.
+
+ :return: True if the package is already installed (i.e. an installation is not necessary).
+ """
+ return self.get_installed_dir() is not None
+
+ def get_installed_dir(self) -> str | None:
+ """
+ Returns the directory of an existing installation. The directory can differ based on the installation target
+ and version.
+ :return: str representation of the installation directory path or None if the package is not installed anywhere
+ """
+ for target in InstallTarget:
+ directory = self._get_install_dir(target)
+ if directory and os.path.exists(self._get_install_marker_path(directory)):
+ return directory
+
+ def _get_install_dir(self, target: InstallTarget) -> str:
+ """
+ Builds the installation directory for a specific target.
+ :param target: to create the installation directory path for
+ :return: str representation of the installation directory for the given target
+ """
+ return os.path.join(target.value, self.name, self.version)
+
+ def _get_install_marker_path(self, install_dir: str) -> str:
+ """
+ Builds the path for a specific "marker" whose presence indicates that the package has been installed
+ successfully in the given directory.
+
+ :param install_dir: base path for the check (f.e. /var/lib/localstack/lib/dynamodblocal/latest/)
+ :return: path which should be checked to indicate if the package has been installed successfully
+ (f.e. /var/lib/localstack/lib/dynamodblocal/latest/DynamoDBLocal.jar)
+ """
+ raise NotImplementedError()
+
+ def _setup_existing_installation(self, target: InstallTarget) -> None:
+ """
+ Internal function to perform the setup for an existing installation, f.e. adding a path to an environment.
+ This is only necessary for certain installers (like the PythonPackageInstaller).
+ This function will _always_ be executed _exactly_ once within a Python session for a specific installer
+ instance and target, if #install is called for the respective target.
+ :param target: of the installation
+ :return: None
+ """
+ pass
+
+ def _prepare_installation(self, target: InstallTarget) -> None:
+ """
+ Internal function to prepare an installation, f.e. by downloading some data or installing an OS package repo.
+ Can be implemented by specific installers.
+ :param target: of the installation
+ :return: None
+ """
+ pass
+
+ def _install(self, target: InstallTarget) -> None:
+ """
+ Internal function to perform the actual installation.
+ Must be implemented by specific installers.
+ :param target: of the installation
+ :return: None
+ """
+ raise NotImplementedError()
+
+ def _post_process(self, target: InstallTarget) -> None:
+ """
+ Internal function to perform some post-processing, f.e. patching an installation or creating symlinks.
+ :param target: of the installation
+ :return: None
+ """
+ pass
+
+
+class Package(abc.ABC):
+ """
+ A Package defines a specific kind of software, mostly used as backends or supporting system for service
+ implementations.
+ """
+
+ def __init__(self, name: str, default_version: str):
+ """
+ :param name: Human readable name of the package, f.e. "PostgreSQL"
+ :param default_version: Default version of the package which is used for installations if no version is defined
+ """
+ self.name = name
+ self.default_version = default_version
+
+ def get_installed_dir(self, version: str | None = None) -> str | None:
+ """
+ Finds a directory where the package (in the specific version) is installed.
+ :param version: of the package to look for. If None, the default version of the package is used.
+ :return: str representation of the path to the existing installation directory or None if the package in this
+ version is not yet installed.
+ """
+ return self.get_installer(version).get_installed_dir()
+
+ def install(self, version: str | None = None, target: Optional[InstallTarget] = None) -> None:
+ """
+ Installs the package in the given version in the preferred target location.
+ :param version: version of the package to install. If None, the default version of the package will be used.
+ :param target: preferred installation target. If None, the var_libs directory is used.
+ :raises NoSuchVersionException: If the given version is not supported.
+ """
+ self.get_installer(version).install(target)
+
+ @functools.lru_cache()
+ def get_installer(self, version: str | None = None) -> PackageInstaller:
+ """
+ Returns the installer instance for a specific version of the package.
+
+ It is important that this be LRU cached. Installers have a mutex lock to prevent races, and it is necessary
+ that this method returns the same installer instance for a given version.
+
+ :param version: version of the package to install. If None, the default version of the package will be used.
+ :return: PackageInstaller instance for the given version.
+ :raises NoSuchVersionException: If the given version is not supported.
+ """
+ if not version:
+ return self.get_installer(self.default_version)
+ if version not in self.get_versions():
+ raise NoSuchVersionException(package=self.name, version=version)
+ return self._get_installer(version)
+
+ def get_versions(self) -> List[str]:
+ """
+ :return: List of all versions available for this package.
+ """
+ raise NotImplementedError()
+
+ def _get_installer(self, version: str) -> PackageInstaller:
+ """
+ Internal lookup function which needs to be implemented by specific packages.
+ It creates PackageInstaller instances for the specific version.
+
+ :param version: to find the installer for
+ :return: PackageInstaller instance responsible for installing the given version of the package.
+ """
+ raise NotImplementedError()
+
+ def __str__(self):
+ return self.name
+
+
+class MultiPackageInstaller(PackageInstaller):
+ """
+ PackageInstaller implementation which composes of multiple package installers.
+ """
+
+ def __init__(self, name: str, version: str, package_installer: List[PackageInstaller]):
+ """
+ :param name: of the (multi-)package installer
+ :param version: of this (multi-)package installer
+ :param package_installer: List of installers this multi-package installer consists of
+ """
+ super().__init__(name=name, version=version)
+
+ assert isinstance(package_installer, list)
+ assert len(package_installer) > 0
+ self.package_installer = package_installer
+
+ def install(self, target: Optional[InstallTarget] = None) -> None:
+ """
+ Installs the different packages this installer is composed of.
+
+ :param target: which defines where to install the packages.
+ :return: None
+ """
+ for package_installer in self.package_installer:
+ package_installer.install(target=target)
+
+ def get_installed_dir(self) -> str | None:
+ # By default, use the installed-dir of the first package
+ return self.package_installer[0].get_installed_dir()
+
+ def _install(self, target: InstallTarget) -> None:
+ # This package installer actually only calls other installers, we pass here
+ pass
+
+ def _get_install_dir(self, target: InstallTarget) -> str:
+ # By default, use the install-dir of the first package
+ return self.package_installer[0]._get_install_dir(target)
+
+ def _get_install_marker_path(self, install_dir: str) -> str:
+ # By default, use the install-marker-path of the first package
+ return self.package_installer[0]._get_install_marker_path(install_dir)
+
+
+PLUGIN_NAMESPACE = "localstack.packages"
+
+
+class PackagesPlugin(Plugin):
+ """
+ Plugin implementation for Package plugins.
+ A package plugin exposes a specific package instance.
+ """
+
+ api: str
+ name: str
+
+ def __init__(
+ self,
+ name: str,
+ scope: str,
+ get_package: Callable[[], Package | List[Package]],
+ should_load: Callable[[], bool] = None,
+ ) -> None:
+ super().__init__()
+ self.name = name
+ self.scope = scope
+ self._get_package = get_package
+ self._should_load = should_load
+
+ def should_load(self) -> bool:
+ if self._should_load:
+ return self._should_load()
+ return True
+
+ def get_package(self) -> Package:
+ """
+ :return: returns the package instance of this package plugin
+ """
+ return self._get_package()
+
+
+class NoSuchPackageException(PackageException):
+ """Exception raised by the PackagesPluginManager to indicate that a package / version is not available."""
+
+ pass
+
+
+class PackagesPluginManager(PluginManager[PackagesPlugin]):
+ """PluginManager which simplifies the loading / access of PackagesPlugins and their exposed package instances."""
+
+ def __init__(self):
+ super().__init__(PLUGIN_NAMESPACE)
+
+ def get_all_packages(self) -> List[Tuple[str, str, Package]]:
+ return sorted(
+ [(plugin.name, plugin.scope, plugin.get_package()) for plugin in self.load_all()]
+ )
+
+ def get_packages(
+ self, package_names: List[str], version: Optional[str] = None
+ ) -> List[Package]:
+ # Plugin names are unique, but there could be multiple packages with the same name in different scopes
+ plugin_specs_per_name = defaultdict(list)
+ # Plugin names have the format "/", build a dict of specs per package name for the lookup
+ for plugin_spec in self.list_plugin_specs():
+ (package_name, _, _) = plugin_spec.name.rpartition("/")
+ plugin_specs_per_name[package_name].append(plugin_spec)
+
+ package_instances: List[Package] = []
+ for package_name in package_names:
+ plugin_specs = plugin_specs_per_name.get(package_name)
+ if not plugin_specs:
+ raise NoSuchPackageException(
+ f"unable to locate installer for package {package_name}"
+ )
+ for plugin_spec in plugin_specs:
+ package_instance = self.load(plugin_spec.name).get_package()
+ package_instances.append(package_instance)
+ if version and version not in package_instance.get_versions():
+ raise NoSuchPackageException(
+ f"unable to locate installer for package {package_name} and version {version}"
+ )
+
+ return package_instances
+
+
+def package(
+ name: str = None, scope: str = "community", should_load: Optional[Callable[[], bool]] = None
+):
+ """
+ Decorator for marking methods that create Package instances as a PackagePlugin.
+ Methods marked with this decorator are discoverable as a PluginSpec within the namespace "localstack.packages",
+ with the name ":". If api is not explicitly specified, then the parent module name is used as
+ service name.
+ """
+
+ def wrapper(fn):
+ _name = name or getmodule(fn).__name__.split(".")[-2]
+
+ @functools.wraps(fn)
+ def factory() -> PackagesPlugin:
+ return PackagesPlugin(name=_name, scope=scope, get_package=fn, should_load=should_load)
+
+ return PluginSpec(PLUGIN_NAMESPACE, f"{_name}/{scope}", factory=factory)
+
+ return wrapper
+
+
+# TODO remove (only used for migrating to new #package decorator)
+packages = package
diff --git a/localstack-core/localstack/packages/core.py b/localstack-core/localstack/packages/core.py
new file mode 100644
index 0000000000000..ae04a4b70f171
--- /dev/null
+++ b/localstack-core/localstack/packages/core.py
@@ -0,0 +1,373 @@
+import logging
+import os
+import re
+from abc import ABC
+from functools import lru_cache
+from sys import version_info
+from typing import Optional, Tuple
+
+import requests
+
+from localstack import config
+
+from ..constants import LOCALSTACK_VENV_FOLDER, MAVEN_REPO_URL
+from ..utils.archives import download_and_extract
+from ..utils.files import chmod_r, chown_r, mkdir, rm_rf
+from ..utils.http import download
+from ..utils.run import is_root, run
+from ..utils.venv import VirtualEnvironment
+from .api import InstallTarget, PackageException, PackageInstaller
+
+LOG = logging.getLogger(__name__)
+
+
+class SystemNotSupportedException(PackageException):
+ """Exception indicating that the current system is not allowed."""
+
+ pass
+
+
+class ExecutableInstaller(PackageInstaller, ABC):
+ """
+ This installer simply adds a clean interface for accessing a downloaded executable directly
+ """
+
+ def get_executable_path(self) -> str | None:
+ """
+ :return: the path to the downloaded binary or None if it's not yet downloaded / installed.
+ """
+ install_dir = self.get_installed_dir()
+ if install_dir:
+ return self._get_install_marker_path(install_dir)
+
+
+class DownloadInstaller(ExecutableInstaller):
+ def __init__(self, name: str, version: str):
+ super().__init__(name, version)
+
+ def _get_download_url(self) -> str:
+ raise NotImplementedError()
+
+ def _get_install_marker_path(self, install_dir: str) -> str:
+ url = self._get_download_url()
+ binary_name = os.path.basename(url)
+ return os.path.join(install_dir, binary_name)
+
+ def _install(self, target: InstallTarget) -> None:
+ target_directory = self._get_install_dir(target)
+ mkdir(target_directory)
+ download_url = self._get_download_url()
+ target_path = self._get_install_marker_path(target_directory)
+ download(download_url, target_path)
+
+
+class ArchiveDownloadAndExtractInstaller(ExecutableInstaller):
+ def __init__(self, name: str, version: str, extract_single_directory: bool = False):
+ """
+ :param name: technical package name, f.e. "opensearch"
+ :param version: version of the package to install
+ :param extract_single_directory: whether to extract files from single root folder in the archive
+ """
+ super().__init__(name, version)
+ self.extract_single_directory = extract_single_directory
+
+ def _get_install_marker_path(self, install_dir: str) -> str:
+ raise NotImplementedError()
+
+ def _get_download_url(self) -> str:
+ raise NotImplementedError()
+
+ def get_installed_dir(self) -> str | None:
+ installed_dir = super().get_installed_dir()
+ subdir = self._get_archive_subdir()
+
+ # If the specific installer defines a subdirectory, we return the subdirectory.
+ # f.e. /var/lib/localstack/lib/amazon-mq/5.16.5/apache-activemq-5.16.5/
+ if installed_dir and subdir:
+ return os.path.join(installed_dir, subdir)
+
+ return installed_dir
+
+ def _get_archive_subdir(self) -> str | None:
+ """
+ :return: name of the subdirectory contained in the archive or none if the package content is at the root level
+ of the archive
+ """
+ return None
+
+ def get_executable_path(self) -> str | None:
+ subdir = self._get_archive_subdir()
+ if subdir is None:
+ return super().get_executable_path()
+ else:
+ install_dir = self.get_installed_dir()
+ if install_dir:
+ install_dir = install_dir[: -len(subdir)]
+ return self._get_install_marker_path(install_dir)
+
+ def _install(self, target: InstallTarget) -> None:
+ target_directory = self._get_install_dir(target)
+ mkdir(target_directory)
+ download_url = self._get_download_url()
+ archive_name = os.path.basename(download_url)
+ archive_path = os.path.join(config.dirs.tmp, archive_name)
+ download_and_extract(
+ download_url,
+ retries=3,
+ tmp_archive=archive_path,
+ target_dir=target_directory,
+ )
+ rm_rf(archive_path)
+ if self.extract_single_directory:
+ dir_contents = os.listdir(target_directory)
+ if len(dir_contents) != 1:
+ return
+ target_subdir = os.path.join(target_directory, dir_contents[0])
+ if not os.path.isdir(target_subdir):
+ return
+ os.rename(target_subdir, f"{target_directory}.backup")
+ rm_rf(target_directory)
+ os.rename(f"{target_directory}.backup", target_directory)
+
+
+class PermissionDownloadInstaller(DownloadInstaller, ABC):
+ def _install(self, target: InstallTarget) -> None:
+ super()._install(target)
+ chmod_r(self.get_executable_path(), 0o777)
+
+
+class GitHubReleaseInstaller(PermissionDownloadInstaller):
+ """
+ Installer which downloads an asset from a GitHub project's tag.
+ """
+
+ def __init__(self, name: str, tag: str, github_slug: str):
+ super().__init__(name, tag)
+ self.github_tag_url = (
+ f"https://api.github.com/repos/{github_slug}/releases/tags/{self.version}"
+ )
+
+ @lru_cache()
+ def _get_download_url(self) -> str:
+ asset_name = self._get_github_asset_name()
+ # try to use a token when calling the GH API for increased API rate limits
+ headers = None
+ gh_token = os.environ.get("GITHUB_API_TOKEN")
+ if gh_token:
+ headers = {"authorization": f"Bearer {gh_token}"}
+ response = requests.get(self.github_tag_url, headers=headers)
+ if not response.ok:
+ raise PackageException(
+ f"Could not get list of releases from {self.github_tag_url}: {response.text}"
+ )
+ github_release = response.json()
+ download_url = None
+ for asset in github_release.get("assets", []):
+ # find the correct binary in the release
+ if asset["name"] == asset_name:
+ download_url = asset["browser_download_url"]
+ break
+ if download_url is None:
+ raise PackageException(
+ f"Could not find required binary {asset_name} in release {self.github_tag_url}"
+ )
+ return download_url
+
+ def _get_install_marker_path(self, install_dir: str) -> str:
+ # Use the GitHub asset name instead of the download URL (since the download URL needs to be fetched online).
+ return os.path.join(install_dir, self._get_github_asset_name())
+
+ def _get_github_asset_name(self) -> str:
+ """
+ Determines the name of the asset to download.
+ The asset name must be determinable without having any online data (because it is used in offline scenarios to
+ determine if the package is already installed).
+
+ :return: name of the asset to download from the GitHub project's tag / version
+ """
+ raise NotImplementedError()
+
+
+class NodePackageInstaller(ExecutableInstaller):
+ """Package installer for Node / NPM packages."""
+
+ def __init__(
+ self,
+ package_name: str,
+ version: str,
+ package_spec: Optional[str] = None,
+ main_module: str = "main.js",
+ ):
+ """
+ Initializes the Node / NPM package installer.
+ :param package_name: npm package name
+ :param version: version of the package which should be installed
+ :param package_spec: optional package spec for the installation.
+ If not set, the package name and version will be used for the installation.
+ :param main_module: main module file of the package
+ """
+ super().__init__(package_name, version)
+ self.package_name = package_name
+ # If the package spec is not explicitly set (f.e. to a repo), we build it and pin the version
+ self.package_spec = package_spec or f"{self.package_name}@{version}"
+ self.main_module = main_module
+
+ def _get_install_marker_path(self, install_dir: str) -> str:
+ return os.path.join(install_dir, "node_modules", self.package_name, self.main_module)
+
+ def _install(self, target: InstallTarget) -> None:
+ target_dir = self._get_install_dir(target)
+
+ run(
+ [
+ "npm",
+ "install",
+ "--prefix",
+ target_dir,
+ self.package_spec,
+ ]
+ )
+ # npm 9+ does _not_ set the ownership of files anymore if run as root
+ # - https://github.blog/changelog/2022-10-24-npm-v9-0-0-released/
+ # - https://github.com/npm/cli/pull/5704
+ # - https://github.com/localstack/localstack/issues/7620
+ if is_root():
+ # if the package was installed as root, set the ownership manually
+ LOG.debug("Setting ownership root:root on %s", target_dir)
+ chown_r(target_dir, "root")
+
+
+LOCALSTACK_VENV = VirtualEnvironment(LOCALSTACK_VENV_FOLDER)
+
+
+class PythonPackageInstaller(PackageInstaller):
+ """
+ Package installer which allows the runtime-installation of additional python packages used by certain services.
+ f.e. vosk as offline speech recognition toolkit (which is ~7MB in size compressed and ~26MB uncompressed).
+ """
+
+ normalized_name: str
+ """Normalized package name according to PEP440."""
+
+ def __init__(self, name: str, version: str, *args, **kwargs):
+ super().__init__(name, version, *args, **kwargs)
+ self.normalized_name = self._normalize_package_name(name)
+
+ def _normalize_package_name(self, name: str):
+ """
+ Normalized the Python package name according to PEP440.
+ https://packaging.python.org/en/latest/specifications/name-normalization/#name-normalization
+ """
+ return re.sub(r"[-_.]+", "-", name).lower()
+
+ def _get_install_dir(self, target: InstallTarget) -> str:
+ # all python installers share a venv
+ return os.path.join(target.value, "python-packages")
+
+ def _get_install_marker_path(self, install_dir: str) -> str:
+ python_subdir = f"python{version_info[0]}.{version_info[1]}"
+ dist_info_dir = f"{self.normalized_name}-{self.version}.dist-info"
+ # the METADATA file is mandatory, use it as install marker
+ return os.path.join(
+ install_dir, "lib", python_subdir, "site-packages", dist_info_dir, "METADATA"
+ )
+
+ def _get_venv(self, target: InstallTarget) -> VirtualEnvironment:
+ venv_dir = self._get_install_dir(target)
+ return VirtualEnvironment(venv_dir)
+
+ def _prepare_installation(self, target: InstallTarget) -> None:
+ # make sure the venv is properly set up before installing the package
+ venv = self._get_venv(target)
+ if not venv.exists:
+ LOG.info("creating virtual environment at %s", venv.venv_dir)
+ venv.create()
+ LOG.info("adding localstack venv path %s", venv.venv_dir)
+ venv.add_pth("localstack-venv", LOCALSTACK_VENV)
+ LOG.debug("injecting venv into path %s", venv.venv_dir)
+ venv.inject_to_sys_path()
+
+ def _install(self, target: InstallTarget) -> None:
+ venv = self._get_venv(target)
+ python_bin = os.path.join(venv.venv_dir, "bin/python")
+
+ # run pip via the python binary of the venv
+ run([python_bin, "-m", "pip", "install", f"{self.name}=={self.version}"], print_error=False)
+
+ def _setup_existing_installation(self, target: InstallTarget) -> None:
+ """If the venv is already present, it just needs to be initialized once."""
+ self._prepare_installation(target)
+
+
+class MavenDownloadInstaller(DownloadInstaller):
+ """The packageURL is easy copy/pastable from the Maven central repository and the first package URL
+ defines the package name and version.
+ Example package_url: pkg:maven/software.amazon.event.ruler/event-ruler@1.7.3
+ => name: event-ruler
+ => version: 1.7.3
+ """
+
+ # Example: software.amazon.event.ruler
+ group_id: str
+ # Example: event-ruler
+ artifact_id: str
+
+ # Custom installation directory
+ install_dir_suffix: str | None
+
+ def __init__(self, package_url: str, install_dir_suffix: str | None = None):
+ self.group_id, self.artifact_id, version = parse_maven_package_url(package_url)
+ super().__init__(self.artifact_id, version)
+ self.install_dir_suffix = install_dir_suffix
+
+ def _get_download_url(self) -> str:
+ group_id_path = self.group_id.replace(".", "/")
+ return f"{MAVEN_REPO_URL}/{group_id_path}/{self.artifact_id}/{self.version}/{self.artifact_id}-{self.version}.jar"
+
+ def _get_install_dir(self, target: InstallTarget) -> str:
+ """Allow to overwrite the default installation directory.
+ This enables downloading transitive dependencies into the same directory.
+ """
+ if self.install_dir_suffix:
+ return os.path.join(target.value, self.install_dir_suffix)
+ else:
+ return super()._get_install_dir(target)
+
+
+class MavenPackageInstaller(MavenDownloadInstaller):
+ """Package installer for downloading Maven JARs, including optional dependencies.
+ The first Maven package is used as main LPM package and other dependencies are installed additionally.
+ Follows the Maven naming conventions: https://maven.apache.org/guides/mini/guide-naming-conventions.html
+ """
+
+ # Installers for Maven dependencies
+ dependencies: list[MavenDownloadInstaller]
+
+ def __init__(self, *package_urls: str):
+ super().__init__(package_urls[0])
+ self.dependencies = []
+
+ # Create installers for dependencies
+ for package_url in package_urls[1:]:
+ install_dir_suffix = os.path.join(self.name, self.version)
+ self.dependencies.append(MavenDownloadInstaller(package_url, install_dir_suffix))
+
+ def _install(self, target: InstallTarget) -> None:
+ # Install all dependencies first
+ for dependency in self.dependencies:
+ dependency._install(target)
+ # Install the main Maven package once all dependencies are installed.
+ # This main package indicates whether all dependencies are installed.
+ super()._install(target)
+
+
+def parse_maven_package_url(package_url: str) -> Tuple[str, str, str]:
+ """Example: parse_maven_package_url("pkg:maven/software.amazon.event.ruler/event-ruler@1.7.3")
+ -> software.amazon.event.ruler, event-ruler, 1.7.3
+ """
+ parts = package_url.split("/")
+ group_id = parts[1]
+ sub_parts = parts[2].split("@")
+ artifact_id = sub_parts[0]
+ version = sub_parts[1]
+ return group_id, artifact_id, version
diff --git a/localstack-core/localstack/packages/debugpy.py b/localstack-core/localstack/packages/debugpy.py
new file mode 100644
index 0000000000000..bd2a768b08cd7
--- /dev/null
+++ b/localstack-core/localstack/packages/debugpy.py
@@ -0,0 +1,42 @@
+from typing import List
+
+from localstack.packages import InstallTarget, Package, PackageInstaller
+from localstack.utils.run import run
+
+
+class DebugPyPackage(Package):
+ def __init__(self):
+ super().__init__("DebugPy", "latest")
+
+ def get_versions(self) -> List[str]:
+ return ["latest"]
+
+ def _get_installer(self, version: str) -> PackageInstaller:
+ return DebugPyPackageInstaller("debugpy", version)
+
+
+class DebugPyPackageInstaller(PackageInstaller):
+ # TODO: migrate this to the upcoming pip installer
+
+ def is_installed(self) -> bool:
+ try:
+ import debugpy # noqa: T100
+
+ assert debugpy
+ return True
+ except ModuleNotFoundError:
+ return False
+
+ def _get_install_marker_path(self, install_dir: str) -> str:
+ # TODO: This method currently does not provide the actual install_marker.
+ # Since we overwrote is_installed(), this installer does not install anything under
+ # var/static libs, and we also don't need an executable, we don't need it to operate the installer.
+ # fix with migration to pip installer
+ return install_dir
+
+ def _install(self, target: InstallTarget) -> None:
+ cmd = "pip install debugpy"
+ run(cmd)
+
+
+debugpy_package = DebugPyPackage()
diff --git a/localstack-core/localstack/packages/ffmpeg.py b/localstack-core/localstack/packages/ffmpeg.py
new file mode 100644
index 0000000000000..096c4fae34a79
--- /dev/null
+++ b/localstack-core/localstack/packages/ffmpeg.py
@@ -0,0 +1,44 @@
+import os
+from typing import List
+
+from localstack.packages import Package, PackageInstaller
+from localstack.packages.core import ArchiveDownloadAndExtractInstaller
+from localstack.utils.platform import get_arch
+
+FFMPEG_STATIC_BIN_URL = (
+ "https://www.johnvansickle.com/ffmpeg/releases/ffmpeg-{version}-{arch}-static.tar.xz"
+)
+
+
+class FfmpegPackage(Package):
+ def __init__(self):
+ super().__init__(name="ffmpeg", default_version="7.0.1")
+
+ def _get_installer(self, version: str) -> PackageInstaller:
+ return FfmpegPackageInstaller(version)
+
+ def get_versions(self) -> List[str]:
+ return ["7.0.1"]
+
+
+class FfmpegPackageInstaller(ArchiveDownloadAndExtractInstaller):
+ def __init__(self, version: str):
+ super().__init__("ffmpeg", version)
+
+ def _get_download_url(self) -> str:
+ return FFMPEG_STATIC_BIN_URL.format(arch=get_arch(), version=self.version)
+
+ def _get_install_marker_path(self, install_dir: str) -> str:
+ return os.path.join(install_dir, self._get_archive_subdir())
+
+ def _get_archive_subdir(self) -> str:
+ return f"ffmpeg-{self.version}-{get_arch()}-static"
+
+ def get_ffmpeg_path(self) -> str:
+ return os.path.join(self.get_installed_dir(), "ffmpeg")
+
+ def get_ffprobe_path(self) -> str:
+ return os.path.join(self.get_installed_dir(), "ffprobe")
+
+
+ffmpeg_package = FfmpegPackage()
diff --git a/localstack-core/localstack/packages/java.py b/localstack-core/localstack/packages/java.py
new file mode 100644
index 0000000000000..c37792ffc011a
--- /dev/null
+++ b/localstack-core/localstack/packages/java.py
@@ -0,0 +1,202 @@
+import logging
+import os
+from typing import List
+
+import requests
+
+from localstack.constants import USER_AGENT_STRING
+from localstack.packages import InstallTarget, Package
+from localstack.packages.core import ArchiveDownloadAndExtractInstaller
+from localstack.utils.files import rm_rf
+from localstack.utils.platform import Arch, get_arch, is_linux, is_mac_os
+from localstack.utils.run import run
+
+LOG = logging.getLogger(__name__)
+
+# Default version if not specified
+DEFAULT_JAVA_VERSION = "11"
+
+# Supported Java LTS versions mapped with Eclipse Temurin build semvers
+JAVA_VERSIONS = {
+ "8": "8u432-b06",
+ "11": "11.0.25+9",
+ "17": "17.0.13+11",
+ "21": "21.0.5+11",
+}
+
+
+class JavaInstallerMixin:
+ """
+ Mixin class for packages that depend on Java. It introduces methods that install Java and help build environment.
+ """
+
+ def _prepare_installation(self, target: InstallTarget) -> None:
+ java_package.install(target=target)
+
+ def get_java_home(self) -> str | None:
+ """
+ Returns path to JRE installation.
+ """
+ return java_package.get_installer().get_java_home()
+
+ def get_java_lib_path(self) -> str | None:
+ """
+ Returns the path to the Java shared library.
+ """
+ if java_home := self.get_java_home():
+ if is_mac_os():
+ return os.path.join(java_home, "lib", "jli", "libjli.dylib")
+ return os.path.join(java_home, "lib", "server", "libjvm.so")
+
+ def get_java_env_vars(self, path: str = None, ld_library_path: str = None) -> dict[str, str]:
+ """
+ Returns environment variables pointing to the Java installation. This is useful to build the environment where
+ the application will run.
+
+ :param path: If not specified, the value of PATH will be obtained from the environment
+ :param ld_library_path: If not specified, the value of LD_LIBRARY_PATH will be obtained from the environment
+ :return: dict consisting of two items:
+ - JAVA_HOME: path to JRE installation
+ - PATH: the env path variable updated with JRE bin path
+ """
+ java_home = self.get_java_home()
+ java_bin = f"{java_home}/bin"
+
+ path = path or os.environ["PATH"]
+
+ ld_library_path = ld_library_path or os.environ.get("LD_LIBRARY_PATH")
+ # null paths (e.g. `:/foo`) have a special meaning according to the manpages
+ if ld_library_path is None:
+ ld_library_path = f"{java_home}/lib:{java_home}/lib/server"
+ else:
+ ld_library_path = f"{java_home}/lib:{java_home}/lib/server:{ld_library_path}"
+
+ return {
+ "JAVA_HOME": java_home,
+ "LD_LIBRARY_PATH": ld_library_path,
+ "PATH": f"{java_bin}:{path}",
+ }
+
+
+class JavaPackageInstaller(ArchiveDownloadAndExtractInstaller):
+ def __init__(self, version: str):
+ super().__init__("java", version, extract_single_directory=True)
+
+ def _get_install_marker_path(self, install_dir: str) -> str:
+ if is_mac_os():
+ return os.path.join(install_dir, "Contents", "Home", "bin", "java")
+ return os.path.join(install_dir, "bin", "java")
+
+ def _get_download_url(self) -> str:
+ # Note: Eclipse Temurin does not provide Mac aarch64 Java 8 builds.
+ # See https://adoptium.net/en-GB/supported-platforms/
+ try:
+ LOG.debug("Determining the latest Java build version")
+ return self._download_url_latest_release()
+ except Exception as exc: # noqa
+ LOG.debug(
+ "Unable to determine the latest Java build version. Using pinned versions: %s", exc
+ )
+ return self._download_url_fallback()
+
+ def _post_process(self, target: InstallTarget) -> None:
+ target_directory = self._get_install_dir(target)
+ minimal_jre_path = os.path.join(target.value, self.name, f"{self.version}.minimal")
+ rm_rf(minimal_jre_path)
+
+ # If jlink is not available, use the environment as is
+ if not os.path.exists(os.path.join(target_directory, "bin", "jlink")):
+ LOG.warning("Skipping JRE optimisation because jlink is not available")
+ return
+
+ # Build a custom JRE with only the necessary bits to minimise disk footprint
+ LOG.debug("Optimising JRE installation")
+ cmd = (
+ "bin/jlink --add-modules "
+ # Required modules
+ "java.base,java.desktop,java.instrument,java.management,"
+ "java.naming,java.scripting,java.sql,java.xml,jdk.compiler,"
+ # jdk.unsupported contains sun.misc.Unsafe which is required by some dependencies
+ "jdk.unsupported,"
+ # Additional cipher suites
+ "jdk.crypto.cryptoki,"
+ # Archive support
+ "jdk.zipfs,"
+ # Required by MQ broker
+ "jdk.httpserver,jdk.management,jdk.management.agent,"
+ # Required by Spark and Hadoop
+ "java.security.jgss,jdk.security.auth,"
+ # Include required locales
+ "jdk.localedata --include-locales en "
+ # Supplementary args
+ "--compress 2 --strip-debug --no-header-files --no-man-pages "
+ # Output directory
+ "--output " + minimal_jre_path
+ )
+ run(cmd, cwd=target_directory)
+
+ rm_rf(target_directory)
+ os.rename(minimal_jre_path, target_directory)
+
+ def get_java_home(self) -> str | None:
+ """
+ Get JAVA_HOME for this installation of Java.
+ """
+ installed_dir = self.get_installed_dir()
+ if is_mac_os():
+ return os.path.join(installed_dir, "Contents", "Home")
+ return installed_dir
+
+ @property
+ def arch(self) -> str | None:
+ return (
+ "x64" if get_arch() == Arch.amd64 else "aarch64" if get_arch() == Arch.arm64 else None
+ )
+
+ @property
+ def os_name(self) -> str | None:
+ return "linux" if is_linux() else "mac" if is_mac_os() else None
+
+ def _download_url_latest_release(self) -> str:
+ """
+ Return the download URL for latest stable JDK build.
+ """
+ endpoint = (
+ f"https://api.adoptium.net/v3/assets/latest/{self.version}/hotspot?"
+ f"os={self.os_name}&architecture={self.arch}&image_type=jdk"
+ )
+ # Override user-agent because Adoptium API denies service to `requests` library
+ response = requests.get(endpoint, headers={"user-agent": USER_AGENT_STRING}).json()
+ return response[0]["binary"]["package"]["link"]
+
+ def _download_url_fallback(self) -> str:
+ """
+ Return the download URL for pinned JDK build.
+ """
+ semver = JAVA_VERSIONS[self.version]
+ tag_slug = f"jdk-{semver}"
+ semver_safe = semver.replace("+", "_")
+
+ # v8 uses a different tag and version scheme
+ if self.version == "8":
+ semver_safe = semver_safe.replace("-", "")
+ tag_slug = f"jdk{semver}"
+
+ return (
+ f"https://github.com/adoptium/temurin{self.version}-binaries/releases/download/{tag_slug}/"
+ f"OpenJDK{self.version}U-jdk_{self.arch}_{self.os_name}_hotspot_{semver_safe}.tar.gz"
+ )
+
+
+class JavaPackage(Package):
+ def __init__(self, default_version: str = DEFAULT_JAVA_VERSION):
+ super().__init__(name="Java", default_version=default_version)
+
+ def get_versions(self) -> List[str]:
+ return list(JAVA_VERSIONS.keys())
+
+ def _get_installer(self, version):
+ return JavaPackageInstaller(version)
+
+
+java_package = JavaPackage()
diff --git a/localstack-core/localstack/packages/plugins.py b/localstack-core/localstack/packages/plugins.py
new file mode 100644
index 0000000000000..4b4b200af8e0c
--- /dev/null
+++ b/localstack-core/localstack/packages/plugins.py
@@ -0,0 +1,22 @@
+from localstack.packages.api import Package, package
+
+
+@package(name="terraform")
+def terraform_package() -> Package:
+ from .terraform import terraform_package
+
+ return terraform_package
+
+
+@package(name="ffmpeg")
+def ffmpeg_package() -> Package:
+ from localstack.packages.ffmpeg import ffmpeg_package
+
+ return ffmpeg_package
+
+
+@package(name="java")
+def java_package() -> Package:
+ from localstack.packages.java import java_package
+
+ return java_package
diff --git a/localstack-core/localstack/packages/terraform.py b/localstack-core/localstack/packages/terraform.py
new file mode 100644
index 0000000000000..703380c54c07e
--- /dev/null
+++ b/localstack-core/localstack/packages/terraform.py
@@ -0,0 +1,41 @@
+import os
+import platform
+from typing import List
+
+from localstack.packages import InstallTarget, Package, PackageInstaller
+from localstack.packages.core import ArchiveDownloadAndExtractInstaller
+from localstack.utils.files import chmod_r
+from localstack.utils.platform import get_arch
+
+TERRAFORM_VERSION = os.getenv("TERRAFORM_VERSION", "1.5.7")
+TERRAFORM_URL_TEMPLATE = (
+ "https://releases.hashicorp.com/terraform/{version}/terraform_{version}_{os}_{arch}.zip"
+)
+
+
+class TerraformPackage(Package):
+ def __init__(self):
+ super().__init__("Terraform", TERRAFORM_VERSION)
+
+ def get_versions(self) -> List[str]:
+ return [TERRAFORM_VERSION]
+
+ def _get_installer(self, version: str) -> PackageInstaller:
+ return TerraformPackageInstaller("terraform", version)
+
+
+class TerraformPackageInstaller(ArchiveDownloadAndExtractInstaller):
+ def _get_install_marker_path(self, install_dir: str) -> str:
+ return os.path.join(install_dir, "terraform")
+
+ def _get_download_url(self) -> str:
+ system = platform.system().lower()
+ arch = get_arch()
+ return TERRAFORM_URL_TEMPLATE.format(version=TERRAFORM_VERSION, os=system, arch=arch)
+
+ def _install(self, target: InstallTarget) -> None:
+ super()._install(target)
+ chmod_r(self.get_executable_path(), 0o777)
+
+
+terraform_package = TerraformPackage()
diff --git a/localstack-core/localstack/plugins.py b/localstack-core/localstack/plugins.py
new file mode 100644
index 0000000000000..a313032547bba
--- /dev/null
+++ b/localstack-core/localstack/plugins.py
@@ -0,0 +1,76 @@
+import logging
+import os
+import sys
+from pathlib import Path
+
+import yaml
+from plux import Plugin
+
+from localstack import config
+from localstack.runtime import hooks
+from localstack.utils.files import rm_rf
+from localstack.utils.ssl import get_cert_pem_file_path
+
+LOG = logging.getLogger(__name__)
+
+
+@hooks.on_infra_start()
+def deprecation_warnings() -> None:
+ LOG.debug("Checking for the usage of deprecated community features and configs...")
+ from localstack.deprecations import log_deprecation_warnings
+
+ log_deprecation_warnings()
+
+
+@hooks.on_infra_start(should_load=lambda: config.REMOVE_SSL_CERT)
+def delete_cached_certificate():
+ LOG.debug("Removing the cached local SSL certificate")
+ target_file = get_cert_pem_file_path()
+ rm_rf(target_file)
+
+
+class OASPlugin(Plugin):
+ """
+ This plugin allows to register an arbitrary number of OpenAPI specs, e.g., the spec for the public endpoints
+ of localstack.core.
+ The OpenAPIValidator handler uses (as opt-in) all the collected specs to validate the requests and the responses
+ to these public endpoints.
+
+ An OAS plugin assumes the following directory layout.
+
+ my_package
+ βββ sub_package
+ β βββ __init__.py <-- spec file
+ β βββ openapi.yaml
+ β βββ plugins.py <-- plugins
+ βββ plugins.py <-- plugins
+ βββ openapi.yaml <-- spec file
+
+ Each package can have its own OpenAPI yaml spec which is loaded by the correspondent plugin in plugins.py
+ You can simply create a plugin like the following:
+
+ class MyPackageOASPlugin(OASPlugin):
+ name = "my_package"
+
+ The only convention is that plugins.py and openapi.yaml have the same pathname.
+ """
+
+ namespace = "localstack.openapi.spec"
+
+ def __init__(self) -> None:
+ # By convention a plugins.py is at the same level (i.e., same pathname) of the openapi.yaml file.
+ # importlib.resources would be a better approach but has issues with namespace packages in editable mode
+ _module = sys.modules[self.__module__]
+ self.spec_path = Path(
+ os.path.join(os.path.dirname(os.path.abspath(_module.__file__)), "openapi.yaml")
+ )
+ assert self.spec_path.exists()
+ self.spec = {}
+
+ def load(self):
+ with self.spec_path.open("r") as f:
+ self.spec = yaml.safe_load(f)
+
+
+class CoreOASPlugin(OASPlugin):
+ name = "localstack"
diff --git a/localstack-core/localstack/runtime/__init__.py b/localstack-core/localstack/runtime/__init__.py
new file mode 100644
index 0000000000000..99044a674080a
--- /dev/null
+++ b/localstack-core/localstack/runtime/__init__.py
@@ -0,0 +1,5 @@
+from .current import get_current_runtime
+
+__all__ = [
+ "get_current_runtime",
+]
diff --git a/localstack-core/localstack/runtime/analytics.py b/localstack-core/localstack/runtime/analytics.py
new file mode 100644
index 0000000000000..4ef2c4ba59ae0
--- /dev/null
+++ b/localstack-core/localstack/runtime/analytics.py
@@ -0,0 +1,133 @@
+import logging
+import os
+
+from localstack import config
+from localstack.runtime import hooks
+from localstack.utils.analytics import log
+
+LOG = logging.getLogger(__name__)
+
+TRACKED_ENV_VAR = [
+ "ALLOW_NONSTANDARD_REGIONS",
+ "BEDROCK_PREWARM",
+ "CONTAINER_RUNTIME",
+ "DEBUG",
+ "DEFAULT_REGION", # Not functional; deprecated in 0.12.7, removed in 3.0.0
+ "DEFAULT_BEDROCK_MODEL",
+ "DISABLE_CORS_CHECK",
+ "DISABLE_CORS_HEADERS",
+ "DMS_SERVERLESS_DEPROVISIONING_DELAY",
+ "DMS_SERVERLESS_STATUS_CHANGE_WAITING_TIME",
+ "DNS_ADDRESS",
+ "DYNAMODB_ERROR_PROBABILITY",
+ "DYNAMODB_IN_MEMORY",
+ "DYNAMODB_REMOVE_EXPIRED_ITEMS",
+ "EAGER_SERVICE_LOADING",
+ "EC2_VM_MANAGER",
+ "ECS_TASK_EXECUTOR",
+ "EDGE_PORT",
+ "ENABLE_REPLICATOR",
+ "ENFORCE_IAM",
+ "ES_CUSTOM_BACKEND", # deprecated in 0.14.0, removed in 3.0.0
+ "ES_MULTI_CLUSTER", # deprecated in 0.14.0, removed in 3.0.0
+ "ES_ENDPOINT_STRATEGY", # deprecated in 0.14.0, removed in 3.0.0
+ "EVENT_RULE_ENGINE",
+ "IAM_SOFT_MODE",
+ "KINESIS_PROVIDER", # Not functional; deprecated in 2.0.0, removed in 3.0.0
+ "KINESIS_ERROR_PROBABILITY",
+ "KMS_PROVIDER", # defunct since 1.4.0
+ "LAMBDA_DEBUG_MODE",
+ "LAMBDA_DOWNLOAD_AWS_LAYERS",
+ "LAMBDA_EXECUTOR", # Not functional; deprecated in 2.0.0, removed in 3.0.0
+ "LAMBDA_STAY_OPEN_MODE", # Not functional; deprecated in 2.0.0, removed in 3.0.0
+ "LAMBDA_REMOTE_DOCKER", # Not functional; deprecated in 2.0.0, removed in 3.0.0
+ "LAMBDA_CODE_EXTRACT_TIME", # Not functional; deprecated in 2.0.0, removed in 3.0.0
+ "LAMBDA_CONTAINER_REGISTRY", # Not functional; deprecated in 2.0.0, removed in 3.0.0
+ "LAMBDA_FALLBACK_URL", # Not functional; deprecated in 2.0.0, removed in 3.0.0
+ "LAMBDA_FORWARD_URL", # Not functional; deprecated in 2.0.0, removed in 3.0.0
+ "LAMBDA_XRAY_INIT", # Not functional; deprecated in 2.0.0, removed in 3.0.0
+ "LAMBDA_PREBUILD_IMAGES",
+ "LAMBDA_RUNTIME_EXECUTOR",
+ "LEGACY_EDGE_PROXY", # Not functional; deprecated in 1.0.0, removed in 2.0.0
+ "LS_LOG",
+ "MOCK_UNIMPLEMENTED", # Not functional; deprecated in 1.3.0, removed in 3.0.0
+ "OPENSEARCH_ENDPOINT_STRATEGY",
+ "PERSISTENCE",
+ "PERSISTENCE_SINGLE_FILE",
+ "PERSIST_ALL", # defunct since 2.3.2
+ "PORT_WEB_UI",
+ "RDS_MYSQL_DOCKER",
+ "REQUIRE_PRO",
+ "SERVICES",
+ "STRICT_SERVICE_LOADING",
+ "SKIP_INFRA_DOWNLOADS",
+ "SQS_ENDPOINT_STRATEGY",
+ "USE_SINGLE_REGION", # Not functional; deprecated in 0.12.7, removed in 3.0.0
+ "USE_SSL",
+]
+
+PRESENCE_ENV_VAR = [
+ "DATA_DIR",
+ "EDGE_FORWARD_URL", # Not functional; deprecated in 1.4.0, removed in 3.0.0
+ "GATEWAY_LISTEN",
+ "HOSTNAME",
+ "HOSTNAME_EXTERNAL",
+ "HOSTNAME_FROM_LAMBDA",
+ "HOST_TMP_FOLDER", # Not functional; deprecated in 1.0.0, removed in 2.0.0
+ "INIT_SCRIPTS_PATH", # Not functional; deprecated in 1.1.0, removed in 2.0.0
+ "LAMBDA_DEBUG_MODE_CONFIG_PATH",
+ "LEGACY_DIRECTORIES", # Not functional; deprecated in 1.1.0, removed in 2.0.0
+ "LEGACY_INIT_DIR", # Not functional; deprecated in 1.1.0, removed in 2.0.0
+ "LOCALSTACK_HOST",
+ "LOCALSTACK_HOSTNAME",
+ "OUTBOUND_HTTP_PROXY",
+ "OUTBOUND_HTTPS_PROXY",
+ "S3_DIR",
+ "TMPDIR",
+]
+
+
+@hooks.on_infra_start()
+def _publish_config_as_analytics_event():
+ env_vars = list(TRACKED_ENV_VAR)
+
+ for key, value in os.environ.items():
+ if key.startswith("PROVIDER_OVERRIDE_"):
+ env_vars.append(key)
+ elif key.startswith("SYNCHRONOUS_") and key.endswith("_EVENTS"):
+ # these config variables have been removed with 3.0.0
+ env_vars.append(key)
+
+ env_vars = {key: os.getenv(key) for key in env_vars}
+ present_env_vars = {env_var: 1 for env_var in PRESENCE_ENV_VAR if os.getenv(env_var)}
+
+ log.event("config", env_vars=env_vars, set_vars=present_env_vars)
+
+
+class LocalstackContainerInfo:
+ def get_image_variant(self) -> str:
+ for f in os.listdir("/usr/lib/localstack"):
+ if f.startswith(".") and f.endswith("-version"):
+ return f[1:-8]
+ return "unknown"
+
+ def has_docker_socket(self) -> bool:
+ return os.path.exists("/run/docker.sock")
+
+ def to_dict(self):
+ return {
+ "variant": self.get_image_variant(),
+ "has_docker_socket": self.has_docker_socket(),
+ }
+
+
+@hooks.on_infra_start()
+def _publish_container_info():
+ if not config.is_in_docker:
+ return
+
+ try:
+ log.event("container_info", payload=LocalstackContainerInfo().to_dict())
+ except Exception as e:
+ if config.DEBUG_ANALYTICS:
+ LOG.debug("error gathering container information: %s", e)
diff --git a/localstack-core/localstack/runtime/components.py b/localstack-core/localstack/runtime/components.py
new file mode 100644
index 0000000000000..db9662b2e030b
--- /dev/null
+++ b/localstack-core/localstack/runtime/components.py
@@ -0,0 +1,56 @@
+"""
+This package contains code to define and manage the core components that make up a ``LocalstackRuntime``.
+These include:
+ - A ``Gateway``
+ - A ``RuntimeServer`` as the main control loop
+ - A ``ServiceManager`` to manage service plugins (TODO: once the Service concept has been generalized)
+ - ... ?
+
+Components can then be accessed via ``get_current_runtime()``.
+"""
+
+from functools import cached_property
+
+from plux import Plugin, PluginManager
+from rolo.gateway import Gateway
+
+from .server.core import RuntimeServer, RuntimeServerPlugin
+
+
+class Components(Plugin):
+ """
+ A Plugin that allows a specific localstack runtime implementation (aws, snowflake, ...) to expose its
+ own component factory.
+ """
+
+ namespace = "localstack.runtime.components"
+
+ @cached_property
+ def gateway(self) -> Gateway:
+ raise NotImplementedError
+
+ @cached_property
+ def runtime_server(self) -> RuntimeServer:
+ raise NotImplementedError
+
+
+class BaseComponents(Components):
+ """
+ A component base, which includes a ``RuntimeServer`` created from the config variable, and a default
+ ServicePluginManager as ServiceManager.
+ """
+
+ @cached_property
+ def runtime_server(self) -> RuntimeServer:
+ from localstack import config
+
+ # TODO: rename to RUNTIME_SERVER
+ server_type = config.GATEWAY_SERVER
+
+ plugins = PluginManager(RuntimeServerPlugin.namespace)
+
+ if not plugins.exists(server_type):
+ raise ValueError(f"Unknown gateway server type {server_type}")
+
+ plugins.load(server_type)
+ return plugins.get_container(server_type).load_value
diff --git a/localstack-core/localstack/runtime/current.py b/localstack-core/localstack/runtime/current.py
new file mode 100644
index 0000000000000..fa033c58844fa
--- /dev/null
+++ b/localstack-core/localstack/runtime/current.py
@@ -0,0 +1,40 @@
+"""This package gives access to the singleton ``LocalstackRuntime`` instance. This is the only global state
+that should exist within localstack, which contains the singleton ``LocalstackRuntime`` which is currently
+running."""
+
+import threading
+import typing
+
+if typing.TYPE_CHECKING:
+ # make sure we don't have any imports here at runtime, so it can be imported anywhere without conflicts
+ from .runtime import LocalstackRuntime
+
+_runtime: typing.Optional["LocalstackRuntime"] = None
+"""The singleton LocalStack Runtime"""
+_runtime_lock = threading.RLock()
+
+
+def get_current_runtime() -> "LocalstackRuntime":
+ with _runtime_lock:
+ if not _runtime:
+ raise ValueError("LocalStack runtime has not yet been set")
+ return _runtime
+
+
+def set_current_runtime(runtime: "LocalstackRuntime"):
+ with _runtime_lock:
+ global _runtime
+ _runtime = runtime
+
+
+def initialize_runtime() -> "LocalstackRuntime":
+ from localstack.runtime import runtime
+
+ with _runtime_lock:
+ try:
+ return get_current_runtime()
+ except ValueError:
+ pass
+ rt = runtime.create_from_environment()
+ set_current_runtime(rt)
+ return rt
diff --git a/localstack-core/localstack/runtime/events.py b/localstack-core/localstack/runtime/events.py
new file mode 100644
index 0000000000000..2382fab6a47a2
--- /dev/null
+++ b/localstack-core/localstack/runtime/events.py
@@ -0,0 +1,7 @@
+import threading
+
+# TODO: deprecate and replace access with ``get_current_runtime().starting``, ...
+infra_starting = threading.Event()
+infra_ready = threading.Event()
+infra_stopping = threading.Event()
+infra_stopped = threading.Event()
diff --git a/localstack-core/localstack/runtime/exceptions.py b/localstack-core/localstack/runtime/exceptions.py
new file mode 100644
index 0000000000000..b4a4f72e65066
--- /dev/null
+++ b/localstack-core/localstack/runtime/exceptions.py
@@ -0,0 +1,9 @@
+class LocalstackExit(Exception):
+ """
+ This exception can be raised during the startup procedure to terminate localstack with an exit code and
+ a reason.
+ """
+
+ def __init__(self, reason: str = None, code: int = 0):
+ super().__init__(reason)
+ self.code = code
diff --git a/localstack-core/localstack/runtime/hooks.py b/localstack-core/localstack/runtime/hooks.py
new file mode 100644
index 0000000000000..05161679cf54e
--- /dev/null
+++ b/localstack-core/localstack/runtime/hooks.py
@@ -0,0 +1,104 @@
+import functools
+
+from plux import PluginManager, plugin
+
+# plugin namespace constants
+HOOKS_CONFIGURE_LOCALSTACK_CONTAINER = "localstack.hooks.configure_localstack_container"
+HOOKS_ON_RUNTIME_CREATE = "localstack.hooks.on_runtime_create"
+HOOKS_ON_INFRA_READY = "localstack.hooks.on_infra_ready"
+HOOKS_ON_INFRA_START = "localstack.hooks.on_infra_start"
+HOOKS_ON_PRO_INFRA_START = "localstack.hooks.on_pro_infra_start"
+HOOKS_ON_INFRA_SHUTDOWN = "localstack.hooks.on_infra_shutdown"
+HOOKS_PREPARE_HOST = "localstack.hooks.prepare_host"
+
+
+def hook(namespace: str, priority: int = 0, **kwargs):
+ """
+ Decorator for creating functional plugins that have a hook_priority attribute. Hooks with a higher priority value
+ will be executed earlier.
+ """
+
+ def wrapper(fn):
+ fn.hook_priority = priority
+ return plugin(namespace=namespace, **kwargs)(fn)
+
+ return wrapper
+
+
+def hook_spec(namespace: str):
+ """
+ Creates a new hook decorator bound to a namespace.
+
+ on_infra_start = hook_spec("localstack.hooks.on_infra_start")
+
+ @on_infra_start()
+ def foo():
+ pass
+
+ # run all hooks in order
+ on_infra_start.run()
+ """
+ fn = functools.partial(hook, namespace=namespace)
+ # attach hook manager and run method to decorator for convenience calls
+ fn.manager = HookManager(namespace)
+ fn.run = fn.manager.run_in_order
+ return fn
+
+
+class HookManager(PluginManager):
+ def load_all_sorted(self, propagate_exceptions=False):
+ """
+ Loads all hook plugins and sorts them by their hook_priority attribute.
+ """
+ plugins = self.load_all(propagate_exceptions)
+ # the hook_priority attribute is part of the function wrapped in the FunctionPlugin
+ plugins.sort(
+ key=lambda _fn_plugin: getattr(_fn_plugin.fn, "hook_priority", 0), reverse=True
+ )
+ return plugins
+
+ def run_in_order(self, *args, **kwargs):
+ """
+ Loads and runs all plugins in order them with the given arguments.
+ """
+ for fn_plugin in self.load_all_sorted():
+ fn_plugin(*args, **kwargs)
+
+ def __str__(self):
+ return "HookManager(%s)" % self.namespace
+
+ def __repr__(self):
+ return self.__str__()
+
+
+configure_localstack_container = hook_spec(HOOKS_CONFIGURE_LOCALSTACK_CONTAINER)
+"""Hooks to configure the LocalStack container before it starts. Executed on the host when invoking the CLI."""
+
+prepare_host = hook_spec(HOOKS_PREPARE_HOST)
+"""Hooks to prepare the host that's starting LocalStack. Executed on the host when invoking the CLI."""
+
+on_infra_start = hook_spec(HOOKS_ON_INFRA_START)
+"""Hooks that are executed right before starting the LocalStack infrastructure."""
+
+on_runtime_create = hook_spec(HOOKS_ON_RUNTIME_CREATE)
+"""Hooks that are executed right before the LocalstackRuntime is created. These can be used to apply
+patches or otherwise configure the interpreter before any other code is imported."""
+
+on_runtime_start = on_infra_start
+"""Alias for on_infra_start. TODO: switch and deprecated `infra` naming."""
+
+on_pro_infra_start = hook_spec(HOOKS_ON_PRO_INFRA_START)
+"""Hooks that are executed after on_infra_start hooks, and only if LocalStack pro has been activated."""
+
+on_infra_ready = hook_spec(HOOKS_ON_INFRA_READY)
+"""Hooks that are execute after all startup hooks have been executed, and the LocalStack infrastructure has become
+available."""
+
+on_runtime_ready = on_infra_ready
+"""Alias for on_infra_ready. TODO: switch and deprecated `infra` naming."""
+
+on_infra_shutdown = hook_spec(HOOKS_ON_INFRA_SHUTDOWN)
+"""Hooks that are execute when localstack shuts down."""
+
+on_runtime_shutdown = on_infra_shutdown
+"""Alias for on_infra_shutdown. TODO: switch and deprecated `infra` naming."""
diff --git a/localstack-core/localstack/runtime/init.py b/localstack-core/localstack/runtime/init.py
new file mode 100644
index 0000000000000..cb71c9da5af1b
--- /dev/null
+++ b/localstack-core/localstack/runtime/init.py
@@ -0,0 +1,283 @@
+"""Module for initialization hooks https://docs.localstack.cloud/references/init-hooks/"""
+
+import dataclasses
+import logging
+import os.path
+import subprocess
+import time
+from enum import Enum
+from functools import cached_property
+from typing import Dict, List, Optional
+
+from plux import Plugin, PluginManager
+
+from localstack.runtime import hooks
+from localstack.utils.objects import singleton_factory
+
+LOG = logging.getLogger(__name__)
+
+
+class State(Enum):
+ UNKNOWN = "UNKNOWN"
+ RUNNING = "RUNNING"
+ SUCCESSFUL = "SUCCESSFUL"
+ ERROR = "ERROR"
+
+ def __str__(self):
+ return self.name
+
+ def __repr__(self):
+ return self.name
+
+
+class Stage(Enum):
+ BOOT = 0
+ START = 1
+ READY = 2
+ SHUTDOWN = 3
+
+ def __str__(self):
+ return self.name
+
+ def __repr__(self):
+ return self.name
+
+
+@dataclasses.dataclass
+class Script:
+ path: str
+ stage: Stage
+ state: State = State.UNKNOWN
+
+
+class ScriptRunner(Plugin):
+ """
+ Interface for running scripts.
+ """
+
+ namespace = "localstack.init.runner"
+ suffixes = []
+
+ def run(self, path: str) -> None:
+ """
+ Run the given script with the appropriate runtime.
+
+ :param path: the path to the script
+ """
+ raise NotImplementedError
+
+ def should_run(self, script_file: str) -> bool:
+ """
+ Checks whether the given file should be run with this script runner. In case multiple runners
+ evaluate this condition to true on the same file (ideally this doesn't happen), the first one
+ loaded will be used, which is potentially indeterministic.
+
+ :param script_file: the script file to run
+ :return: True if this runner should be used, False otherwise
+ """
+ for suffix in self.suffixes:
+ if script_file.endswith(suffix):
+ return True
+ return False
+
+
+class ShellScriptRunner(ScriptRunner):
+ """
+ Runner that interprets scripts as shell scripts and calls them directly.
+ """
+
+ name = "sh"
+ suffixes = [".sh"]
+
+ def run(self, path: str) -> None:
+ exit_code = subprocess.call(args=[], executable=path)
+ if exit_code != 0:
+ raise OSError("Script %s returned a non-zero exit code %s" % (path, exit_code))
+
+
+class PythonScriptRunner(ScriptRunner):
+ """
+ Runner that uses ``exec`` to run a python script.
+ """
+
+ name = "py"
+ suffixes = [".py"]
+
+ def run(self, path: str) -> None:
+ with open(path, "rb") as fd:
+ exec(fd.read(), {})
+
+
+class InitScriptManager:
+ _stage_directories: Dict[Stage, str] = {
+ Stage.BOOT: "boot.d",
+ Stage.START: "start.d",
+ Stage.READY: "ready.d",
+ Stage.SHUTDOWN: "shutdown.d",
+ }
+
+ script_root: str
+ stage_completed: Dict[Stage, bool]
+
+ def __init__(self, script_root: str):
+ self.script_root = script_root
+ self.stage_completed = {stage: False for stage in Stage}
+ self.runner_manager: PluginManager[ScriptRunner] = PluginManager(ScriptRunner.namespace)
+
+ @cached_property
+ def scripts(self) -> Dict[Stage, List[Script]]:
+ return self._find_scripts()
+
+ def get_script_runner(self, script_file: str) -> Optional[ScriptRunner]:
+ runners = self.runner_manager.load_all()
+ for runner in runners:
+ if runner.should_run(script_file):
+ return runner
+ return None
+
+ def has_script_runner(self, script_file: str) -> bool:
+ return self.get_script_runner(script_file) is not None
+
+ def run_stage(self, stage: Stage) -> List[Script]:
+ """
+ Runs all scripts in the given stage.
+
+ :param stage: the stage to run
+ :return: the scripts that were in the stage
+ """
+ scripts = self.scripts.get(stage, [])
+
+ if self.stage_completed[stage]:
+ LOG.debug("Stage %s already completed, skipping", stage)
+ return scripts
+
+ try:
+ for script in scripts:
+ LOG.debug("Running %s script %s", script.stage, script.path)
+
+ env_original = os.environ.copy()
+
+ try:
+ script.state = State.RUNNING
+ runner = self.get_script_runner(script.path)
+ runner.run(script.path)
+ except Exception as e:
+ script.state = State.ERROR
+ if LOG.isEnabledFor(logging.DEBUG):
+ LOG.exception("Error while running script %s", script)
+ else:
+ LOG.error("Error while running script %s: %s", script, e)
+ else:
+ script.state = State.SUCCESSFUL
+ finally:
+ # Discard env variables overridden in startup script that may cause side-effects
+ for env_var in (
+ "AWS_ACCESS_KEY_ID",
+ "AWS_SECRET_ACCESS_KEY",
+ "AWS_SESSION_TOKEN",
+ "AWS_DEFAULT_REGION",
+ "AWS_PROFILE",
+ "AWS_REGION",
+ ):
+ if env_var in env_original:
+ os.environ[env_var] = env_original[env_var]
+ else:
+ os.environ.pop(env_var, None)
+ finally:
+ self.stage_completed[stage] = True
+
+ return scripts
+
+ def _find_scripts(self) -> Dict[Stage, List[Script]]:
+ scripts = {}
+
+ if self.script_root is None:
+ LOG.debug("Unable to discover init scripts as script_root is None")
+ return {}
+
+ for stage in Stage:
+ scripts[stage] = []
+
+ stage_dir = self._stage_directories[stage]
+ if not stage_dir:
+ continue
+
+ stage_path = os.path.join(self.script_root, stage_dir)
+ if not os.path.isdir(stage_path):
+ continue
+
+ for root, dirs, files in os.walk(stage_path, topdown=True):
+ # from the docs: "When topdown is true, the caller can modify the dirnames list in-place"
+ dirs.sort()
+ files.sort()
+ for file in files:
+ script_path = os.path.abspath(os.path.join(root, file))
+ if not os.path.isfile(script_path):
+ continue
+
+ # only add the script if there's a runner for it
+ if not self.has_script_runner(script_path):
+ LOG.debug("No runner available for script %s", script_path)
+ continue
+
+ scripts[stage].append(Script(path=script_path, stage=stage))
+ LOG.debug("Init scripts discovered: %s", scripts)
+
+ return scripts
+
+
+# runtime integration
+
+
+@singleton_factory
+def init_script_manager() -> InitScriptManager:
+ from localstack import config
+
+ return InitScriptManager(script_root=config.dirs.init)
+
+
+@hooks.on_infra_start()
+def _run_init_scripts_on_start():
+ # this is a hack since we currently cannot know whether boot scripts have been executed or not
+ init_script_manager().stage_completed[Stage.BOOT] = True
+ _run_and_log(Stage.START)
+
+
+@hooks.on_infra_ready()
+def _run_init_scripts_on_ready():
+ _run_and_log(Stage.READY)
+
+
+@hooks.on_infra_shutdown()
+def _run_init_scripts_on_shutdown():
+ _run_and_log(Stage.SHUTDOWN)
+
+
+def _run_and_log(stage: Stage):
+ from localstack.utils.analytics import log
+
+ then = time.time()
+ scripts = init_script_manager().run_stage(stage)
+ took = (time.time() - then) * 1000
+
+ if scripts:
+ log.event("run_init", {"stage": stage.name, "scripts": len(scripts), "duration": took})
+
+
+def main():
+ """
+ Run the init scripts for a particular stage. For example, to run all boot scripts run::
+
+ python -m localstack.runtime.init BOOT
+
+ The __main__ entrypoint is currently mainly used for the docker-entrypoint.sh. Other stages
+ are executed from runtime hooks.
+ """
+ import sys
+
+ stage = Stage[sys.argv[1]]
+ init_script_manager().run_stage(stage)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/localstack-core/localstack/runtime/legacy.py b/localstack-core/localstack/runtime/legacy.py
new file mode 100644
index 0000000000000..2a2f54c562929
--- /dev/null
+++ b/localstack-core/localstack/runtime/legacy.py
@@ -0,0 +1,17 @@
+"""Adapter code for the legacy runtime to make sure the new runtime is compatible with the old one,
+and at the same time doesn't need ``localstack.services.infra``, which imports AWS-specific modules."""
+
+import logging
+import os
+import signal
+
+LOG = logging.getLogger(__name__)
+
+
+def signal_supervisor_restart():
+ # TODO: we should think about moving the localstack-supervisor into a script in the runtime,
+ # and make `signal_supervisor_restart` part of the supervisor code.
+ if pid := os.environ.get("SUPERVISOR_PID"):
+ os.kill(int(pid), signal.SIGUSR1)
+ else:
+ LOG.warning("could not signal supervisor to restart localstack")
diff --git a/localstack-core/localstack/runtime/main.py b/localstack-core/localstack/runtime/main.py
new file mode 100644
index 0000000000000..3a0357e230ad0
--- /dev/null
+++ b/localstack-core/localstack/runtime/main.py
@@ -0,0 +1,93 @@
+"""This is the entrypoint used to start the localstack runtime. It starts the infrastructure and also
+manages the interaction with the operating system - mostly signal handlers for now."""
+
+import signal
+import sys
+import traceback
+
+from localstack import config, constants
+from localstack.runtime.exceptions import LocalstackExit
+
+
+def print_runtime_information(in_docker: bool = False):
+ # FIXME: this is legacy code from the old CLI, reconcile with new CLI and runtime output
+ from localstack.utils.container_networking import get_main_container_name
+ from localstack.utils.container_utils.container_client import ContainerException
+ from localstack.utils.docker_utils import DOCKER_CLIENT
+
+ print()
+ print(f"LocalStack version: {constants.VERSION}")
+ if in_docker:
+ try:
+ container_name = get_main_container_name()
+ print("LocalStack Docker container name: %s" % container_name)
+ inspect_result = DOCKER_CLIENT.inspect_container(container_name)
+ container_id = inspect_result["Id"]
+ print("LocalStack Docker container id: %s" % container_id[:12])
+ image_details = DOCKER_CLIENT.inspect_image(inspect_result["Image"])
+ digests = image_details.get("RepoDigests") or ["Unavailable"]
+ print("LocalStack Docker image sha: %s" % digests[0])
+ except ContainerException:
+ print(
+ "LocalStack Docker container info: Failed to inspect the LocalStack docker container. "
+ "This is likely because the docker socket was not mounted into the container. "
+ "Without access to the docker socket, LocalStack will not function properly. Please "
+ "consult the LocalStack documentation on how to correctly start up LocalStack. ",
+ end="",
+ )
+ if config.DEBUG:
+ print("Docker debug information:")
+ traceback.print_exc()
+ else:
+ print(
+ "You can run LocalStack with `DEBUG=1` to get more information about the error."
+ )
+
+ if config.LOCALSTACK_BUILD_DATE:
+ print("LocalStack build date: %s" % config.LOCALSTACK_BUILD_DATE)
+
+ if config.LOCALSTACK_BUILD_GIT_HASH:
+ print("LocalStack build git hash: %s" % config.LOCALSTACK_BUILD_GIT_HASH)
+
+ print()
+
+
+def main():
+ from localstack.logging.setup import setup_logging_from_config
+ from localstack.runtime import current
+
+ try:
+ setup_logging_from_config()
+ runtime = current.initialize_runtime()
+ except Exception as e:
+ sys.stdout.write(f"ERROR: The LocalStack Runtime could not be initialized: {e}\n")
+ sys.stdout.flush()
+ raise
+
+ # TODO: where should this go?
+ print_runtime_information()
+
+ # signal handler to make sure SIGTERM properly shuts down localstack
+ def _terminate_localstack(sig: int, frame):
+ sys.stdout.write(f"Localstack runtime received signal {sig}\n")
+ sys.stdout.flush()
+ runtime.exit(0)
+
+ signal.signal(signal.SIGINT, _terminate_localstack)
+ signal.signal(signal.SIGTERM, _terminate_localstack)
+
+ try:
+ runtime.run()
+ except LocalstackExit as e:
+ sys.stdout.write(f"Localstack returning with exit code {e.code}. Reason: {e}")
+ sys.exit(e.code)
+ except Exception as e:
+ sys.stdout.write(f"ERROR: the LocalStack runtime exited unexpectedly: {e}\n")
+ sys.stdout.flush()
+ raise
+
+ sys.exit(runtime.exit_code)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/localstack-core/localstack/runtime/patches.py b/localstack-core/localstack/runtime/patches.py
new file mode 100644
index 0000000000000..4772a480bfee1
--- /dev/null
+++ b/localstack-core/localstack/runtime/patches.py
@@ -0,0 +1,70 @@
+"""
+System-wide patches that should be applied.
+"""
+
+from localstack.runtime import hooks
+from localstack.utils.patch import patch
+
+
+def patch_thread_pool():
+ """
+ This patch to ThreadPoolExecutor makes the executor remove the threads it creates from the global
+ ``_thread_queues`` of ``concurrent.futures.thread``, which joins all created threads at python exit and
+ will block interpreter shutdown if any threads are still running, even if they are daemon threads.
+ """
+
+ import concurrent.futures.thread
+
+ @patch(concurrent.futures.thread.ThreadPoolExecutor._adjust_thread_count)
+ def _adjust_thread_count(fn, self) -> None:
+ fn(self)
+
+ for t in self._threads:
+ if not t.daemon:
+ continue
+ try:
+ del concurrent.futures.thread._threads_queues[t]
+ except KeyError:
+ pass
+
+
+def patch_urllib3_connection_pool(**constructor_kwargs):
+ """
+ Override the default parameters of HTTPConnectionPool, e.g., set the pool size via maxsize=16
+ """
+ try:
+ from urllib3 import connectionpool, poolmanager
+
+ class MyHTTPSConnectionPool(connectionpool.HTTPSConnectionPool):
+ def __init__(self, *args, **kwargs):
+ kwargs.update(constructor_kwargs)
+ super(MyHTTPSConnectionPool, self).__init__(*args, **kwargs)
+
+ poolmanager.pool_classes_by_scheme["https"] = MyHTTPSConnectionPool
+
+ class MyHTTPConnectionPool(connectionpool.HTTPConnectionPool):
+ def __init__(self, *args, **kwargs):
+ kwargs.update(constructor_kwargs)
+ super(MyHTTPConnectionPool, self).__init__(*args, **kwargs)
+
+ poolmanager.pool_classes_by_scheme["http"] = MyHTTPConnectionPool
+ except Exception:
+ pass
+
+
+_applied = False
+
+
+@hooks.on_runtime_start(priority=100) # apply patches earlier than other hooks
+def apply_runtime_patches():
+ # FIXME: find a better way to apply system-wide patches
+ global _applied
+ if _applied:
+ return
+ _applied = True
+
+ from localstack.http.duplex_socket import enable_duplex_socket
+
+ patch_urllib3_connection_pool(maxsize=128)
+ patch_thread_pool()
+ enable_duplex_socket()
diff --git a/localstack-core/localstack/runtime/runtime.py b/localstack-core/localstack/runtime/runtime.py
new file mode 100644
index 0000000000000..1e5d4e6ab5b21
--- /dev/null
+++ b/localstack-core/localstack/runtime/runtime.py
@@ -0,0 +1,203 @@
+import logging
+import os
+import threading
+
+from plux import PluginManager
+
+from localstack import config, constants
+from localstack.runtime import events, hooks
+from localstack.utils import files, functions, net, sync, threads
+
+from .components import Components
+
+LOG = logging.getLogger(__name__)
+
+
+class LocalstackRuntime:
+ """
+ The localstack runtime. It has the following responsibilities:
+
+ - Manage localstack filesystem directories
+ - Execute runtime lifecycle hook plugins from ``localstack.runtime.hooks``.
+ - Manage the localstack SSL certificate
+ - Serve the gateway (It uses a ``RuntimeServer`` to serve a ``Gateway`` instance coming from the
+ ``Components`` factory.)
+ """
+
+ def __init__(self, components: Components):
+ self.components = components
+
+ # at some point, far far in the future, we should no longer access a global config object, but rather
+ # the one from the current runtime. This will allow us to truly instantiate multiple localstack
+ # runtime instances in one process, which can be useful for many different things. but there is too
+ # much global state at the moment think about this seriously. however, this assignment here can
+ # serve as a reminder to avoid global state in general.
+ self.config = config
+
+ # TODO: move away from `localstack.runtime.events` and instantiate new `threading.Event()` here
+ # instead
+ self.starting = events.infra_starting
+ self.ready = events.infra_ready
+ self.stopping = events.infra_stopping
+ self.stopped = events.infra_stopped
+ self.exit_code = 0
+ self._lifecycle_lock = threading.RLock()
+
+ def run(self):
+ """
+ Start the main control loop of the runtime and block the thread. This will initialize the
+ filesystem, run all lifecycle hooks, initialize the gateway server, and then serve the
+ ``RuntimeServer`` until ``shutdown()`` is called.
+ """
+ # indicates to the environment that this is an "infra process" (old terminology referring to the
+ # localstack runtime). this is necessary for disabling certain hooks that may run in the context of
+ # the CLI host mode. TODO: should not be needed over time.
+ os.environ[constants.LOCALSTACK_INFRA_PROCESS] = "1"
+
+ self._init_filesystem()
+ self._on_starting()
+ self._init_gateway_server()
+
+ # since we are blocking the main thread with the runtime server, we need to run the monitor that
+ # prints the ready marker asynchronously. this is different from how the runtime was started in the
+ # past, where the server was running in a thread.
+ # TODO: ideally we pass down a `shutdown` event that can be waited on so we can cancel the thread
+ # if the runtime shuts down beforehand
+ threading.Thread(target=self._run_ready_monitor, daemon=True).start()
+
+ # run the main control loop of the server and block execution
+ try:
+ self.components.runtime_server.run()
+ finally:
+ self._on_return()
+
+ def exit(self, code: int = 0):
+ """
+ Sets the exit code and runs ``shutdown``. It does not actually call ``sys.exit``, this is for the
+ caller to do.
+
+ :param code: the exit code to be set
+ """
+ self.exit_code = code
+ # we don't know yet why, but shutdown does not work on the main thread
+ threading.Thread(target=self.shutdown, name="Runtime-Shutdown").start()
+
+ def shutdown(self):
+ """
+ Initiates an orderly shutdown of the runtime by stopping the main control loop of the
+ ``RuntimeServer``. The shutdown hooks are actually called by the main control loop (in the main
+ thread) after it returns.
+ """
+ with self._lifecycle_lock:
+ if self.stopping.is_set():
+ return
+ self.stopping.set()
+
+ LOG.debug("[shutdown] Running shutdown hooks ...")
+ functions.call_safe(
+ hooks.on_runtime_shutdown.run,
+ exception_message="[shutdown] error calling shutdown hook",
+ )
+ LOG.debug("[shutdown] Shutting down runtime server ...")
+ self.components.runtime_server.shutdown()
+
+ def is_ready(self) -> bool:
+ return self.ready.is_set()
+
+ def _init_filesystem(self):
+ self._clear_tmp_directory()
+ self.config.dirs.mkdirs()
+
+ def _init_gateway_server(self):
+ from localstack.utils.ssl import create_ssl_cert, install_predefined_cert_if_available
+
+ install_predefined_cert_if_available()
+ serial_number = self.config.GATEWAY_LISTEN[0].port
+ _, cert_file_name, key_file_name = create_ssl_cert(serial_number=serial_number)
+ ssl_creds = (cert_file_name, key_file_name)
+
+ self.components.runtime_server.register(
+ self.components.gateway, self.config.GATEWAY_LISTEN, ssl_creds
+ )
+
+ def _on_starting(self):
+ self.starting.set()
+ hooks.on_runtime_start.run()
+
+ def _on_ready(self):
+ hooks.on_runtime_ready.run()
+ print(constants.READY_MARKER_OUTPUT, flush=True)
+ self.ready.set()
+
+ def _on_return(self):
+ LOG.debug("[shutdown] Cleaning up resources ...")
+ self._cleanup_resources()
+ self.stopped.set()
+ LOG.debug("[shutdown] Completed, bye!")
+
+ def _run_ready_monitor(self):
+ self._wait_for_gateway()
+ self._on_ready()
+
+ def _wait_for_gateway(self):
+ host_and_port = self.config.GATEWAY_LISTEN[0]
+
+ if not sync.poll_condition(
+ lambda: net.is_port_open(host_and_port.port), timeout=15, interval=0.3
+ ):
+ if LOG.isEnabledFor(logging.DEBUG):
+ # make another call with quiet=False to print detailed error logs
+ net.is_port_open(host_and_port.port, quiet=False)
+ raise TimeoutError(f"gave up waiting for gateway server to start on {host_and_port}")
+
+ def _clear_tmp_directory(self):
+ if self.config.CLEAR_TMP_FOLDER:
+ # try to clear temp dir on startup
+ try:
+ files.rm_rf(self.config.dirs.tmp)
+ except PermissionError as e:
+ LOG.error(
+ "unable to delete temp folder %s: %s, please delete manually or you will "
+ "keep seeing these errors.",
+ self.config.dirs.tmp,
+ e,
+ )
+
+ def _cleanup_resources(self):
+ threads.cleanup_threads_and_processes()
+ self._clear_tmp_directory()
+
+
+def create_from_environment() -> LocalstackRuntime:
+ """
+ Creates a new runtime instance from the current environment. It uses a plugin manager to resolve the
+ necessary components from the ``localstack.runtime.components`` plugin namespace to start the runtime.
+
+ :return: a new LocalstackRuntime instance
+ """
+ hooks.on_runtime_create.run()
+
+ plugin_manager = PluginManager(Components.namespace)
+ if config.RUNTIME_COMPONENTS:
+ try:
+ component = plugin_manager.load(config.RUNTIME_COMPONENTS)
+ return LocalstackRuntime(component)
+ except Exception as e:
+ raise ValueError(
+ f"Could not load runtime components from config RUNTIME_COMPONENTS={config.RUNTIME_COMPONENTS}: {e}."
+ ) from e
+ components = plugin_manager.load_all()
+
+ if not components:
+ raise ValueError(
+ f"No component plugins found in namespace {Components.namespace}. Are entry points created "
+ f"correctly?"
+ )
+
+ if len(components) > 1:
+ LOG.warning(
+ "There are more than one component plugins, using the first one which is %s",
+ components[0].name,
+ )
+
+ return LocalstackRuntime(components[0])
diff --git a/localstack-core/localstack/runtime/server/__init__.py b/localstack-core/localstack/runtime/server/__init__.py
new file mode 100644
index 0000000000000..808f22795246a
--- /dev/null
+++ b/localstack-core/localstack/runtime/server/__init__.py
@@ -0,0 +1,5 @@
+from localstack.runtime.server.core import RuntimeServer
+
+__all__ = [
+ "RuntimeServer",
+]
diff --git a/localstack-core/localstack/runtime/server/core.py b/localstack-core/localstack/runtime/server/core.py
new file mode 100644
index 0000000000000..137f276f3d496
--- /dev/null
+++ b/localstack-core/localstack/runtime/server/core.py
@@ -0,0 +1,51 @@
+from plux import Plugin
+from rolo.gateway import Gateway
+
+from localstack import config
+
+
+class RuntimeServer:
+ """
+ The main network IO loop of LocalStack. This could be twisted, hypercorn, or any other server
+ implementation.
+ """
+
+ def register(
+ self,
+ gateway: Gateway,
+ listen: list[config.HostAndPort],
+ ssl_creds: tuple[str, str] | None = None,
+ ):
+ """
+ Registers the Gateway and the port configuration into the server. Some servers like ``twisted`` or
+ ``hypercorn`` support multiple calls to ``register``, allowing you to serve several Gateways
+ through a single event loop.
+
+ :param gateway: the gateway to serve
+ :param listen: the host and port configuration
+ :param ssl_creds: ssl credentials (certificate file path, key file path)
+ """
+ raise NotImplementedError
+
+ def run(self):
+ """
+ Run the server and block the thread.
+ """
+ raise NotImplementedError
+
+ def shutdown(self):
+ """
+ Shutdown the running server.
+ """
+ raise NotImplementedError
+
+
+class RuntimeServerPlugin(Plugin):
+ """
+ Plugin that serves as a factory for specific ```RuntimeServer`` implementations.
+ """
+
+ namespace = "localstack.runtime.server"
+
+ def load(self, *args, **kwargs) -> RuntimeServer:
+ raise NotImplementedError
diff --git a/localstack-core/localstack/runtime/server/hypercorn.py b/localstack-core/localstack/runtime/server/hypercorn.py
new file mode 100644
index 0000000000000..ce15ea3d043e0
--- /dev/null
+++ b/localstack-core/localstack/runtime/server/hypercorn.py
@@ -0,0 +1,68 @@
+import asyncio
+import threading
+
+from hypercorn import Config
+from hypercorn.asyncio import serve
+from rolo.gateway import Gateway
+from rolo.gateway.asgi import AsgiGateway
+
+from localstack import config
+from localstack.logging.setup import setup_hypercorn_logger
+
+from .core import RuntimeServer
+
+
+class HypercornRuntimeServer(RuntimeServer):
+ def __init__(self):
+ self.loop = asyncio.get_event_loop()
+
+ self._close = asyncio.Event()
+ self._closed = threading.Event()
+
+ self._futures = []
+
+ def register(
+ self,
+ gateway: Gateway,
+ listen: list[config.HostAndPort],
+ ssl_creds: tuple[str, str] | None = None,
+ ):
+ hypercorn_config = Config()
+ hypercorn_config.h11_pass_raw_headers = True
+ hypercorn_config.bind = [str(host_and_port) for host_and_port in listen]
+ # hypercorn_config.use_reloader = use_reloader
+
+ setup_hypercorn_logger(hypercorn_config)
+
+ if ssl_creds:
+ cert_file_name, key_file_name = ssl_creds
+ hypercorn_config.certfile = cert_file_name
+ hypercorn_config.keyfile = key_file_name
+
+ app = AsgiGateway(gateway, event_loop=self.loop)
+
+ future = asyncio.run_coroutine_threadsafe(
+ serve(app, hypercorn_config, shutdown_trigger=self._shutdown_trigger),
+ self.loop,
+ )
+ self._futures.append(future)
+
+ def run(self):
+ self.loop.run_forever()
+
+ def shutdown(self):
+ self._close.set()
+ asyncio.run_coroutine_threadsafe(self._set_closed(), self.loop)
+ # TODO: correctly wait for all hypercorn serve coroutines to finish
+ asyncio.run_coroutine_threadsafe(self.loop.shutdown_asyncgens(), self.loop)
+ self.loop.shutdown_default_executor()
+ self.loop.stop()
+
+ async def _wait_server_stopped(self):
+ self._closed.set()
+
+ async def _set_closed(self):
+ self._close.set()
+
+ async def _shutdown_trigger(self):
+ await self._close.wait()
diff --git a/localstack-core/localstack/runtime/server/plugins.py b/localstack-core/localstack/runtime/server/plugins.py
new file mode 100644
index 0000000000000..95746e110375d
--- /dev/null
+++ b/localstack-core/localstack/runtime/server/plugins.py
@@ -0,0 +1,19 @@
+from localstack.runtime.server.core import RuntimeServer, RuntimeServerPlugin
+
+
+class TwistedRuntimeServerPlugin(RuntimeServerPlugin):
+ name = "twisted"
+
+ def load(self, *args, **kwargs) -> RuntimeServer:
+ from .twisted import TwistedRuntimeServer
+
+ return TwistedRuntimeServer()
+
+
+class HypercornRuntimeServerPlugin(RuntimeServerPlugin):
+ name = "hypercorn"
+
+ def load(self, *args, **kwargs) -> RuntimeServer:
+ from .hypercorn import HypercornRuntimeServer
+
+ return HypercornRuntimeServer()
diff --git a/localstack-core/localstack/runtime/server/twisted.py b/localstack-core/localstack/runtime/server/twisted.py
new file mode 100644
index 0000000000000..eba02ae16422c
--- /dev/null
+++ b/localstack-core/localstack/runtime/server/twisted.py
@@ -0,0 +1,57 @@
+from rolo.gateway import Gateway
+from rolo.serving.twisted import TwistedGateway
+from twisted.internet import endpoints, reactor, ssl
+
+from localstack import config
+from localstack.aws.serving.twisted import TLSMultiplexerFactory, stop_thread_pool
+from localstack.utils import patch
+
+from .core import RuntimeServer
+
+
+class TwistedRuntimeServer(RuntimeServer):
+ def __init__(self):
+ self.thread_pool = None
+
+ def register(
+ self,
+ gateway: Gateway,
+ listen: list[config.HostAndPort],
+ ssl_creds: tuple[str, str] | None = None,
+ ):
+ # setup twisted webserver Site
+ site = TwistedGateway(gateway)
+
+ # configure ssl
+ if ssl_creds:
+ cert_file_name, key_file_name = ssl_creds
+ context_factory = ssl.DefaultOpenSSLContextFactory(key_file_name, cert_file_name)
+ context_factory.getContext().use_certificate_chain_file(cert_file_name)
+ protocol_factory = TLSMultiplexerFactory(context_factory, False, site)
+ else:
+ protocol_factory = site
+
+ # add endpoint for each host/port combination
+ for host_and_port in listen:
+ if config.is_ipv6_address(host_and_port.host):
+ endpoint = endpoints.TCP6ServerEndpoint(
+ reactor, host_and_port.port, interface=host_and_port.host
+ )
+ else:
+ # TODO: interface = host?
+ endpoint = endpoints.TCP4ServerEndpoint(reactor, host_and_port.port)
+ endpoint.listen(protocol_factory)
+
+ def run(self):
+ reactor.suggestThreadPoolSize(config.GATEWAY_WORKER_COUNT)
+ self.thread_pool = reactor.getThreadPool()
+ patch.patch(self.thread_pool.stop)(stop_thread_pool)
+
+ # we don't need signal handlers, since all they do is call ``reactor`` stop, which we expect the
+ # caller to do via ``shutdown``.
+ return reactor.run(installSignalHandlers=False)
+
+ def shutdown(self):
+ if self.thread_pool:
+ self.thread_pool.stop(timeout=10)
+ reactor.stop()
diff --git a/localstack-core/localstack/runtime/shutdown.py b/localstack-core/localstack/runtime/shutdown.py
new file mode 100644
index 0000000000000..a64dab86ef930
--- /dev/null
+++ b/localstack-core/localstack/runtime/shutdown.py
@@ -0,0 +1,73 @@
+import logging
+from typing import Any, Callable
+
+from localstack.runtime import hooks
+from localstack.utils.functions import call_safe
+
+LOG = logging.getLogger(__name__)
+
+SERVICE_SHUTDOWN_PRIORITY = -10
+"""Shutdown hook priority for shutting down service plugins."""
+
+
+class ShutdownHandlers:
+ """
+ Register / unregister shutdown handlers. All registered shutdown handlers should execute as fast as possible.
+ Blocking shutdown handlers will block infra shutdown.
+ """
+
+ def __init__(self):
+ self._callbacks = []
+
+ def register(self, shutdown_handler: Callable[[], Any]) -> None:
+ """
+ Register shutdown handler. Handler should not block or take more than a couple seconds.
+
+ :param shutdown_handler: Callable without parameters
+ """
+ self._callbacks.append(shutdown_handler)
+
+ def unregister(self, shutdown_handler: Callable[[], Any]) -> None:
+ """
+ Unregister a handler. Idempotent operation.
+
+ :param shutdown_handler: Shutdown handler which was previously registered
+ """
+ try:
+ self._callbacks.remove(shutdown_handler)
+ except ValueError:
+ pass
+
+ def run(self) -> None:
+ """
+ Execute shutdown handlers in reverse order of registration.
+ Should only be called once, on shutdown.
+ """
+ for callback in reversed(list(self._callbacks)):
+ call_safe(callback)
+
+
+SHUTDOWN_HANDLERS = ShutdownHandlers()
+"""Shutdown handlers run with default priority in an on_infra_shutdown hook."""
+
+ON_AFTER_SERVICE_SHUTDOWN_HANDLERS = ShutdownHandlers()
+"""Shutdown handlers that are executed after all services have been shut down."""
+
+
+@hooks.on_infra_shutdown()
+def run_shutdown_handlers():
+ SHUTDOWN_HANDLERS.run()
+
+
+@hooks.on_infra_shutdown(priority=SERVICE_SHUTDOWN_PRIORITY)
+def shutdown_services():
+ # TODO: this belongs into the shutdown procedure of a `Platform` or `RuntimeContainer` class.
+ from localstack.services.plugins import SERVICE_PLUGINS
+
+ LOG.info("[shutdown] Stopping all services")
+ SERVICE_PLUGINS.stop_all_services()
+
+
+@hooks.on_infra_shutdown(priority=SERVICE_SHUTDOWN_PRIORITY - 10)
+def run_on_after_service_shutdown_handlers():
+ ON_AFTER_SERVICE_SHUTDOWN_HANDLERS.run()
diff --git a/localstack/services/sns/__init__.py b/localstack-core/localstack/services/__init__.py
similarity index 100%
rename from localstack/services/sns/__init__.py
rename to localstack-core/localstack/services/__init__.py
diff --git a/localstack/services/sqs/__init__.py b/localstack-core/localstack/services/acm/__init__.py
similarity index 100%
rename from localstack/services/sqs/__init__.py
rename to localstack-core/localstack/services/acm/__init__.py
diff --git a/localstack-core/localstack/services/acm/provider.py b/localstack-core/localstack/services/acm/provider.py
new file mode 100644
index 0000000000000..7425b88832e6b
--- /dev/null
+++ b/localstack-core/localstack/services/acm/provider.py
@@ -0,0 +1,136 @@
+from moto import settings as moto_settings
+from moto.acm import models as acm_models
+
+from localstack.aws.api import RequestContext, handler
+from localstack.aws.api.acm import (
+ AcmApi,
+ ListCertificatesRequest,
+ ListCertificatesResponse,
+ RequestCertificateRequest,
+ RequestCertificateResponse,
+)
+from localstack.services import moto
+from localstack.utils.patch import patch
+
+# reduce the validation wait time from 60 (default) to 10 seconds
+moto_settings.ACM_VALIDATION_WAIT = min(10, moto_settings.ACM_VALIDATION_WAIT)
+
+
+@patch(acm_models.CertBundle.describe)
+def describe(describe_orig, self):
+ # TODO fix! Terrible hack (for parity). Moto adds certain required fields only if status is PENDING_VALIDATION.
+ cert_status = self.status
+ self.status = "PENDING_VALIDATION"
+ try:
+ result = describe_orig(self)
+ finally:
+ self.status = cert_status
+
+ cert = result.get("Certificate", {})
+ cert["Status"] = cert_status
+ sans = cert.setdefault("SubjectAlternativeNames", [])
+ sans_summaries = cert.setdefault("SubjectAlternativeNameSummaries", sans)
+
+ # add missing attributes in ACM certs that cause Terraform to fail
+ addenda = {
+ "RenewalEligibility": "INELIGIBLE",
+ "KeyUsages": [{"Name": "DIGITAL_SIGNATURE"}, {"Name": "KEY_ENCIPHERMENT"}],
+ "ExtendedKeyUsages": [],
+ "Options": {"CertificateTransparencyLoggingPreference": "ENABLED"},
+ }
+ addenda["DomainValidationOptions"] = options = cert.get("DomainValidationOptions")
+ if not options:
+ options = addenda["DomainValidationOptions"] = [
+ {"ValidationMethod": cert.get("ValidationMethod")}
+ ]
+
+ for option in options:
+ option["DomainName"] = domain_name = option.get("DomainName") or cert.get("DomainName")
+ validation_domain = option.get("ValidationDomain") or f"test.{domain_name.lstrip('*.')}"
+ option["ValidationDomain"] = validation_domain
+ option["ValidationMethod"] = option.get("ValidationMethod") or "DNS"
+ status = option.get("ValidationStatus")
+ option["ValidationStatus"] = (
+ "SUCCESS" if (status is None or cert_status == "ISSUED") else status
+ )
+ if option["ValidationMethod"] == "EMAIL":
+ option["ValidationEmails"] = option.get("ValidationEmails") or [
+ f"admin@{self.common_name}"
+ ]
+ test_record = {
+ "Name": validation_domain,
+ "Type": "CNAME",
+ "Value": "test123",
+ }
+ option["ResourceRecord"] = option.get("ResourceRecord") or test_record
+ option["ResourceRecord"]["Name"] = option["ResourceRecord"]["Name"].replace(".*.", ".")
+
+ for key, value in addenda.items():
+ if not cert.get(key):
+ cert[key] = value
+ cert["Serial"] = str(cert.get("Serial") or "")
+
+ if cert.get("KeyAlgorithm") in ["RSA_1024", "RSA_2048"]:
+ cert["KeyAlgorithm"] = cert["KeyAlgorithm"].replace("RSA_", "RSA-")
+
+ # add subject alternative names
+ if cert["DomainName"] not in sans:
+ sans.append(cert["DomainName"])
+ if cert["DomainName"] not in sans_summaries:
+ sans_summaries.append(cert["DomainName"])
+
+ if "HasAdditionalSubjectAlternativeNames" not in cert:
+ cert["HasAdditionalSubjectAlternativeNames"] = False
+
+ if not cert.get("ExtendedKeyUsages"):
+ cert["ExtendedKeyUsages"] = [
+ {"Name": "TLS_WEB_SERVER_AUTHENTICATION", "OID": "1.3.6.1.0.1.2.3.0"},
+ {"Name": "TLS_WEB_CLIENT_AUTHENTICATION", "OID": "1.3.6.1.0.1.2.3.4"},
+ ]
+
+ # remove attributes prior to validation
+ if not cert.get("Status") == "ISSUED":
+ attrs = ["CertificateAuthorityArn", "IssuedAt", "NotAfter", "NotBefore", "Serial"]
+ for attr in attrs:
+ cert.pop(attr, None)
+ cert["KeyUsages"] = []
+ cert["ExtendedKeyUsages"] = []
+
+ return result
+
+
+class AcmProvider(AcmApi):
+ @handler("RequestCertificate", expand=False)
+ def request_certificate(
+ self,
+ context: RequestContext,
+ request: RequestCertificateRequest,
+ ) -> RequestCertificateResponse:
+ response: RequestCertificateResponse = moto.call_moto(context)
+
+ cert_arn = response["CertificateArn"]
+ backend = acm_models.acm_backends[context.account_id][context.region]
+ cert = backend._certificates[cert_arn]
+ if not hasattr(cert, "domain_validation_options"):
+ cert.domain_validation_options = request.get("DomainValidationOptions")
+
+ return response
+
+ @handler("ListCertificates", expand=False)
+ def list_certificates(
+ self,
+ context: RequestContext,
+ request: ListCertificatesRequest,
+ ) -> ListCertificatesResponse:
+ response = moto.call_moto(context)
+ summaries = response.get("CertificateSummaryList") or []
+ for summary in summaries:
+ if "KeyUsages" in summary:
+ summary["KeyUsages"] = [
+ k["Name"] if isinstance(k, dict) else k for k in summary["KeyUsages"]
+ ]
+ if "ExtendedKeyUsages" in summary:
+ summary["ExtendedKeyUsages"] = [
+ k["Name"] if isinstance(k, dict) else k for k in summary["ExtendedKeyUsages"]
+ ]
+ return response
diff --git a/localstack/utils/__init__.py b/localstack-core/localstack/services/apigateway/__init__.py
similarity index 100%
rename from localstack/utils/__init__.py
rename to localstack-core/localstack/services/apigateway/__init__.py
diff --git a/localstack-core/localstack/services/apigateway/exporter.py b/localstack-core/localstack/services/apigateway/exporter.py
new file mode 100644
index 0000000000000..42614ab4def8f
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/exporter.py
@@ -0,0 +1,325 @@
+import abc
+import json
+from typing import Type
+
+from apispec import APISpec
+
+from localstack.aws.api.apigateway import ListOfModel
+from localstack.aws.connect import connect_to
+from localstack.utils.time import TIMESTAMP_FORMAT_TZ, timestamp
+
+from .helpers import OpenAPIExt
+
+# TODO:
+# - handle more extensions
+# see the list in OpenAPIExt
+# currently handled:
+# - x-amazon-apigateway-integration
+#
+
+
+class _BaseOpenApiExporter(abc.ABC):
+ VERSION = None
+
+ def __init__(self):
+ self.export_formats = {"application/json": "to_dict", "application/yaml": "to_yaml"}
+
+ def _add_models(self, spec: APISpec, models: ListOfModel, base_path: str):
+ for model in models:
+ model_def = json.loads(model["schema"])
+ self._resolve_refs(model_def, base_path)
+ spec.components.schema(
+ component_id=model["name"],
+ component=model_def,
+ )
+
+ def _resolve_refs(self, schema: dict, base_path: str):
+ if "$ref" in schema:
+ schema["$ref"] = f"{base_path}/{schema['$ref'].rsplit('/', maxsplit=1)[-1]}"
+ for value in schema.values():
+ if isinstance(value, dict):
+ self._resolve_refs(value, base_path)
+
+ @staticmethod
+ def _get_integration(method_integration: dict) -> dict:
+ fields = {
+ "type",
+ "passthroughBehavior",
+ "requestParameters",
+ "requestTemplates",
+ "httpMethod",
+ "uri",
+ }
+ integration = {k: v for k, v in method_integration.items() if k in fields}
+ integration["type"] = integration["type"].lower()
+ integration["passthroughBehavior"] = integration["passthroughBehavior"].lower()
+ if responses := method_integration.get("integrationResponses"):
+ integration["responses"] = {"default": responses.get("200")}
+ return integration
+
+ @abc.abstractmethod
+ def export(
+ self,
+ api_id: str,
+ stage: str,
+ export_format: str,
+ with_extension: bool,
+ account_id: str,
+ region_name: str,
+ ) -> str | dict: ...
+
+ @abc.abstractmethod
+ def _add_paths(self, spec: APISpec, resources: dict, with_extension: bool):
+ """
+ This method iterates over the different REST resources and its methods to add the APISpec paths using the
+ `apispec` module.
+ The path format is different between Swagger (OpenAPI 2.0) and OpenAPI 3.0
+ :param spec: an APISpec object representing the exported API Gateway REST API
+ :param resources: the API Gateway REST API resources (methods, methods integrations, responses...)
+ :param with_extension: flag to add the custom OpenAPI extension `apigateway`, allowing to properly import
+ integrations for example, or authorizers. (all the `x-amazon` fields contained in `OpenAPIExt`).
+ :return: None
+ """
+ ...
+
+
+class _OpenApiSwaggerExporter(_BaseOpenApiExporter):
+ VERSION = "2.0"
+
+ def _add_paths(self, spec, resources, with_extension):
+ for item in resources.get("items"):
+ path = item.get("path")
+ for method, method_config in item.get("resourceMethods", {}).items():
+ method = method.lower()
+
+ method_integration = method_config.get("methodIntegration", {})
+ integration_responses = method_integration.get("integrationResponses", {})
+ method_responses = method_config.get("methodResponses")
+ responses = {}
+ produces = set()
+ for status_code, values in method_responses.items():
+ response = {"description": f"{status_code} response"}
+ if response_parameters := values.get("responseParameters"):
+ headers = {}
+ for parameter in response_parameters:
+ in_, name = parameter.removeprefix("method.response.").split(".")
+ # TODO: other type?
+ if in_ == "header":
+ headers[name] = {"type": "string"}
+
+ if headers:
+ response["headers"] = headers
+ if response_models := values.get("responseModels"):
+ for content_type, model_name in response_models.items():
+ produces.add(content_type)
+ response["schema"] = model_name
+ if integration_response := integration_responses.get(status_code, {}):
+ produces.update(integration_response.get("responseTemplates", {}).keys())
+
+ responses[status_code] = response
+
+ request_parameters = method_config.get("requestParameters", {})
+ parameters = []
+ for parameter, required in request_parameters.items():
+ in_, name = parameter.removeprefix("method.request.").split(".")
+ in_ = in_ if in_ != "querystring" else "query"
+ parameters.append(
+ {"name": name, "in": in_, "required": required, "type": "string"}
+ )
+
+ request_models = method_config.get("requestModels", {})
+ for model_name in request_models.values():
+ parameter = {
+ "in": "body",
+ "name": model_name,
+ "required": True,
+ "schema": {"$ref": f"#/definitions/{model_name}"},
+ }
+ parameters.append(parameter)
+
+ method_operations = {"responses": responses}
+ if parameters:
+ method_operations["parameters"] = parameters
+ if produces:
+ method_operations["produces"] = list(produces)
+ if content_types := request_models | method_integration.get("requestTemplates", {}):
+ method_operations["consumes"] = list(content_types.keys())
+ if operation_name := method_config.get("operationName"):
+ method_operations["operationId"] = operation_name
+ if with_extension and method_integration:
+ method_operations[OpenAPIExt.INTEGRATION] = self._get_integration(
+ method_integration
+ )
+
+ spec.path(path=path, operations={method: method_operations})
+
+ def export(
+ self,
+ api_id: str,
+ stage: str,
+ export_format: str,
+ with_extension: bool,
+ account_id: str,
+ region_name: str,
+ ) -> str:
+ """
+ https://github.com/OAI/OpenAPI-Specification/blob/main/versions/2.0.md
+ """
+ apigateway_client = connect_to(
+ aws_access_key_id=account_id, region_name=region_name
+ ).apigateway
+
+ rest_api = apigateway_client.get_rest_api(restApiId=api_id)
+ resources = apigateway_client.get_resources(restApiId=api_id)
+ models = apigateway_client.get_models(restApiId=api_id)
+
+ info = {}
+ if (description := rest_api.get("description")) is not None:
+ info["description"] = description
+
+ spec = APISpec(
+ title=rest_api.get("name"),
+ version=rest_api.get("version")
+ or timestamp(rest_api.get("createdDate"), format=TIMESTAMP_FORMAT_TZ),
+ info=info,
+ openapi_version=self.VERSION,
+ basePath=f"/{stage}",
+ schemes=["https"],
+ )
+
+ self._add_paths(spec, resources, with_extension)
+ self._add_models(spec, models["items"], "#/definitions")
+
+ return getattr(spec, self.export_formats.get(export_format))()
+
+
+class _OpenApiOAS30Exporter(_BaseOpenApiExporter):
+ VERSION = "3.0.1"
+
+ def _add_paths(self, spec, resources, with_extension):
+ for item in resources.get("items"):
+ path = item.get("path")
+ for method, method_config in item.get("resourceMethods", {}).items():
+ method = method.lower()
+
+ method_integration = method_config.get("methodIntegration", {})
+ integration_responses = method_integration.get("integrationResponses", {})
+ method_responses = method_config.get("methodResponses")
+ responses = {}
+ produces = set()
+ for status_code, values in method_responses.items():
+ response = {"description": f"{status_code} response"}
+ content = {}
+ if response_parameters := values.get("responseParameters"):
+ headers = {}
+ for parameter in response_parameters:
+ in_, name = parameter.removeprefix("method.response.").split(".")
+ # TODO: other type? query?
+ if in_ == "header":
+ headers[name] = {"schema": {"type": "string"}}
+
+ if headers:
+ response["headers"] = headers
+ if response_models := values.get("responseModels"):
+ for content_type, model_name in response_models.items():
+ content[content_type] = {
+ "schema": {"$ref": f"#/components/schemas/{model_name}"}
+ }
+ if integration_response := integration_responses.get(status_code, {}):
+ produces.update(integration_response.get("responseTemplates", {}).keys())
+
+ response["content"] = content
+ responses[status_code] = response
+
+ request_parameters = method_config.get("requestParameters", {})
+ parameters = []
+ for parameter, required in request_parameters.items():
+ in_, name = parameter.removeprefix("method.request.").split(".")
+ in_ = in_ if in_ != "querystring" else "query"
+ parameters.append({"name": name, "in": in_, "schema": {"type": "string"}})
+
+ request_body = {"content": {}}
+ request_models = method_config.get("requestModels", {})
+ for content_type, model_name in request_models.items():
+ request_body["content"][content_type] = {
+ "schema": {"$ref": f"#/components/schemas/{model_name}"},
+ }
+ request_body["required"] = True
+
+ method_operations = {"responses": responses}
+ if parameters:
+ method_operations["parameters"] = parameters
+ if request_body["content"]:
+ method_operations["requestBody"] = request_body
+ if operation_name := method_config.get("operationName"):
+ method_operations["operationId"] = operation_name
+ if with_extension and method_integration:
+ method_operations[OpenAPIExt.INTEGRATION] = self._get_integration(
+ method_integration
+ )
+
+ spec.path(path=path, operations={method: method_operations})
+
+ def export(
+ self,
+ api_id: str,
+ stage: str,
+ export_format: str,
+ with_extension: bool,
+ account_id: str,
+ region_name: str,
+ ) -> str:
+ """
+ https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md
+ """
+ apigateway_client = connect_to(
+ aws_access_key_id=account_id, region_name=region_name
+ ).apigateway
+
+ rest_api = apigateway_client.get_rest_api(restApiId=api_id)
+ resources = apigateway_client.get_resources(restApiId=api_id)
+ models = apigateway_client.get_models(restApiId=api_id)
+
+ info = {}
+
+ if (description := rest_api.get("description")) is not None:
+ info["description"] = description
+
+ spec = APISpec(
+ title=rest_api.get("name"),
+ version=rest_api.get("version")
+ or timestamp(rest_api.get("createdDate"), format=TIMESTAMP_FORMAT_TZ),
+ info=info,
+ openapi_version=self.VERSION,
+ servers=[{"variables": {"basePath": {"default": stage}}}],
+ )
+
+ self._add_paths(spec, resources, with_extension)
+ self._add_models(spec, models["items"], "#/components/schemas")
+
+ response = getattr(spec, self.export_formats.get(export_format))()
+ if isinstance(response, dict) and "components" not in response:
+ response["components"] = {}
+ return response
+
+
+class OpenApiExporter:
+ exporters: dict[str, Type[_BaseOpenApiExporter]]
+
+ def __init__(self):
+ self.exporters = {"swagger": _OpenApiSwaggerExporter, "oas30": _OpenApiOAS30Exporter}
+
+ def export_api(
+ self,
+ api_id: str,
+ stage: str,
+ export_type: str,
+ account_id: str,
+ region_name: str,
+ export_format: str = "application/json",
+ with_extension=False,
+ ) -> str:
+ exporter = self.exporters.get(export_type)()
+ return exporter.export(
+ api_id, stage, export_format, with_extension, account_id, region_name
+ )
diff --git a/localstack-core/localstack/services/apigateway/helpers.py b/localstack-core/localstack/services/apigateway/helpers.py
new file mode 100644
index 0000000000000..cde25c4bdaba2
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/helpers.py
@@ -0,0 +1,1032 @@
+import contextlib
+import copy
+import hashlib
+import json
+import logging
+from datetime import datetime
+from typing import List, Optional, TypedDict, Union
+from urllib import parse as urlparse
+
+from jsonpatch import apply_patch
+from jsonpointer import JsonPointerException
+from moto.apigateway import models as apigw_models
+from moto.apigateway.models import APIGatewayBackend, Integration, Resource
+from moto.apigateway.models import RestAPI as MotoRestAPI
+from moto.apigateway.utils import ApigwAuthorizerIdentifier, ApigwResourceIdentifier
+
+from localstack import config
+from localstack.aws.api import RequestContext
+from localstack.aws.api.apigateway import (
+ Authorizer,
+ ConnectionType,
+ DocumentationPart,
+ DocumentationPartLocation,
+ IntegrationType,
+ Model,
+ NotFoundException,
+ PutRestApiRequest,
+ RequestValidator,
+)
+from localstack.constants import (
+ APPLICATION_JSON,
+ AWS_REGION_US_EAST_1,
+ DEFAULT_AWS_ACCOUNT_ID,
+ PATH_USER_REQUEST,
+)
+from localstack.services.apigateway.legacy.context import ApiInvocationContext
+from localstack.services.apigateway.models import (
+ ApiGatewayStore,
+ RestApiContainer,
+ apigateway_stores,
+)
+from localstack.utils import common
+from localstack.utils.json import parse_json_or_yaml
+from localstack.utils.strings import short_uid, to_bytes, to_str
+from localstack.utils.urls import localstack_host
+
+LOG = logging.getLogger(__name__)
+
+REQUEST_TIME_DATE_FORMAT = "%d/%b/%Y:%H:%M:%S %z"
+
+INVOKE_TEST_LOG_TEMPLATE = """Execution log for request {request_id}
+ {formatted_date} : Starting execution for request: {request_id}
+ {formatted_date} : HTTP Method: {http_method}, Resource Path: {resource_path}
+ {formatted_date} : Method request path: {request_path}
+ {formatted_date} : Method request query string: {query_string}
+ {formatted_date} : Method request headers: {request_headers}
+ {formatted_date} : Method request body before transformations: {request_body}
+ {formatted_date} : Method response body after transformations: {response_body}
+ {formatted_date} : Method response headers: {response_headers}
+ {formatted_date} : Successfully completed execution
+ {formatted_date} : Method completed with status: {status_code}
+ """
+
+
+EMPTY_MODEL = "Empty"
+ERROR_MODEL = "Error"
+
+
+# TODO: we could actually parse the schema to get TypedDicts with the proper schema/types for each properties
+class OpenAPIExt:
+ """
+ Represents the specific OpenAPI extensions for API Gateway
+ https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-swagger-extensions.html
+ """
+
+ ANY_METHOD = "x-amazon-apigateway-any-method"
+ CORS = "x-amazon-apigateway-cors"
+ API_KEY_SOURCE = "x-amazon-apigateway-api-key-source"
+ AUTH = "x-amazon-apigateway-auth"
+ AUTHORIZER = "x-amazon-apigateway-authorizer"
+ AUTHTYPE = "x-amazon-apigateway-authtype"
+ BINARY_MEDIA_TYPES = "x-amazon-apigateway-binary-media-types"
+ DOCUMENTATION = "x-amazon-apigateway-documentation"
+ ENDPOINT_CONFIGURATION = "x-amazon-apigateway-endpoint-configuration"
+ GATEWAY_RESPONSES = "x-amazon-apigateway-gateway-responses"
+ IMPORTEXPORT_VERSION = "x-amazon-apigateway-importexport-version"
+ INTEGRATION = "x-amazon-apigateway-integration"
+ INTEGRATIONS = "x-amazon-apigateway-integrations" # used in components
+ MINIMUM_COMPRESSION_SIZE = "x-amazon-apigateway-minimum-compression-size"
+ POLICY = "x-amazon-apigateway-policy"
+ REQUEST_VALIDATOR = "x-amazon-apigateway-request-validator"
+ REQUEST_VALIDATORS = "x-amazon-apigateway-request-validators"
+ TAG_VALUE = "x-amazon-apigateway-tag-value"
+
+
+class AuthorizerConfig(TypedDict):
+ authorizer: Authorizer
+ authorization_scopes: Optional[list[str]]
+
+
+# TODO: make the CRUD operations in this file generic for the different model types (authorizes, validators, ...)
+
+
+def get_apigateway_store(context: RequestContext) -> ApiGatewayStore:
+ return apigateway_stores[context.account_id][context.region]
+
+
+def get_apigateway_store_for_invocation(context: ApiInvocationContext) -> ApiGatewayStore:
+ account_id = context.account_id or DEFAULT_AWS_ACCOUNT_ID
+ region_name = context.region_name or AWS_REGION_US_EAST_1
+ return apigateway_stores[account_id][region_name]
+
+
+def get_moto_backend(account_id: str, region: str) -> APIGatewayBackend:
+ return apigw_models.apigateway_backends[account_id][region]
+
+
+def get_moto_rest_api(context: RequestContext, rest_api_id: str) -> MotoRestAPI:
+ moto_backend = apigw_models.apigateway_backends[context.account_id][context.region]
+ if rest_api := moto_backend.apis.get(rest_api_id):
+ return rest_api
+ else:
+ raise NotFoundException(
+ f"Invalid API identifier specified {context.account_id}:{rest_api_id}"
+ )
+
+
+def get_rest_api_container(context: RequestContext, rest_api_id: str) -> RestApiContainer:
+ store = get_apigateway_store(context=context)
+ if not (rest_api_container := store.rest_apis.get(rest_api_id)):
+ raise NotFoundException(
+ f"Invalid API identifier specified {context.account_id}:{rest_api_id}"
+ )
+ return rest_api_container
+
+
+class OpenAPISpecificationResolver:
+ def __init__(self, document: dict, rest_api_id: str, allow_recursive=True):
+ self.document = document
+ self.allow_recursive = allow_recursive
+ # cache which maps known refs to part of the document
+ self._cache = {}
+ self._refpaths = ["#"]
+ host_definition = localstack_host()
+ self._base_url = f"{config.get_protocol()}://apigateway.{host_definition.host_and_port()}/restapis/{rest_api_id}/models/"
+
+ def _is_ref(self, item) -> bool:
+ return isinstance(item, dict) and "$ref" in item
+
+ def _is_internal_ref(self, refpath) -> bool:
+ return str(refpath).startswith("#/")
+
+ @property
+ def current_path(self):
+ return self._refpaths[-1]
+
+ @contextlib.contextmanager
+ def _pathctx(self, refpath: str):
+ if not self._is_internal_ref(refpath):
+ refpath = "/".join((self.current_path, refpath))
+
+ self._refpaths.append(refpath)
+ yield
+ self._refpaths.pop()
+
+ def _resolve_refpath(self, refpath: str) -> dict:
+ if refpath in self._refpaths and not self.allow_recursive:
+ raise Exception("recursion detected with allow_recursive=False")
+
+ # We don't resolve the Model definition, we will return a absolute reference to the model like AWS
+ # When validating the schema, we will need to resolve the $ref there
+ # Because if we resolved all $ref in schema, it can lead to circular references in complex schemas
+ if self.current_path.startswith("#/definitions") or self.current_path.startswith(
+ "#/components/schemas"
+ ):
+ return {"$ref": f"{self._base_url}{refpath.rsplit('/', maxsplit=1)[-1]}"}
+
+ # We should not resolve the Model either, because we need its name to set it to the Request/ResponseModels,
+ # it just makes our job more difficult to retrieve the Model name
+ # We still need to verify that the ref exists
+ is_schema = self.current_path.endswith("schema")
+
+ if refpath in self._cache and not is_schema:
+ return self._cache.get(refpath)
+
+ with self._pathctx(refpath):
+ if self._is_internal_ref(self.current_path):
+ cur = self.document
+ else:
+ raise NotImplementedError("External references not yet supported.")
+
+ for step in self.current_path.split("/")[1:]:
+ cur = cur.get(step)
+
+ self._cache[self.current_path] = cur
+
+ if is_schema:
+ # If the $ref doesn't exist in our schema, return None, otherwise return the ref
+ return {"$ref": refpath} if cur else None
+
+ return cur
+
+ def _namespaced_resolution(self, namespace: str, data: Union[dict, list]) -> Union[dict, list]:
+ with self._pathctx(namespace):
+ return self._resolve_references(data)
+
+ def _resolve_references(self, data) -> Union[dict, list]:
+ if self._is_ref(data):
+ return self._resolve_refpath(data["$ref"])
+
+ if isinstance(data, dict):
+ for k, v in data.items():
+ data[k] = self._namespaced_resolution(k, v)
+ elif isinstance(data, list):
+ for i, v in enumerate(data):
+ data[i] = self._namespaced_resolution(str(i), v)
+
+ return data
+
+ def resolve_references(self) -> dict:
+ return self._resolve_references(self.document)
+
+
+class ModelResolver:
+ """
+ This class allows a Model to use recursive and circular references to other Models.
+ To be able to JSON dump Models, AWS will not resolve Models but will use their absolute $ref instead.
+ When validating, we need to resolve those references, using JSON schema tricks to allow recursion.
+ See: https://json-schema.org/understanding-json-schema/structuring.html#recursion
+
+ To allow a simpler structure, we're not replacing directly the reference with the schema, but instead create
+ a map of all used schema in $defs, as advised on JSON schema:
+ See: https://json-schema.org/understanding-json-schema/structuring.html#defs
+
+ This allows us to not render every sub schema/models, but instead keep a clean map of used schemas.
+ """
+
+ def __init__(self, rest_api_container: RestApiContainer, model_name: str):
+ self.rest_api_container = rest_api_container
+ self.model_name = model_name
+ self._deps = {}
+ self._current_resolving_name = None
+
+ @contextlib.contextmanager
+ def _resolving_ctx(self, current_resolving_name: str):
+ self._current_resolving_name = current_resolving_name
+ yield
+ self._current_resolving_name = None
+
+ def resolve_model(self, model: dict) -> dict | None:
+ resolved_model = copy.deepcopy(model)
+ model_names = set()
+
+ def _look_for_ref(sub_model):
+ for key, value in sub_model.items():
+ if key == "$ref":
+ ref_name = value.rsplit("/", maxsplit=1)[-1]
+ if ref_name == self.model_name:
+ # if we reference our main Model, use the # for recursive access
+ sub_model[key] = "#"
+ continue
+ # otherwise, this Model will be available in $defs
+ sub_model[key] = f"#/$defs/{ref_name}"
+
+ if ref_name != self._current_resolving_name:
+ # add the ref to the next ref to resolve and to $deps
+ model_names.add(ref_name)
+
+ elif isinstance(value, dict):
+ _look_for_ref(value)
+ elif isinstance(value, list):
+ for val in value:
+ if isinstance(val, dict):
+ _look_for_ref(val)
+
+ if isinstance(resolved_model, dict):
+ _look_for_ref(resolved_model)
+
+ if model_names:
+ for ref_model_name in model_names:
+ if ref_model_name in self._deps:
+ continue
+
+ def_resolved, was_resolved = self._get_resolved_submodel(model_name=ref_model_name)
+
+ if not def_resolved:
+ LOG.debug(
+ "Failed to resolve submodel %s for model %s",
+ ref_model_name,
+ self._current_resolving_name,
+ )
+ return
+ # if the ref was already resolved, we copy the result to not alter the already resolved schema
+ if was_resolved:
+ def_resolved = copy.deepcopy(def_resolved)
+
+ self._remove_self_ref(def_resolved)
+
+ if "$deps" in def_resolved:
+ # this will happen only if the schema was already resolved, otherwise the deps would be in _deps
+ # remove own definition in case of recursive / circular Models
+ def_resolved["$defs"].pop(self.model_name, None)
+ # remove the $defs from the schema, we don't want nested $defs
+ def_resolved_defs = def_resolved.pop("$defs")
+ # merge the resolved sub model $defs to the main schema
+ self._deps.update(def_resolved_defs)
+
+ # add the dependencies to the global $deps
+ self._deps[ref_model_name] = def_resolved
+
+ return resolved_model
+
+ def _remove_self_ref(self, resolved_schema: dict):
+ for key, value in resolved_schema.items():
+ if key == "$ref":
+ ref_name = value.rsplit("/", maxsplit=1)[-1]
+ if ref_name == self.model_name:
+ resolved_schema[key] = "#"
+
+ elif isinstance(value, dict):
+ self._remove_self_ref(value)
+
+ def get_resolved_model(self) -> dict | None:
+ if not (resolved_model := self.rest_api_container.resolved_models.get(self.model_name)):
+ model = self.rest_api_container.models.get(self.model_name)
+ if not model:
+ return None
+ schema = json.loads(model["schema"])
+ resolved_model = self.resolve_model(schema)
+ if not resolved_model:
+ return None
+ # attach the resolved dependencies of the schema
+ if self._deps:
+ resolved_model["$defs"] = self._deps
+ self.rest_api_container.resolved_models[self.model_name] = resolved_model
+
+ return resolved_model
+
+ def _get_resolved_submodel(self, model_name: str) -> tuple[dict | None, bool | None]:
+ was_resolved = True
+ if not (resolved_model := self.rest_api_container.resolved_models.get(model_name)):
+ was_resolved = False
+ model = self.rest_api_container.models.get(model_name)
+ if not model:
+ LOG.warning(
+ "Error while validating the request body, could not the find the Model: '%s'",
+ model_name,
+ )
+ return None, was_resolved
+ schema = json.loads(model["schema"])
+
+ with self._resolving_ctx(model_name):
+ resolved_model = self.resolve_model(schema)
+
+ return resolved_model, was_resolved
+
+
+def resolve_references(data: dict, rest_api_id, allow_recursive=True) -> dict:
+ resolver = OpenAPISpecificationResolver(
+ data, allow_recursive=allow_recursive, rest_api_id=rest_api_id
+ )
+ return resolver.resolve_references()
+
+
+# ---------------
+# UTIL FUNCTIONS
+# ---------------
+
+
+def path_based_url(api_id: str, stage_name: str, path: str) -> str:
+ """Return URL for inbound API gateway for given API ID, stage name, and path"""
+ pattern = "%s/restapis/{api_id}/{stage_name}/%s{path}" % (
+ config.external_service_url(),
+ PATH_USER_REQUEST,
+ )
+ return pattern.format(api_id=api_id, stage_name=stage_name, path=path)
+
+
+def localstack_path_based_url(api_id: str, stage_name: str, path: str) -> str:
+ """Return URL for inbound API gateway for given API ID, stage name, and path on the _aws namespace"""
+ return f"{config.external_service_url()}/_aws/execute-api/{api_id}/{stage_name}{path}"
+
+
+def host_based_url(rest_api_id: str, path: str, stage_name: str = None):
+ """Return URL for inbound API gateway for given API ID, stage name, and path with custom dns
+ format"""
+ pattern = "{endpoint}{stage}{path}"
+ stage = stage_name and f"/{stage_name}" or ""
+ return pattern.format(endpoint=get_execute_api_endpoint(rest_api_id), stage=stage, path=path)
+
+
+def get_execute_api_endpoint(api_id: str, protocol: str | None = None) -> str:
+ host = localstack_host()
+ protocol = protocol or config.get_protocol()
+ return f"{protocol}://{api_id}.execute-api.{host.host_and_port()}"
+
+
+def apply_json_patch_safe(subject, patch_operations, in_place=True, return_list=False):
+ """Apply JSONPatch operations, using some customizations for compatibility with API GW
+ resources."""
+
+ results = []
+ patch_operations = (
+ [patch_operations] if isinstance(patch_operations, dict) else patch_operations
+ )
+ for operation in patch_operations:
+ try:
+ # special case: for "replace" operations, assume "" as the default value
+ if operation["op"] == "replace" and operation.get("value") is None:
+ operation["value"] = ""
+
+ if operation["op"] != "remove" and operation.get("value") is None:
+ LOG.info('Missing "value" in JSONPatch operation for %s: %s', subject, operation)
+ continue
+
+ if operation["op"] == "add":
+ path = operation["path"]
+ target = subject.get(path.strip("/"))
+ target = target or common.extract_from_jsonpointer_path(subject, path)
+ if not isinstance(target, list):
+ # for `add` operation, if the target does not exist, set it to an empty dict (default behaviour)
+ # previous behaviour was an empty list. Revisit this if issues arise.
+ # TODO: we are assigning a value, even if not `in_place=True`
+ common.assign_to_path(subject, path, value={}, delimiter="/")
+
+ target = common.extract_from_jsonpointer_path(subject, path)
+ if isinstance(target, list) and not path.endswith("/-"):
+ # if "path" is an attribute name pointing to an array in "subject", and we're running
+ # an "add" operation, then we should use the standard-compliant notation "/path/-"
+ operation["path"] = f"{path}/-"
+
+ if operation["op"] == "remove":
+ path = operation["path"]
+ common.assign_to_path(subject, path, value={}, delimiter="/")
+
+ result = apply_patch(subject, [operation], in_place=in_place)
+ if not in_place:
+ subject = result
+ results.append(result)
+ except JsonPointerException:
+ pass # path cannot be found - ignore
+ except Exception as e:
+ if "non-existent object" in str(e):
+ if operation["op"] == "replace":
+ # fall back to an ADD operation if the REPLACE fails
+ operation["op"] = "add"
+ result = apply_patch(subject, [operation], in_place=in_place)
+ results.append(result)
+ continue
+ if operation["op"] == "remove" and isinstance(subject, dict):
+ result = subject.pop(operation["path"], None)
+ results.append(result)
+ continue
+ raise
+ if return_list:
+ return results
+ return (results or [subject])[-1]
+
+
+def add_documentation_parts(rest_api_container, documentation):
+ for doc_part in documentation.get("documentationParts", []):
+ entity_id = short_uid()[:6]
+ location = doc_part["location"]
+ rest_api_container.documentation_parts[entity_id] = DocumentationPart(
+ id=entity_id,
+ location=DocumentationPartLocation(
+ type=location.get("type"),
+ path=location.get("path", "/")
+ if location.get("type") not in ["API", "MODEL"]
+ else None,
+ method=location.get("method"),
+ statusCode=location.get("statusCode"),
+ name=location.get("name"),
+ ),
+ properties=doc_part["properties"],
+ )
+
+
+def import_api_from_openapi_spec(
+ rest_api: MotoRestAPI, context: RequestContext, request: PutRestApiRequest
+) -> tuple[MotoRestAPI, list[str]]:
+ """Import an API from an OpenAPI spec document"""
+ body = parse_json_or_yaml(to_str(request["body"].read()))
+
+ warnings = []
+
+ # TODO There is an issue with the botocore specs so the parameters doesn't get populated as it should
+ # Once this is fixed we can uncomment the code below instead of taking the parameters the context request
+ # query_params = request.get("parameters") or {}
+ query_params: dict = context.request.values.to_dict()
+
+ resolved_schema = resolve_references(copy.deepcopy(body), rest_api_id=rest_api.id)
+ account_id = context.account_id
+ region_name = context.region
+
+ # TODO:
+ # 1. validate the "mode" property of the spec document, "merge" or "overwrite"
+ # 2. validate the document type, "swagger" or "openapi"
+
+ rest_api.version = (
+ str(version) if (version := resolved_schema.get("info", {}).get("version")) else None
+ )
+ # XXX for some reason this makes cf tests fail that's why is commented.
+ # test_cfn_handle_serverless_api_resource
+ # rest_api.name = resolved_schema.get("info", {}).get("title")
+ rest_api.description = resolved_schema.get("info", {}).get("description")
+
+ # authorizers map to avoid duplication
+ authorizers = {}
+
+ store = get_apigateway_store(context=context)
+ rest_api_container = store.rest_apis[rest_api.id]
+
+ def is_api_key_required(path_payload: dict) -> bool:
+ # TODO: consolidate and refactor with `create_authorizer`, duplicate logic for now
+ if not (security_schemes := path_payload.get("security")):
+ return False
+
+ for security_scheme in security_schemes:
+ for security_scheme_name in security_scheme.keys():
+ # $.securityDefinitions is Swagger 2.0
+ # $.components.SecuritySchemes is OpenAPI 3.0
+ security_definitions = resolved_schema.get(
+ "securityDefinitions"
+ ) or resolved_schema.get("components", {}).get("securitySchemes", {})
+ if security_scheme_name in security_definitions:
+ security_config = security_definitions.get(security_scheme_name)
+ if (
+ OpenAPIExt.AUTHORIZER not in security_config
+ and security_config.get("type") == "apiKey"
+ and security_config.get("name", "").lower() == "x-api-key"
+ ):
+ return True
+ return False
+
+ def create_authorizers(security_schemes: dict) -> None:
+ for security_scheme_name, security_config in security_schemes.items():
+ aws_apigateway_authorizer = security_config.get(OpenAPIExt.AUTHORIZER, {})
+ if not aws_apigateway_authorizer:
+ continue
+
+ if security_scheme_name in authorizers:
+ continue
+
+ authorizer_type = aws_apigateway_authorizer.get("type", "").upper()
+ # TODO: do we need validation of resources here?
+ authorizer = Authorizer(
+ id=ApigwAuthorizerIdentifier(
+ account_id, region_name, security_scheme_name
+ ).generate(),
+ name=security_scheme_name,
+ type=authorizer_type,
+ authorizerResultTtlInSeconds=aws_apigateway_authorizer.get(
+ "authorizerResultTtlInSeconds", None
+ ),
+ )
+ if provider_arns := aws_apigateway_authorizer.get("providerARNs"):
+ authorizer["providerARNs"] = provider_arns
+ if auth_type := security_config.get(OpenAPIExt.AUTHTYPE):
+ authorizer["authType"] = auth_type
+ if authorizer_uri := aws_apigateway_authorizer.get("authorizerUri"):
+ authorizer["authorizerUri"] = authorizer_uri
+ if authorizer_credentials := aws_apigateway_authorizer.get("authorizerCredentials"):
+ authorizer["authorizerCredentials"] = authorizer_credentials
+ if authorizer_type in ("TOKEN", "COGNITO_USER_POOLS"):
+ header_name = security_config.get("name")
+ authorizer["identitySource"] = f"method.request.header.{header_name}"
+ elif identity_source := aws_apigateway_authorizer.get("identitySource"):
+ # https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-swagger-extensions-authorizer.html
+ # Applicable for the authorizer of the request and jwt type only
+ authorizer["identitySource"] = identity_source
+ if identity_validation_expression := aws_apigateway_authorizer.get(
+ "identityValidationExpression"
+ ):
+ authorizer["identityValidationExpression"] = identity_validation_expression
+
+ rest_api_container.authorizers[authorizer["id"]] = authorizer
+
+ authorizers[security_scheme_name] = authorizer
+
+ def get_authorizer(path_payload: dict) -> Optional[AuthorizerConfig]:
+ if not (security_schemes := path_payload.get("security")):
+ return None
+
+ for security_scheme in security_schemes:
+ for security_scheme_name, scopes in security_scheme.items():
+ if authorizer := authorizers.get(security_scheme_name):
+ return AuthorizerConfig(authorizer=authorizer, authorization_scopes=scopes)
+
+ def get_or_create_path(abs_path: str, base_path: str):
+ parts = abs_path.rstrip("/").replace("//", "/").split("/")
+ parent_id = ""
+ if len(parts) > 1:
+ parent_path = "/".join(parts[:-1])
+ parent = get_or_create_path(parent_path, base_path=base_path)
+ parent_id = parent.id
+ if existing := [
+ r
+ for r in rest_api.resources.values()
+ if r.path_part == (parts[-1] or "/") and (r.parent_id or "") == (parent_id or "")
+ ]:
+ return existing[0]
+
+ # construct relative path (without base path), then add field resources for this path
+ rel_path = abs_path.removeprefix(base_path)
+ return add_path_methods(rel_path, parts, parent_id=parent_id)
+
+ def add_path_methods(rel_path: str, parts: List[str], parent_id=""):
+ rel_path = rel_path or "/"
+ child_id = ApigwResourceIdentifier(account_id, region_name, parent_id, rel_path).generate()
+
+ # Create a `Resource` for the passed `rel_path`
+ resource = Resource(
+ account_id=rest_api.account_id,
+ resource_id=child_id,
+ region_name=rest_api.region_name,
+ api_id=rest_api.id,
+ path_part=parts[-1] or "/",
+ parent_id=parent_id,
+ )
+
+ paths_dict = resolved_schema["paths"]
+ method_paths = paths_dict.get(rel_path, {})
+ # Iterate over each field of the `path` to try to find the methods defined
+ for field, field_schema in method_paths.items():
+ if field in [
+ "parameters",
+ "servers",
+ "description",
+ "summary",
+ "$ref",
+ ] or not isinstance(field_schema, dict):
+ LOG.warning("Ignoring unsupported field %s in path %s", field, rel_path)
+ # TODO: check if we should skip parameters, those are global parameters applied to every routes but
+ # can be overridden at the operation level
+ continue
+
+ method_name = field.upper()
+ if method_name == OpenAPIExt.ANY_METHOD.upper():
+ method_name = "ANY"
+
+ # Create the `Method` resource for each method path
+ method_resource = create_method_resource(resource, method_name, field_schema)
+
+ # Get the `Method` requestParameters and requestModels
+ request_parameters_schema = field_schema.get("parameters", [])
+ request_parameters = {}
+ request_models = {}
+ if request_parameters_schema:
+ for req_param_data in request_parameters_schema:
+ # For Swagger 2.0, possible values for `in` from the specs are "query", "header", "path",
+ # "formData" or "body".
+ # For OpenAPI 3.0, values are "query", "header", "path" or "cookie".
+ # Only "path", "header" and "query" are supported in API Gateway for requestParameters
+ # "body" is mapped to a requestModel
+ param_location = req_param_data.get("in")
+ param_name = req_param_data.get("name")
+ param_required = req_param_data.get("required", False)
+ if param_location in ("query", "header", "path"):
+ if param_location == "query":
+ param_location = "querystring"
+
+ request_parameters[f"method.request.{param_location}.{param_name}"] = (
+ param_required
+ )
+
+ elif param_location == "body":
+ request_models = {APPLICATION_JSON: param_name}
+
+ else:
+ LOG.warning(
+ "Ignoring unsupported requestParameters/requestModels location value for %s: %s",
+ param_name,
+ param_location,
+ )
+ continue
+
+ # this replaces 'body' in Parameters for OpenAPI 3.0, a requestBody Object
+ # https://swagger.io/specification/v3/#request-body-object
+ if request_models_schema := field_schema.get("requestBody"):
+ model_ref = None
+ for content_type, media_type in request_models_schema.get("content", {}).items():
+ # we're iterating over the Media Type object:
+ # https://swagger.io/specification/v3/#media-type-object
+ if content_type == APPLICATION_JSON:
+ model_ref = media_type.get("schema", {}).get("$ref")
+ continue
+ LOG.warning(
+ "Found '%s' content-type for the MethodResponse model for path '%s' and method '%s', not adding the model as currently not supported",
+ content_type,
+ rel_path,
+ method_name,
+ )
+ if model_ref:
+ model_schema = model_ref.rsplit("/", maxsplit=1)[-1]
+ request_models = {APPLICATION_JSON: model_schema}
+
+ method_resource.request_models = request_models or None
+
+ # check if there's a request validator set in the method
+ request_validator_name = field_schema.get(
+ OpenAPIExt.REQUEST_VALIDATOR, default_req_validator_name
+ )
+ if request_validator_name:
+ if not (
+ req_validator_id := request_validator_name_id_map.get(request_validator_name)
+ ):
+ # Might raise an exception here if we properly validate the template
+ LOG.warning(
+ "A validator ('%s') was referenced for %s.(%s), but is not defined",
+ request_validator_name,
+ rel_path,
+ method_name,
+ )
+ method_resource.request_validator_id = req_validator_id
+
+ # we check if there's a path parameter, AWS adds the requestParameter automatically
+ resource_path_part = parts[-1].strip("/")
+ if is_variable_path(resource_path_part) and not is_greedy_path(resource_path_part):
+ path_parameter = resource_path_part[1:-1] # remove the curly braces
+ request_parameters[f"method.request.path.{path_parameter}"] = True
+
+ method_resource.request_parameters = request_parameters or None
+
+ # Create the `MethodResponse` for the previously created `Method`
+ method_responses = field_schema.get("responses", {})
+ for method_status_code, method_response in method_responses.items():
+ method_status_code = str(method_status_code)
+ method_response_model = None
+ model_ref = None
+ # separating the two different versions, Swagger (2.0) and OpenAPI 3.0
+ if "schema" in method_response: # this is Swagger
+ model_ref = method_response["schema"].get("$ref")
+ elif "content" in method_response: # this is OpenAPI 3.0
+ for content_type, media_type in method_response["content"].items():
+ # we're iterating over the Media Type object:
+ # https://swagger.io/specification/v3/#media-type-object
+ if content_type == APPLICATION_JSON:
+ model_ref = media_type.get("schema", {}).get("$ref")
+ continue
+ LOG.warning(
+ "Found '%s' content-type for the MethodResponse model for path '%s' and method '', not adding the model as currently not supported",
+ content_type,
+ rel_path,
+ method_name,
+ )
+
+ if model_ref:
+ model_schema = model_ref.rsplit("/", maxsplit=1)[-1]
+
+ method_response_model = {APPLICATION_JSON: model_schema}
+
+ method_response_parameters = {}
+ if response_param_headers := method_response.get("headers"):
+ for header, header_info in response_param_headers.items():
+ # TODO: make use of `header_info`
+ method_response_parameters[f"method.response.header.{header}"] = False
+
+ method_resource.create_response(
+ method_status_code,
+ method_response_model,
+ method_response_parameters or None,
+ )
+
+ # Create the `Integration` for the previously created `Method`
+ method_integration = field_schema.get(OpenAPIExt.INTEGRATION, {})
+
+ integration_type = (
+ i_type.upper() if (i_type := method_integration.get("type")) else None
+ )
+
+ match integration_type:
+ case "AWS_PROXY":
+ # if the integration is AWS_PROXY with lambda, the only accepted integration method is POST
+ integration_method = "POST"
+ case _:
+ integration_method = (
+ method_integration.get("httpMethod") or method_name
+ ).upper()
+
+ connection_type = (
+ ConnectionType.INTERNET
+ if integration_type in (IntegrationType.HTTP, IntegrationType.HTTP_PROXY)
+ else None
+ )
+
+ if integration_request_parameters := method_integration.get("requestParameters"):
+ validated_parameters = {}
+ for k, v in integration_request_parameters.items():
+ if isinstance(v, str):
+ validated_parameters[k] = v
+ else:
+ # TODO This fixes for boolean serialization. We should validate how other types behave
+ value = str(v).lower()
+ warnings.append(
+ "Invalid format for 'requestParameters'. Expected type string for property "
+ f"'{k}' of resource '{resource.get_path()}' and method '{method_name}' but got '{value}'"
+ )
+
+ integration_request_parameters = validated_parameters
+
+ integration = Integration(
+ http_method=integration_method,
+ uri=method_integration.get("uri"),
+ integration_type=integration_type,
+ passthrough_behavior=method_integration.get(
+ "passthroughBehavior", "WHEN_NO_MATCH"
+ ).upper(),
+ request_templates=method_integration.get("requestTemplates"),
+ request_parameters=integration_request_parameters,
+ cache_namespace=resource.id,
+ timeout_in_millis=method_integration.get("timeoutInMillis") or "29000",
+ content_handling=method_integration.get("contentHandling"),
+ connection_type=connection_type,
+ )
+
+ # Create the `IntegrationResponse` for the previously created `Integration`
+ if method_integration_responses := method_integration.get("responses"):
+ for pattern, integration_responses in method_integration_responses.items():
+ integration_response_templates = integration_responses.get("responseTemplates")
+ integration_response_parameters = integration_responses.get(
+ "responseParameters"
+ )
+
+ integration_response = integration.create_integration_response(
+ status_code=str(integration_responses.get("statusCode", 200)),
+ selection_pattern=pattern if pattern != "default" else None,
+ response_templates=integration_response_templates,
+ response_parameters=integration_response_parameters,
+ content_handling=None,
+ )
+ # moto set the responseTemplates to an empty dict when it should be None if not defined
+ if integration_response_templates is None:
+ integration_response.response_templates = None
+
+ resource.resource_methods[method_name].method_integration = integration
+
+ rest_api.resources[child_id] = resource
+ rest_api_container.resource_children.setdefault(parent_id, []).append(child_id)
+ return resource
+
+ def create_method_resource(child, method, method_schema):
+ authorization_type = "NONE"
+ api_key_required = is_api_key_required(method_schema)
+ kwargs = {}
+
+ if authorizer := get_authorizer(method_schema) or default_authorizer:
+ method_authorizer = authorizer["authorizer"]
+ # override the authorizer_type if it's a TOKEN or REQUEST to CUSTOM
+ if (authorizer_type := method_authorizer["type"]) in ("TOKEN", "REQUEST"):
+ authorization_type = "CUSTOM"
+ else:
+ authorization_type = authorizer_type
+
+ kwargs["authorizer_id"] = method_authorizer["id"]
+
+ if authorization_scopes := authorizer.get("authorization_scopes"):
+ kwargs["authorization_scopes"] = authorization_scopes
+
+ return child.add_method(
+ method,
+ api_key_required=api_key_required,
+ authorization_type=authorization_type,
+ operation_name=method_schema.get("operationId"),
+ **kwargs,
+ )
+
+ models = resolved_schema.get("definitions") or resolved_schema.get("components", {}).get(
+ "schemas", {}
+ )
+ for name, model_data in models.items():
+ model_id = short_uid()[:6] # length 6 to make TF tests pass
+ model = Model(
+ id=model_id,
+ name=name,
+ contentType=APPLICATION_JSON,
+ description=model_data.get("description"),
+ schema=json.dumps(model_data),
+ )
+ store.rest_apis[rest_api.id].models[name] = model
+
+ # create the RequestValidators defined at the top-level field `x-amazon-apigateway-request-validators`
+ request_validators = resolved_schema.get(OpenAPIExt.REQUEST_VALIDATORS, {})
+ request_validator_name_id_map = {}
+ for validator_name, validator_schema in request_validators.items():
+ validator_id = short_uid()[:6]
+
+ validator = RequestValidator(
+ id=validator_id,
+ name=validator_name,
+ validateRequestBody=validator_schema.get("validateRequestBody") or False,
+ validateRequestParameters=validator_schema.get("validateRequestParameters") or False,
+ )
+
+ store.rest_apis[rest_api.id].validators[validator_id] = validator
+ request_validator_name_id_map[validator_name] = validator_id
+
+ # get default requestValidator if present
+ default_req_validator_name = resolved_schema.get(OpenAPIExt.REQUEST_VALIDATOR)
+
+ # $.securityDefinitions is Swagger 2.0
+ # $.components.SecuritySchemes is OpenAPI 3.0
+ security_data = resolved_schema.get("securityDefinitions") or resolved_schema.get(
+ "components", {}
+ ).get("securitySchemes", {})
+ # create the defined authorizers, even if they're not used by any routes
+ if security_data:
+ create_authorizers(security_data)
+
+ # create default authorizer if present
+ default_authorizer = get_authorizer(resolved_schema)
+
+ # determine base path
+ # default basepath mode is "ignore"
+ # see https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-import-api-basePath.html
+ basepath_mode = query_params.get("basepath") or "ignore"
+ base_path = ""
+
+ if basepath_mode != "ignore":
+ # in Swagger 2.0, the basePath is a top-level property
+ if "basePath" in resolved_schema:
+ base_path = resolved_schema["basePath"]
+
+ # in OpenAPI 3.0, the basePath is contained in the server object
+ elif "servers" in resolved_schema:
+ servers_property = resolved_schema.get("servers", [])
+ for server in servers_property:
+ # first, we check if there are a basePath variable (1st choice)
+ if "basePath" in server.get("variables", {}):
+ base_path = server["variables"]["basePath"].get("default", "")
+ break
+ # TODO: this allows both absolute and relative part, but AWS might not manage relative
+ url_path = urlparse.urlparse(server.get("url", "")).path
+ if url_path:
+ base_path = url_path if url_path != "/" else ""
+ break
+
+ if basepath_mode == "split":
+ base_path = base_path.strip("/").partition("/")[-1]
+ base_path = f"/{base_path}" if base_path else ""
+
+ api_paths = resolved_schema.get("paths", {})
+ if api_paths:
+ # Remove default root, then add paths from API spec
+ # TODO: the default mode is now `merge`, not `overwrite` if using `PutRestApi`
+ # TODO: quick hack for now, but do not remove the rootResource if the OpenAPI file is empty
+ rest_api.resources = {}
+
+ for path in api_paths:
+ get_or_create_path(base_path + path, base_path=base_path)
+
+ # binary types
+ rest_api.binaryMediaTypes = resolved_schema.get(OpenAPIExt.BINARY_MEDIA_TYPES, [])
+
+ policy = resolved_schema.get(OpenAPIExt.POLICY)
+ if policy:
+ policy = json.dumps(policy) if isinstance(policy, dict) else str(policy)
+ rest_api.policy = policy
+ minimum_compression_size = resolved_schema.get(OpenAPIExt.MINIMUM_COMPRESSION_SIZE)
+ if minimum_compression_size is not None:
+ rest_api.minimum_compression_size = int(minimum_compression_size)
+ endpoint_config = resolved_schema.get(OpenAPIExt.ENDPOINT_CONFIGURATION)
+ if endpoint_config:
+ if endpoint_config.get("vpcEndpointIds"):
+ endpoint_config.setdefault("types", ["PRIVATE"])
+ rest_api.endpoint_configuration = endpoint_config
+
+ api_key_source = resolved_schema.get(OpenAPIExt.API_KEY_SOURCE)
+ if api_key_source is not None:
+ rest_api.api_key_source = api_key_source.upper()
+
+ documentation = resolved_schema.get(OpenAPIExt.DOCUMENTATION)
+ if documentation:
+ add_documentation_parts(rest_api_container, documentation)
+
+ return rest_api, warnings
+
+
+def is_greedy_path(path_part: str) -> bool:
+ return path_part.startswith("{") and path_part.endswith("+}")
+
+
+def is_variable_path(path_part: str) -> bool:
+ return path_part.startswith("{") and path_part.endswith("}")
+
+
+def log_template(
+ request_id: str,
+ date: datetime,
+ http_method: str,
+ resource_path: str,
+ request_path: str,
+ query_string: str,
+ request_headers: str,
+ request_body: str,
+ response_body: str,
+ response_headers: str,
+ status_code: str,
+):
+ formatted_date = date.strftime("%a %b %d %H:%M:%S %Z %Y")
+ return INVOKE_TEST_LOG_TEMPLATE.format(
+ request_id=request_id,
+ formatted_date=formatted_date,
+ http_method=http_method,
+ resource_path=resource_path,
+ request_path=request_path,
+ query_string=query_string,
+ request_headers=request_headers,
+ request_body=request_body,
+ response_body=response_body,
+ response_headers=response_headers,
+ status_code=status_code,
+ )
+
+
+def get_domain_name_hash(domain_name: str) -> str:
+ """
+ Return a hash of the given domain name, which help construct regional domain names for APIs.
+ TODO: use this in the future to dispatch API Gateway API invocations made to the regional domain name
+ """
+ return hashlib.shake_128(to_bytes(domain_name)).hexdigest(4)
+
+
+def get_regional_domain_name(domain_name: str) -> str:
+ """
+ Return the regional domain name for the given domain name.
+ In real AWS, this would look something like: "d-oplm2qchq0.execute-api.us-east-1.amazonaws.com"
+ In LocalStack, we're returning this format: "d-.execute-api.localhost.localstack.cloud"
+ """
+ domain_name_hash = get_domain_name_hash(domain_name)
+ host = localstack_host().host
+ return f"d-{domain_name_hash}.execute-api.{host}"
diff --git a/localstack/utils/analytics/__init__.py b/localstack-core/localstack/services/apigateway/legacy/__init__.py
similarity index 100%
rename from localstack/utils/analytics/__init__.py
rename to localstack-core/localstack/services/apigateway/legacy/__init__.py
diff --git a/localstack-core/localstack/services/apigateway/legacy/context.py b/localstack-core/localstack/services/apigateway/legacy/context.py
new file mode 100644
index 0000000000000..37b9725f3feb8
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/legacy/context.py
@@ -0,0 +1,201 @@
+import base64
+import json
+from enum import Enum
+from typing import Any, Dict, List, Optional, Union
+
+from responses import Response
+
+from localstack.constants import HEADER_LOCALSTACK_EDGE_URL
+from localstack.utils.aws.aws_responses import parse_query_string
+from localstack.utils.strings import short_uid, to_str
+
+# type definition for data parameters (i.e., invocation payloads)
+InvocationPayload = Union[Dict, str, bytes]
+
+
+class ApiGatewayVersion(Enum):
+ V1 = "v1"
+ V2 = "v2"
+
+
+class ApiInvocationContext:
+ """Represents the context for an incoming API Gateway invocation."""
+
+ # basic (raw) HTTP invocation details (method, path, data, headers)
+ method: str
+ path: str
+ data: InvocationPayload
+ headers: Dict[str, str]
+
+ # raw URI (including query string) retired from werkzeug "RAW_URI" environment variable
+ raw_uri: str
+
+ # invocation context
+ context: Dict[str, Any]
+ # authentication info for this invocation
+ auth_context: Dict[str, Any]
+
+ # target API/resource details extracted from the invocation
+ apigw_version: ApiGatewayVersion
+ api_id: str
+ stage: str
+ account_id: str
+ region_name: str
+ # resource path, including any path parameter placeholders (e.g., "/my/path/{id}")
+ resource_path: str
+ integration: Dict
+ resource: Dict
+ # Invocation path with query string, e.g., "/my/path?test". Defaults to "path", can be used
+ # to overwrite the actual API path, in case the path format "../_user_request_/.." is used.
+ _path_with_query_string: str
+
+ # response templates to be applied to the invocation result
+ response_templates: Dict
+
+ route: Dict
+ connection_id: str
+ path_params: Dict
+
+ # response object
+ response: Response
+
+ # dict of stage variables (mapping names to values)
+ stage_variables: Dict[str, str]
+
+ # websockets route selection
+ ws_route: str
+
+ def __init__(
+ self,
+ method: str,
+ path: str,
+ data: Union[str, bytes],
+ headers: Dict[str, str],
+ api_id: str = None,
+ stage: str = None,
+ context: Dict[str, Any] = None,
+ auth_context: Dict[str, Any] = None,
+ ):
+ self.method = method
+ self._path = path
+ self.data = data
+ self.headers = headers
+ self.context = {"requestId": short_uid()} if context is None else context
+ self.auth_context = {} if auth_context is None else auth_context
+ self.apigw_version = None
+ self.api_id = api_id
+ self.stage = stage
+ self.region_name = None
+ self.account_id = None
+ self.integration = None
+ self.resource = None
+ self.resource_path = None
+ self.path_with_query_string = None
+ self.response_templates = {}
+ self.stage_variables = {}
+ self.path_params = {}
+ self.route = None
+ self.ws_route = None
+ self.response = None
+
+ @property
+ def path(self) -> str:
+ return self._path
+
+ @path.setter
+ def path(self, new_path: str):
+ if isinstance(new_path, str):
+ new_path = "/" + new_path.lstrip("/")
+ self._path = new_path
+
+ @property
+ def resource_id(self) -> Optional[str]:
+ return (self.resource or {}).get("id")
+
+ @property
+ def invocation_path(self) -> str:
+ """Return the plain invocation path, without query parameters."""
+ path = self.path_with_query_string or self.path
+ return path.split("?")[0]
+
+ @property
+ def path_with_query_string(self) -> str:
+ """Return invocation path with query string - defaults to the value of 'path', unless customized."""
+ return self._path_with_query_string or self.path
+
+ @path_with_query_string.setter
+ def path_with_query_string(self, new_path: str):
+ """Set a custom invocation path with query string (used to handle "../_user_request_/.." paths)."""
+ if isinstance(new_path, str):
+ new_path = "/" + new_path.lstrip("/")
+ self._path_with_query_string = new_path
+
+ def query_params(self) -> Dict[str, str]:
+ """Extract the query parameters from the target URL or path in this request context."""
+ query_string = self.path_with_query_string.partition("?")[2]
+ return parse_query_string(query_string)
+
+ @property
+ def integration_uri(self) -> Optional[str]:
+ integration = self.integration or {}
+ return integration.get("uri") or integration.get("integrationUri")
+
+ @property
+ def auth_identity(self) -> Optional[Dict]:
+ if isinstance(self.auth_context, dict):
+ if self.auth_context.get("identity") is None:
+ self.auth_context["identity"] = {}
+ return self.auth_context["identity"]
+
+ @property
+ def authorizer_type(self) -> str:
+ if isinstance(self.auth_context, dict):
+ return self.auth_context.get("authorizer_type") if self.auth_context else None
+
+ @property
+ def authorizer_result(self) -> Dict[str, Any]:
+ if isinstance(self.auth_context, dict):
+ return self.auth_context.get("authorizer") if self.auth_context else {}
+
+ def is_websocket_request(self) -> bool:
+ upgrade_header = str(self.headers.get("upgrade") or "")
+ return upgrade_header.lower() == "websocket"
+
+ def is_v1(self) -> bool:
+ """Whether this is an API Gateway v1 request"""
+ return self.apigw_version == ApiGatewayVersion.V1
+
+ def cookies(self) -> Optional[List[str]]:
+ if cookies := self.headers.get("cookie") or "":
+ return list(cookies.split(";"))
+ return None
+
+ @property
+ def is_data_base64_encoded(self) -> bool:
+ try:
+ json.dumps(self.data) if isinstance(self.data, (dict, list)) else to_str(self.data)
+ return False
+ except UnicodeDecodeError:
+ return True
+
+ def data_as_string(self) -> str:
+ try:
+ return (
+ json.dumps(self.data) if isinstance(self.data, (dict, list)) else to_str(self.data)
+ )
+ except UnicodeDecodeError:
+ # we string encode our base64 as string as well
+ return to_str(base64.b64encode(self.data))
+
+ def _extract_host_from_header(self) -> str:
+ host = self.headers.get(HEADER_LOCALSTACK_EDGE_URL) or self.headers.get("host", "")
+ return host.split("://")[-1].split("/")[0].split(":")[0]
+
+ @property
+ def domain_name(self) -> str:
+ return self._extract_host_from_header()
+
+ @property
+ def domain_prefix(self) -> str:
+ host = self._extract_host_from_header()
+ return host.split(".")[0]
diff --git a/localstack-core/localstack/services/apigateway/legacy/helpers.py b/localstack-core/localstack/services/apigateway/legacy/helpers.py
new file mode 100644
index 0000000000000..62a91a32e78b0
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/legacy/helpers.py
@@ -0,0 +1,711 @@
+import json
+import logging
+import re
+import time
+from collections import defaultdict
+from datetime import datetime, timezone
+from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
+from urllib import parse as urlparse
+
+from botocore.utils import InvalidArnException
+from moto.apigateway.models import apigateway_backends
+from requests.models import Response
+
+from localstack.aws.connect import connect_to
+from localstack.constants import (
+ APPLICATION_JSON,
+ DEFAULT_AWS_ACCOUNT_ID,
+ HEADER_LOCALSTACK_EDGE_URL,
+ PATH_USER_REQUEST,
+)
+from localstack.services.apigateway.helpers import REQUEST_TIME_DATE_FORMAT
+from localstack.services.apigateway.legacy.context import ApiInvocationContext
+from localstack.utils import common
+from localstack.utils.aws import resources as resource_utils
+from localstack.utils.aws.arns import get_partition, parse_arn
+from localstack.utils.aws.aws_responses import requests_error_response_json, requests_response
+from localstack.utils.json import try_json
+from localstack.utils.numbers import is_number
+from localstack.utils.strings import canonicalize_bool_to_str, long_uid, to_str
+
+LOG = logging.getLogger(__name__)
+
+# regex path patterns
+PATH_REGEX_MAIN = r"^/restapis/([A-Za-z0-9_\-]+)/[a-z]+(\?.*)?"
+PATH_REGEX_SUB = r"^/restapis/([A-Za-z0-9_\-]+)/[a-z]+/([A-Za-z0-9_\-]+)/.*"
+PATH_REGEX_TEST_INVOKE_API = r"^\/restapis\/([A-Za-z0-9_\-]+)\/resources\/([A-Za-z0-9_\-]+)\/methods\/([A-Za-z0-9_\-]+)/?(\?.*)?"
+
+# regex path pattern for user requests, handles stages like $default
+PATH_REGEX_USER_REQUEST = (
+ r"^/restapis/([A-Za-z0-9_\\-]+)(?:/([A-Za-z0-9\_($|%%24)\\-]+))?/%s/(.*)$" % PATH_USER_REQUEST
+)
+# URL pattern for invocations
+HOST_REGEX_EXECUTE_API = r"(?:.*://)?([a-zA-Z0-9]+)(?:(-vpce-[^.]+))?\.execute-api\.(.*)"
+
+# template for SQS inbound data
+APIGATEWAY_SQS_DATA_INBOUND_TEMPLATE = (
+ "Action=SendMessage&MessageBody=$util.base64Encode($input.json('$'))"
+)
+
+
+class ApiGatewayIntegrationError(Exception):
+ """
+ Base class for all ApiGateway Integration errors.
+ Can be used as is or extended for common error types.
+ These exceptions should be handled in one place, and bubble up from all others.
+ """
+
+ message: str
+ status_code: int
+
+ def __init__(self, message: str, status_code: int):
+ super().__init__(message)
+ self.message = message
+ self.status_code = status_code
+
+ def to_response(self):
+ return requests_response({"message": self.message}, status_code=self.status_code)
+
+
+class IntegrationParameters(TypedDict):
+ path: dict[str, str]
+ querystring: dict[str, str]
+ headers: dict[str, str]
+
+
+class RequestParametersResolver:
+ """
+ Integration request data mapping expressions
+ https://docs.aws.amazon.com/apigateway/latest/developerguide/request-response-data-mappings.html
+
+ Note: Use on REST APIs only
+ """
+
+ def resolve(self, context: ApiInvocationContext) -> IntegrationParameters:
+ """
+ Resolve method request parameters into integration request parameters.
+ Integration request parameters, in the form of path variables, query strings
+ or headers, can be mapped from any defined method request parameters
+ and the payload.
+
+ :return: IntegrationParameters
+ """
+ method_request_params: Dict[str, Any] = self.method_request_dict(context)
+
+ # requestParameters: {
+ # "integration.request.path.pathParam": "method.request.header.Content-Type"
+ # "integration.request.querystring.who": "method.request.querystring.who",
+ # "integration.request.header.Content-Type": "'application/json'",
+ # }
+ request_params = context.integration.get("requestParameters", {})
+
+ # resolve all integration request parameters with the already resolved method request parameters
+ integrations_parameters = {}
+ for k, v in request_params.items():
+ if v.lower() in method_request_params:
+ integrations_parameters[k] = method_request_params[v.lower()]
+ else:
+ # static values
+ integrations_parameters[k] = v.replace("'", "")
+
+ # build the integration parameters
+ result: IntegrationParameters = IntegrationParameters(path={}, querystring={}, headers={})
+ for k, v in integrations_parameters.items():
+ # headers
+ if k.startswith("integration.request.header."):
+ header_name = k.split(".")[-1]
+ result["headers"].update({header_name: v})
+
+ # querystring
+ if k.startswith("integration.request.querystring."):
+ param_name = k.split(".")[-1]
+ result["querystring"].update({param_name: v})
+
+ # path
+ if k.startswith("integration.request.path."):
+ path_name = k.split(".")[-1]
+ result["path"].update({path_name: v})
+
+ return result
+
+ def method_request_dict(self, context: ApiInvocationContext) -> Dict[str, Any]:
+ """
+ Build a dict with all method request parameters and their values.
+ :return: dict with all method request parameters and their values,
+ and all keys in lowercase
+ """
+ params: Dict[str, str] = {}
+
+ # TODO: add support for multi-values headers and multi-values querystring
+
+ for k, v in context.query_params().items():
+ params[f"method.request.querystring.{k}"] = v
+
+ for k, v in context.headers.items():
+ params[f"method.request.header.{k}"] = v
+
+ for k, v in context.path_params.items():
+ params[f"method.request.path.{k}"] = v
+
+ for k, v in context.stage_variables.items():
+ params[f"stagevariables.{k}"] = v
+
+ # TODO: add support for missing context variables, use `context.context` which contains most of the variables
+ # see https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html#context-variable-reference
+ # - all `context.identity` fields
+ # - protocol
+ # - requestId, extendedRequestId
+ # - all requestOverride, responseOverride
+ # - requestTime, requestTimeEpoch
+ # - resourcePath
+ # - wafResponseCode, webaclArn
+ params["context.accountId"] = context.account_id
+ params["context.apiId"] = context.api_id
+ params["context.domainName"] = context.domain_name
+ params["context.httpMethod"] = context.method
+ params["context.path"] = context.path
+ params["context.resourceId"] = context.resource_id
+ params["context.stage"] = context.stage
+
+ auth_context_authorizer = context.auth_context.get("authorizer") or {}
+ for k, v in auth_context_authorizer.items():
+ if isinstance(v, bool):
+ v = canonicalize_bool_to_str(v)
+ elif is_number(v):
+ v = str(v)
+
+ params[f"context.authorizer.{k.lower()}"] = v
+
+ if context.data:
+ params["method.request.body"] = context.data
+
+ return {key.lower(): val for key, val in params.items()}
+
+
+class ResponseParametersResolver:
+ def resolve(self, context: ApiInvocationContext) -> Dict[str, str]:
+ """
+ Resolve integration response parameters into method response parameters.
+ Integration response parameters can map header, body,
+ or static values to the header type of the method response.
+
+ :return: dict with all method response parameters and their values
+ """
+ integration_request_params: Dict[str, Any] = self.integration_request_dict(context)
+
+ # "responseParameters" : {
+ # "method.response.header.Location" : "integration.response.body.redirect.url",
+ # "method.response.header.x-user-id" : "integration.response.header.x-userid"
+ # }
+ integration_responses = context.integration.get("integrationResponses", {})
+ # XXX Fix for other status codes context.response contains a response status code, but response
+ # can be a LambdaResponse or Response object and the field is not the same, normalize it or use introspection
+ response_params = integration_responses.get("200", {}).get("responseParameters", {})
+
+ # resolve all integration request parameters with the already resolved method
+ # request parameters
+ method_parameters = {}
+ for k, v in response_params.items():
+ if v.lower() in integration_request_params:
+ method_parameters[k] = integration_request_params[v.lower()]
+ else:
+ # static values
+ method_parameters[k] = v.replace("'", "")
+
+ # build the integration parameters
+ result: Dict[str, str] = {}
+ for k, v in method_parameters.items():
+ # headers
+ if k.startswith("method.response.header."):
+ header_name = k.split(".")[-1]
+ result[header_name] = v
+
+ return result
+
+ def integration_request_dict(self, context: ApiInvocationContext) -> Dict[str, Any]:
+ params: Dict[str, str] = {}
+
+ for k, v in context.headers.items():
+ params[f"integration.request.header.{k}"] = v
+
+ if context.data:
+ params["integration.request.body"] = try_json(context.data)
+
+ return {key.lower(): val for key, val in params.items()}
+
+
+def make_json_response(message):
+ return requests_response(json.dumps(message), headers={"Content-Type": APPLICATION_JSON})
+
+
+def make_error_response(message, code=400, error_type=None):
+ if code == 404 and not error_type:
+ error_type = "NotFoundException"
+ error_type = error_type or "InvalidRequest"
+ return requests_error_response_json(message, code=code, error_type=error_type)
+
+
+def select_integration_response(matched_part: str, invocation_context: ApiInvocationContext):
+ int_responses = invocation_context.integration.get("integrationResponses") or {}
+ if select_by_pattern := [
+ response
+ for response in int_responses.values()
+ if response.get("selectionPattern")
+ and re.match(response.get("selectionPattern"), matched_part)
+ ]:
+ selected_response = select_by_pattern[0]
+ if len(select_by_pattern) > 1:
+ LOG.warning(
+ "Multiple integration responses matching '%s' statuscode. Choosing '%s' (first).",
+ matched_part,
+ selected_response["statusCode"],
+ )
+ else:
+ # choose default return code
+ default_responses = [
+ response for response in int_responses.values() if not response.get("selectionPattern")
+ ]
+ if not default_responses:
+ raise ApiGatewayIntegrationError("Internal server error", 500)
+
+ selected_response = default_responses[0]
+ if len(default_responses) > 1:
+ LOG.warning(
+ "Multiple default integration responses. Choosing %s (first).",
+ selected_response["statusCode"],
+ )
+ return selected_response
+
+
+def make_accepted_response():
+ response = Response()
+ response.status_code = 202
+ return response
+
+
+def get_api_id_from_path(path):
+ if match := re.match(PATH_REGEX_SUB, path):
+ return match.group(1)
+ return re.match(PATH_REGEX_MAIN, path).group(1)
+
+
+def is_test_invoke_method(method, path):
+ return method == "POST" and bool(re.match(PATH_REGEX_TEST_INVOKE_API, path))
+
+
+def get_stage_variables(context: ApiInvocationContext) -> Optional[Dict[str, str]]:
+ if is_test_invoke_method(context.method, context.path):
+ return None
+
+ if not context.stage:
+ return {}
+
+ account_id, region_name = get_api_account_id_and_region(context.api_id)
+ api_gateway_client = connect_to(
+ aws_access_key_id=account_id, region_name=region_name
+ ).apigateway
+ try:
+ response = api_gateway_client.get_stage(restApiId=context.api_id, stageName=context.stage)
+ return response.get("variables", {})
+ except Exception:
+ LOG.info("Failed to get stage %s for API id %s", context.stage, context.api_id)
+ return {}
+
+
+def tokenize_path(path):
+ return path.lstrip("/").split("/")
+
+
+def extract_path_params(path: str, extracted_path: str) -> Dict[str, str]:
+ tokenized_extracted_path = tokenize_path(extracted_path)
+ # Looks for '{' in the tokenized extracted path
+ path_params_list = [(i, v) for i, v in enumerate(tokenized_extracted_path) if "{" in v]
+ tokenized_path = tokenize_path(path)
+ path_params = {}
+ for param in path_params_list:
+ path_param_name = param[1][1:-1]
+ path_param_position = param[0]
+ if path_param_name.endswith("+"):
+ path_params[path_param_name.rstrip("+")] = "/".join(
+ tokenized_path[path_param_position:]
+ )
+ else:
+ path_params[path_param_name] = tokenized_path[path_param_position]
+ path_params = common.json_safe(path_params)
+ return path_params
+
+
+def extract_query_string_params(path: str) -> Tuple[str, Dict[str, str]]:
+ parsed_path = urlparse.urlparse(path)
+ if not path.startswith("//"):
+ path = parsed_path.path
+ parsed_query_string_params = urlparse.parse_qs(parsed_path.query)
+
+ query_string_params = {}
+ for query_param_name, query_param_values in parsed_query_string_params.items():
+ if len(query_param_values) == 1:
+ query_string_params[query_param_name] = query_param_values[0]
+ else:
+ query_string_params[query_param_name] = query_param_values
+
+ path = path or "/"
+ return path, query_string_params
+
+
+def get_cors_response(headers):
+ # TODO: for now we simply return "allow-all" CORS headers, but in the future
+ # we should implement custom headers for CORS rules, as supported by API Gateway:
+ # http://docs.aws.amazon.com/apigateway/latest/developerguide/how-to-cors.html
+ response = Response()
+ response.status_code = 200
+ response.headers["Access-Control-Allow-Origin"] = "*"
+ response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, PATCH"
+ response.headers["Access-Control-Allow-Headers"] = "*"
+ response._content = ""
+ return response
+
+
+def get_apigateway_path_for_resource(
+ api_id, resource_id, path_suffix="", resources=None, region_name=None
+):
+ if resources is None:
+ apigateway = connect_to(region_name=region_name).apigateway
+ resources = apigateway.get_resources(restApiId=api_id, limit=100)["items"]
+ target_resource = list(filter(lambda res: res["id"] == resource_id, resources))[0]
+ path_part = target_resource.get("pathPart", "")
+ if path_suffix:
+ if path_part:
+ path_suffix = "%s/%s" % (path_part, path_suffix)
+ else:
+ path_suffix = path_part
+ parent_id = target_resource.get("parentId")
+ if not parent_id:
+ return "/%s" % path_suffix
+ return get_apigateway_path_for_resource(
+ api_id,
+ parent_id,
+ path_suffix=path_suffix,
+ resources=resources,
+ region_name=region_name,
+ )
+
+
+def get_rest_api_paths(account_id: str, region_name: str, rest_api_id: str):
+ apigateway = connect_to(aws_access_key_id=account_id, region_name=region_name).apigateway
+ resources = apigateway.get_resources(restApiId=rest_api_id, limit=100)
+ resource_map = {}
+ for resource in resources["items"]:
+ path = resource.get("path")
+ # TODO: check if this is still required in the general case (can we rely on "path" being
+ # present?)
+ path = path or get_apigateway_path_for_resource(
+ rest_api_id, resource["id"], region_name=region_name
+ )
+ resource_map[path] = resource
+ return resource_map
+
+
+# TODO: Extract this to a set of rules that have precedence and easy to test individually.
+#
+# https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-method-settings
+# -method-request.html
+# https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-routes.html
+def get_resource_for_path(
+ path: str, method: str, path_map: Dict[str, Dict]
+) -> tuple[Optional[str], Optional[dict]]:
+ matches = []
+ # creates a regex from the input path if there are parameters, e.g /foo/{bar}/baz -> /foo/[
+ # ^\]+/baz, otherwise is a direct match.
+ for api_path, details in path_map.items():
+ api_path_regex = re.sub(r"{[^+]+\+}", r"[^\?#]+", api_path)
+ api_path_regex = re.sub(r"{[^}]+}", r"[^/]+", api_path_regex)
+ if re.match(r"^%s$" % api_path_regex, path):
+ matches.append((api_path, details))
+
+ # if there are no matches, it's not worth to proceed, bail here!
+ if not matches:
+ LOG.debug("No match found for path: '%s' and method: '%s'", path, method)
+ return None, None
+
+ if len(matches) == 1:
+ LOG.debug("Match found for path: '%s' and method: '%s'", path, method)
+ return matches[0]
+
+ # so we have more than one match
+ # /{proxy+} and /api/{proxy+} for inputs like /api/foo/bar
+ # /foo/{param1}/baz and /foo/{param1}/{param2} for inputs like /for/bar/baz
+ proxy_matches = []
+ param_matches = []
+ for match in matches:
+ match_methods = list(match[1].get("resourceMethods", {}).keys())
+ # only look for path matches if the request method is in the resource
+ if method.upper() in match_methods or "ANY" in match_methods:
+ # check if we have an exact match (exact matches take precedence) if the method is the same
+ if match[0] == path:
+ return match
+
+ elif path_matches_pattern(path, match[0]):
+ # parameters can fit in
+ param_matches.append(match)
+ continue
+
+ proxy_matches.append(match)
+
+ if param_matches:
+ # count the amount of parameters, return the one with the least which is the most precise
+ sorted_matches = sorted(param_matches, key=lambda x: x[0].count("{"))
+ LOG.debug("Match found for path: '%s' and method: '%s'", path, method)
+ return sorted_matches[0]
+
+ if proxy_matches:
+ # at this stage, we still have more than one match, but we have an eager example like
+ # /{proxy+} or /api/{proxy+}, so we pick the best match by sorting by length, only if they have a method
+ # that could match
+ sorted_matches = sorted(proxy_matches, key=lambda x: len(x[0]), reverse=True)
+ LOG.debug("Match found for path: '%s' and method: '%s'", path, method)
+ return sorted_matches[0]
+
+ # if there are no matches with a method that would match, return
+ LOG.debug("No match found for method: '%s' for matched path: %s", method, path)
+ return None, None
+
+
+def path_matches_pattern(path, api_path):
+ api_paths = api_path.split("/")
+ paths = path.split("/")
+ reg_check = re.compile(r"{(.*)}")
+ if len(api_paths) != len(paths):
+ return False
+ results = [
+ part == paths[indx]
+ for indx, part in enumerate(api_paths)
+ if reg_check.match(part) is None and part
+ ]
+
+ return len(results) > 0 and all(results)
+
+
+def connect_api_gateway_to_sqs(gateway_name, stage_name, queue_arn, path, account_id, region_name):
+ resources = {}
+ template = APIGATEWAY_SQS_DATA_INBOUND_TEMPLATE
+ resource_path = path.replace("/", "")
+
+ try:
+ arn = parse_arn(queue_arn)
+ queue_name = arn["resource"]
+ sqs_account = arn["account"]
+ sqs_region = arn["region"]
+ except InvalidArnException:
+ queue_name = queue_arn
+ sqs_account = account_id
+ sqs_region = region_name
+
+ partition = get_partition(region_name)
+ resources[resource_path] = [
+ {
+ "httpMethod": "POST",
+ "authorizationType": "NONE",
+ "integrations": [
+ {
+ "type": "AWS",
+ "uri": "arn:%s:apigateway:%s:sqs:path/%s/%s"
+ % (partition, sqs_region, sqs_account, queue_name),
+ "requestTemplates": {"application/json": template},
+ "requestParameters": {
+ "integration.request.header.Content-Type": "'application/x-www-form-urlencoded'"
+ },
+ }
+ ],
+ }
+ ]
+ return resource_utils.create_api_gateway(
+ name=gateway_name,
+ resources=resources,
+ stage_name=stage_name,
+ client=connect_to(aws_access_key_id=sqs_account, region_name=sqs_region).apigateway,
+ )
+
+
+def get_target_resource_details(
+ invocation_context: ApiInvocationContext,
+) -> Tuple[Optional[str], Optional[dict]]:
+ """Look up and return the API GW resource (path pattern + resource dict) for the given invocation context."""
+ path_map = get_rest_api_paths(
+ account_id=invocation_context.account_id,
+ region_name=invocation_context.region_name,
+ rest_api_id=invocation_context.api_id,
+ )
+ relative_path = invocation_context.invocation_path.rstrip("/") or "/"
+ try:
+ extracted_path, resource = get_resource_for_path(
+ path=relative_path, method=invocation_context.method, path_map=path_map
+ )
+ if not extracted_path:
+ return None, None
+ invocation_context.resource = resource
+ invocation_context.resource_path = extracted_path
+ try:
+ invocation_context.path_params = extract_path_params(
+ path=relative_path, extracted_path=extracted_path
+ )
+ except Exception:
+ invocation_context.path_params = {}
+
+ return extracted_path, resource
+
+ except Exception:
+ return None, None
+
+
+def get_target_resource_method(invocation_context: ApiInvocationContext) -> Optional[Dict]:
+ """Look up and return the API GW resource method for the given invocation context."""
+ _, resource = get_target_resource_details(invocation_context)
+ if not resource:
+ return None
+ methods = resource.get("resourceMethods") or {}
+ return methods.get(invocation_context.method.upper()) or methods.get("ANY")
+
+
+def event_type_from_route_key(invocation_context):
+ action = invocation_context.route["RouteKey"]
+ return (
+ "CONNECT"
+ if action == "$connect"
+ else "DISCONNECT"
+ if action == "$disconnect"
+ else "MESSAGE"
+ )
+
+
+def get_event_request_context(invocation_context: ApiInvocationContext):
+ method = invocation_context.method
+ path = invocation_context.path
+ headers = invocation_context.headers
+ integration_uri = invocation_context.integration_uri
+ resource_path = invocation_context.resource_path
+ resource_id = invocation_context.resource_id
+
+ set_api_id_stage_invocation_path(invocation_context)
+ api_id = invocation_context.api_id
+ stage = invocation_context.stage
+
+ if "_user_request_" in invocation_context.raw_uri:
+ full_path = invocation_context.raw_uri.partition("_user_request_")[2]
+ else:
+ full_path = invocation_context.raw_uri.removeprefix(f"/{stage}")
+ relative_path, query_string_params = extract_query_string_params(path=full_path)
+
+ source_ip = invocation_context.auth_identity.get("sourceIp")
+ integration_uri = integration_uri or ""
+ account_id = integration_uri.split(":lambda:path")[-1].split(":function:")[0].split(":")[-1]
+ account_id = account_id or DEFAULT_AWS_ACCOUNT_ID
+ request_context = {
+ "accountId": account_id,
+ "apiId": api_id,
+ "resourcePath": resource_path or relative_path,
+ "domainPrefix": invocation_context.domain_prefix,
+ "domainName": invocation_context.domain_name,
+ "resourceId": resource_id,
+ "requestId": long_uid(),
+ "identity": {
+ "accountId": account_id,
+ "sourceIp": source_ip,
+ "userAgent": headers.get("User-Agent"),
+ },
+ "httpMethod": method,
+ "protocol": "HTTP/1.1",
+ "requestTime": datetime.now(timezone.utc).strftime(REQUEST_TIME_DATE_FORMAT),
+ "requestTimeEpoch": int(time.time() * 1000),
+ "authorizer": {},
+ }
+
+ if invocation_context.is_websocket_request():
+ request_context["connectionId"] = invocation_context.connection_id
+
+ # set "authorizer" and "identity" event attributes from request context
+ authorizer_result = invocation_context.authorizer_result
+ if authorizer_result:
+ request_context["authorizer"] = authorizer_result
+ request_context["identity"].update(invocation_context.auth_identity or {})
+
+ if not is_test_invoke_method(method, path):
+ request_context["path"] = (f"/{stage}" if stage else "") + relative_path
+ request_context["stage"] = stage
+ return request_context
+
+
+def set_api_id_stage_invocation_path(
+ invocation_context: ApiInvocationContext,
+) -> ApiInvocationContext:
+ # skip if all details are already available
+ values = (
+ invocation_context.api_id,
+ invocation_context.stage,
+ invocation_context.path_with_query_string,
+ )
+ if all(values):
+ return invocation_context
+
+ # skip if this is a websocket request
+ if invocation_context.is_websocket_request():
+ return invocation_context
+
+ path = invocation_context.path
+ headers = invocation_context.headers
+
+ path_match = re.search(PATH_REGEX_USER_REQUEST, path)
+ host_header = headers.get(HEADER_LOCALSTACK_EDGE_URL, "") or headers.get("Host") or ""
+ host_match = re.search(HOST_REGEX_EXECUTE_API, host_header)
+ test_invoke_match = re.search(PATH_REGEX_TEST_INVOKE_API, path)
+ if path_match:
+ api_id = path_match.group(1)
+ stage = path_match.group(2)
+ relative_path_w_query_params = "/%s" % path_match.group(3)
+ elif host_match:
+ api_id = extract_api_id_from_hostname_in_url(host_header)
+ stage = path.strip("/").split("/")[0]
+ relative_path_w_query_params = "/%s" % path.lstrip("/").partition("/")[2]
+ elif test_invoke_match:
+ stage = invocation_context.stage
+ api_id = invocation_context.api_id
+ relative_path_w_query_params = invocation_context.path_with_query_string
+ else:
+ raise Exception(
+ f"Unable to extract API Gateway details from request: {path} {dict(headers)}"
+ )
+
+ # set details in invocation context
+ invocation_context.api_id = api_id
+ invocation_context.stage = stage
+ invocation_context.path_with_query_string = relative_path_w_query_params
+ return invocation_context
+
+
+def get_api_account_id_and_region(api_id: str) -> Tuple[Optional[str], Optional[str]]:
+ """Return the region name for the given REST API ID"""
+ for account_id, account in apigateway_backends.items():
+ for region_name, region in account.items():
+ # compare low case keys to avoid case sensitivity issues
+ for key in region.apis.keys():
+ if key.lower() == api_id.lower():
+ return account_id, region_name
+ return None, None
+
+
+def extract_api_id_from_hostname_in_url(hostname: str) -> str:
+ """Extract API ID 'id123' from URLs like https://id123.execute-api.localhost.localstack.cloud:4566"""
+ match = re.match(HOST_REGEX_EXECUTE_API, hostname)
+ return match.group(1)
+
+
+def multi_value_dict_for_list(elements: Union[List, Dict]) -> Dict:
+ temp_mv_dict = defaultdict(list)
+ for key in elements:
+ if isinstance(key, (list, tuple)):
+ key, value = key
+ else:
+ value = elements[key]
+
+ key = to_str(key)
+ temp_mv_dict[key].append(value)
+ return {k: tuple(v) for k, v in temp_mv_dict.items()}
diff --git a/localstack-core/localstack/services/apigateway/legacy/integration.py b/localstack-core/localstack/services/apigateway/legacy/integration.py
new file mode 100644
index 0000000000000..12852fff266af
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/legacy/integration.py
@@ -0,0 +1,1119 @@
+import base64
+import json
+import logging
+import re
+from abc import ABC, abstractmethod
+from functools import lru_cache
+from http import HTTPMethod, HTTPStatus
+from typing import Any, Dict
+from urllib.parse import urljoin
+
+import requests
+from botocore.exceptions import ClientError
+from moto.apigatewayv2.exceptions import BadRequestException
+from requests import Response
+
+from localstack import config
+from localstack.aws.connect import (
+ INTERNAL_REQUEST_PARAMS_HEADER,
+ InternalRequestParameters,
+ connect_to,
+ dump_dto,
+)
+from localstack.constants import APPLICATION_JSON, HEADER_CONTENT_TYPE
+from localstack.services.apigateway.legacy.context import ApiInvocationContext
+from localstack.services.apigateway.legacy.helpers import (
+ ApiGatewayIntegrationError,
+ IntegrationParameters,
+ RequestParametersResolver,
+ ResponseParametersResolver,
+ extract_path_params,
+ extract_query_string_params,
+ get_event_request_context,
+ get_stage_variables,
+ make_error_response,
+ multi_value_dict_for_list,
+)
+from localstack.services.apigateway.legacy.templates import (
+ MappingTemplates,
+ RequestTemplates,
+ ResponseTemplates,
+)
+from localstack.services.stepfunctions.stepfunctions_utils import await_sfn_execution_result
+from localstack.utils import common
+from localstack.utils.aws.arns import ARN_PARTITION_REGEX, extract_region_from_arn, get_partition
+from localstack.utils.aws.aws_responses import (
+ LambdaResponse,
+ request_response_stream,
+ requests_response,
+)
+from localstack.utils.aws.client_types import ServicePrincipal
+from localstack.utils.aws.request_context import mock_aws_request_headers
+from localstack.utils.aws.templating import VtlTemplate
+from localstack.utils.collections import dict_multi_values, remove_attributes
+from localstack.utils.common import make_http_request, to_str
+from localstack.utils.http import add_query_params_to_url, canonicalize_headers, parse_request_data
+from localstack.utils.json import json_safe, try_json
+from localstack.utils.strings import camel_to_snake_case, to_bytes
+
+LOG = logging.getLogger(__name__)
+
+
+class IntegrationAccessError(ApiGatewayIntegrationError):
+ """
+ Error message when an integration cannot be accessed.
+ """
+
+ def __init__(self):
+ super().__init__("Internal server error", 500)
+
+
+class BackendIntegration(ABC):
+ """Abstract base class representing a backend integration"""
+
+ def __init__(self):
+ self.request_templates = RequestTemplates()
+ self.response_templates = ResponseTemplates()
+ self.request_params_resolver = RequestParametersResolver()
+ self.response_params_resolver = ResponseParametersResolver()
+
+ @abstractmethod
+ def invoke(self, invocation_context: ApiInvocationContext):
+ pass
+
+ @classmethod
+ def _create_response(cls, status_code, headers, data=""):
+ response = Response()
+ response.status_code = status_code
+ response.headers = headers
+ response._content = data
+ return response
+
+ @classmethod
+ def apply_request_parameters(
+ cls, integration_params: IntegrationParameters, headers: Dict[str, Any]
+ ):
+ for k, v in integration_params.get("headers").items():
+ headers.update({k: v})
+
+ @classmethod
+ def apply_response_parameters(
+ cls, invocation_context: ApiInvocationContext, response: Response
+ ):
+ integration = invocation_context.integration
+ integration_responses = integration.get("integrationResponses") or {}
+ if not integration_responses:
+ return response
+ entries = list(integration_responses.keys())
+ return_code = str(response.status_code)
+ if return_code not in entries:
+ if len(entries) > 1:
+ LOG.info("Found multiple integration response status codes: %s", entries)
+ return response
+ return_code = entries[0]
+ response_params = integration_responses[return_code].get("responseParameters", {})
+ for key, value in response_params.items():
+ # TODO: add support for method.response.body, etc ...
+ if str(key).lower().startswith("method.response.header."):
+ header_name = key[len("method.response.header.") :]
+ response.headers[header_name] = value.strip("'")
+ return response
+
+ @classmethod
+ def render_template_selection_expression(cls, invocation_context: ApiInvocationContext):
+ integration = invocation_context.integration
+ template_selection_expression = integration.get("templateSelectionExpression")
+
+ # AWS template selection relies on the content type
+ # to select an input template or output mapping AND template selection expressions.
+ # All of them will fall back to the $default template if a matching template is not found.
+ if not template_selection_expression:
+ content_type = invocation_context.headers.get(HEADER_CONTENT_TYPE, APPLICATION_JSON)
+ if integration.get("RequestTemplates", {}).get(content_type):
+ return content_type
+ return "$default"
+
+ data = try_json(invocation_context.data)
+ variables = {
+ "request": {
+ "header": invocation_context.headers,
+ "querystring": invocation_context.query_params(),
+ "body": data,
+ "context": invocation_context.context or {},
+ "stage_variables": invocation_context.stage_variables or {},
+ }
+ }
+ return VtlTemplate().render_vtl(template_selection_expression, variables) or "$default"
+
+
+@lru_cache(maxsize=64)
+def get_service_factory(region_name: str, role_arn: str):
+ if role_arn:
+ return connect_to.with_assumed_role(
+ role_arn=role_arn,
+ region_name=region_name,
+ service_principal=ServicePrincipal.apigateway,
+ session_name="BackplaneAssumeRoleSession",
+ )
+ else:
+ return connect_to(region_name=region_name)
+
+
+@lru_cache(maxsize=64)
+def get_internal_mocked_headers(
+ service_name: str,
+ region_name: str,
+ source_arn: str,
+ role_arn: str | None,
+) -> dict[str, str]:
+ if role_arn:
+ access_key_id = (
+ connect_to(region_name=region_name)
+ .sts.request_metadata(service_principal=ServicePrincipal.apigateway)
+ .assume_role(RoleArn=role_arn, RoleSessionName="BackplaneAssumeRoleSession")[
+ "Credentials"
+ ]["AccessKeyId"]
+ )
+ else:
+ access_key_id = None
+ headers = mock_aws_request_headers(
+ service=service_name, aws_access_key_id=access_key_id, region_name=region_name
+ )
+
+ dto = InternalRequestParameters(
+ service_principal=ServicePrincipal.apigateway, source_arn=source_arn
+ )
+ headers[INTERNAL_REQUEST_PARAMS_HEADER] = dump_dto(dto)
+ return headers
+
+
+def get_source_arn(invocation_context: ApiInvocationContext):
+ return f"arn:{get_partition(invocation_context.region_name)}:execute-api:{invocation_context.region_name}:{invocation_context.account_id}:{invocation_context.api_id}/{invocation_context.stage}/{invocation_context.method}{invocation_context.path}"
+
+
+def call_lambda(
+ function_arn: str, event: bytes, asynchronous: bool, invocation_context: ApiInvocationContext
+) -> str:
+ clients = get_service_factory(
+ region_name=extract_region_from_arn(function_arn),
+ role_arn=invocation_context.integration.get("credentials"),
+ )
+ inv_result = clients.lambda_.request_metadata(
+ service_principal=ServicePrincipal.apigateway, source_arn=get_source_arn(invocation_context)
+ ).invoke(
+ FunctionName=function_arn,
+ Payload=event,
+ InvocationType="Event" if asynchronous else "RequestResponse",
+ )
+ if payload := inv_result.get("Payload"):
+ payload = to_str(payload.read())
+ return payload
+ return ""
+
+
+class LambdaProxyIntegration(BackendIntegration):
+ @classmethod
+ def update_content_length(cls, response: Response):
+ if response and response.content is not None:
+ response.headers["Content-Length"] = str(len(response.content))
+
+ @classmethod
+ def lambda_result_to_response(cls, result) -> LambdaResponse:
+ response = LambdaResponse()
+ response.headers.update({"content-type": "application/json"})
+ parsed_result = result if isinstance(result, dict) else json.loads(str(result or "{}"))
+ parsed_result = common.json_safe(parsed_result)
+ parsed_result = {} if parsed_result is None else parsed_result
+
+ if set(parsed_result) - {
+ "body",
+ "statusCode",
+ "headers",
+ "isBase64Encoded",
+ "multiValueHeaders",
+ }:
+ LOG.warning(
+ 'Lambda output should follow the next JSON format: { "isBase64Encoded": true|false, "statusCode": httpStatusCode, "headers": { "headerName": "headerValue", ... },"body": "..."}\n Lambda output: %s',
+ parsed_result,
+ )
+ response.status_code = 502
+ response._content = json.dumps({"message": "Internal server error"})
+ return response
+
+ response.status_code = int(parsed_result.get("statusCode", 200))
+ parsed_headers = parsed_result.get("headers", {})
+ if parsed_headers is not None:
+ response.headers.update(parsed_headers)
+ try:
+ result_body = parsed_result.get("body")
+ if isinstance(result_body, dict):
+ response._content = json.dumps(result_body)
+ else:
+ body_bytes = to_bytes(to_str(result_body or ""))
+ if parsed_result.get("isBase64Encoded", False):
+ body_bytes = base64.b64decode(body_bytes)
+ response._content = body_bytes
+ except Exception as e:
+ LOG.warning("Couldn't set Lambda response content: %s", e)
+ response._content = "{}"
+ response.multi_value_headers = parsed_result.get("multiValueHeaders") or {}
+ return response
+
+ @staticmethod
+ def fix_proxy_path_params(path_params):
+ proxy_path_param_value = path_params.get("proxy+")
+ if not proxy_path_param_value:
+ return
+ del path_params["proxy+"]
+ path_params["proxy"] = proxy_path_param_value
+
+ @staticmethod
+ def validate_integration_method(invocation_context: ApiInvocationContext):
+ if invocation_context.integration["httpMethod"] != HTTPMethod.POST:
+ raise ApiGatewayIntegrationError("Internal server error", status_code=500)
+
+ @classmethod
+ def construct_invocation_event(
+ cls, method, path, headers, data, query_string_params=None, is_base64_encoded=False
+ ):
+ query_string_params = query_string_params or parse_request_data(method, path, "")
+
+ single_value_query_string_params = {
+ k: v[-1] if isinstance(v, list) else v for k, v in query_string_params.items()
+ }
+ # Some headers get capitalized like in CloudFront, see
+ # https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/add-origin-custom-headers.html#add-origin-custom-headers-forward-authorization
+ # It seems AWS_PROXY lambda integrations are behind cloudfront, as seen by the returned headers in AWS
+ to_capitalize: list[str] = ["authorization"] # some headers get capitalized
+ headers = {
+ k.capitalize() if k.lower() in to_capitalize else k: v for k, v in headers.items()
+ }
+
+ # AWS canonical header names, converting them to lower-case
+ headers = canonicalize_headers(headers)
+
+ return {
+ "path": "/" + path.lstrip("/"),
+ "headers": headers,
+ "multiValueHeaders": multi_value_dict_for_list(headers),
+ "body": data,
+ "isBase64Encoded": is_base64_encoded,
+ "httpMethod": method,
+ "queryStringParameters": single_value_query_string_params or None,
+ "multiValueQueryStringParameters": dict_multi_values(query_string_params) or None,
+ }
+
+ @classmethod
+ def process_apigateway_invocation(
+ cls,
+ func_arn,
+ path,
+ payload,
+ invocation_context: ApiInvocationContext,
+ query_string_params=None,
+ ) -> str:
+ if (path_params := invocation_context.path_params) is None:
+ path_params = {}
+ if (request_context := invocation_context.context) is None:
+ request_context = {}
+ try:
+ resource_path = invocation_context.resource_path or path
+ event = cls.construct_invocation_event(
+ invocation_context.method,
+ path,
+ invocation_context.headers,
+ payload,
+ query_string_params,
+ invocation_context.is_data_base64_encoded,
+ )
+ path_params = dict(path_params)
+ cls.fix_proxy_path_params(path_params)
+ event["pathParameters"] = path_params
+ event["resource"] = resource_path
+ event["requestContext"] = request_context
+ event["stageVariables"] = invocation_context.stage_variables
+ LOG.debug(
+ "Running Lambda function %s from API Gateway invocation: %s %s",
+ func_arn,
+ invocation_context.method or "GET",
+ path,
+ )
+ asynchronous = invocation_context.headers.get("X-Amz-Invocation-Type") == "'Event'"
+ return call_lambda(
+ function_arn=func_arn,
+ event=to_bytes(json.dumps(event)),
+ asynchronous=asynchronous,
+ invocation_context=invocation_context,
+ )
+ except ClientError as e:
+ raise IntegrationAccessError() from e
+ except Exception as e:
+ LOG.warning(
+ "Unable to run Lambda function on API Gateway message: %s",
+ e,
+ )
+
+ def invoke(self, invocation_context: ApiInvocationContext):
+ self.validate_integration_method(invocation_context)
+ uri = (
+ invocation_context.integration.get("uri")
+ or invocation_context.integration.get("integrationUri")
+ or ""
+ )
+ invocation_context.context = get_event_request_context(invocation_context)
+ relative_path, query_string_params = extract_query_string_params(
+ path=invocation_context.path_with_query_string
+ )
+ try:
+ path_params = extract_path_params(
+ path=relative_path, extracted_path=invocation_context.resource_path
+ )
+ invocation_context.path_params = path_params
+ except Exception:
+ pass
+
+ func_arn = uri
+ if ":lambda:path" in uri:
+ func_arn = uri.split(":lambda:path")[1].split("functions/")[1].split("/invocations")[0]
+
+ if invocation_context.authorizer_type:
+ invocation_context.context["authorizer"] = invocation_context.authorizer_result
+
+ payload = self.request_templates.render(invocation_context)
+
+ result = self.process_apigateway_invocation(
+ func_arn=func_arn,
+ path=relative_path,
+ payload=payload,
+ invocation_context=invocation_context,
+ query_string_params=query_string_params,
+ )
+
+ response = LambdaResponse()
+ response.headers.update({"content-type": "application/json"})
+ parsed_result = json.loads(str(result or "{}"))
+ parsed_result = common.json_safe(parsed_result)
+ parsed_result = {} if parsed_result is None else parsed_result
+
+ if set(parsed_result) - {
+ "body",
+ "statusCode",
+ "headers",
+ "isBase64Encoded",
+ "multiValueHeaders",
+ }:
+ LOG.warning(
+ 'Lambda output should follow the next JSON format: { "isBase64Encoded": true|false, "statusCode": httpStatusCode, "headers": { "headerName": "headerValue", ... },"body": "..."}\n Lambda output: %s',
+ parsed_result,
+ )
+ response.status_code = 502
+ response._content = json.dumps({"message": "Internal server error"})
+ return response
+
+ response.status_code = int(parsed_result.get("statusCode", 200))
+ parsed_headers = parsed_result.get("headers", {})
+ if parsed_headers is not None:
+ response.headers.update(parsed_headers)
+ try:
+ result_body = parsed_result.get("body")
+ if isinstance(result_body, dict):
+ response._content = json.dumps(result_body)
+ else:
+ body_bytes = to_bytes(result_body or "")
+ if parsed_result.get("isBase64Encoded", False):
+ body_bytes = base64.b64decode(body_bytes)
+ response._content = body_bytes
+ except Exception as e:
+ LOG.warning("Couldn't set Lambda response content: %s", e)
+ response._content = "{}"
+ response.multi_value_headers = parsed_result.get("multiValueHeaders") or {}
+
+ # apply custom response template
+ self.update_content_length(response)
+ invocation_context.response = response
+
+ return invocation_context.response
+
+
+class LambdaIntegration(BackendIntegration):
+ def invoke(self, invocation_context: ApiInvocationContext):
+ invocation_context.stage_variables = get_stage_variables(invocation_context)
+ headers = invocation_context.headers
+
+ # resolve integration parameters
+ integration_parameters = self.request_params_resolver.resolve(context=invocation_context)
+ headers.update(integration_parameters.get("headers", {}))
+
+ if invocation_context.authorizer_type:
+ invocation_context.context["authorizer"] = invocation_context.authorizer_result
+
+ func_arn = self._lambda_integration_uri(invocation_context)
+ # integration type "AWS" is only supported for WebSocket APIs and REST
+ # API (v1), but the template selection expression is only supported for
+ # Websockets
+ if invocation_context.is_websocket_request():
+ template_key = self.render_template_selection_expression(invocation_context)
+ payload = self.request_templates.render(invocation_context, template_key)
+ else:
+ payload = self.request_templates.render(invocation_context)
+
+ asynchronous = headers.get("X-Amz-Invocation-Type", "").strip("'") == "Event"
+ try:
+ result = call_lambda(
+ function_arn=func_arn,
+ event=to_bytes(payload or ""),
+ asynchronous=asynchronous,
+ invocation_context=invocation_context,
+ )
+ except ClientError as e:
+ raise IntegrationAccessError() from e
+
+ # default lambda status code is 200
+ response = LambdaResponse()
+ response.status_code = 200
+ response._content = result
+
+ if asynchronous:
+ response._content = ""
+
+ # response template
+ invocation_context.response = response
+ self.response_templates.render(invocation_context)
+ invocation_context.response.headers["Content-Length"] = str(len(response.content or ""))
+
+ headers = self.response_params_resolver.resolve(invocation_context)
+ invocation_context.response.headers.update(headers)
+
+ return invocation_context.response
+
+ def _lambda_integration_uri(self, invocation_context: ApiInvocationContext):
+ """
+ https://docs.aws.amazon.com/apigateway/latest/developerguide/aws-api-gateway-stage-variables-reference.html
+ """
+ uri = (
+ invocation_context.integration.get("uri")
+ or invocation_context.integration.get("integrationUri")
+ or ""
+ )
+ variables = {"stageVariables": invocation_context.stage_variables}
+ uri = VtlTemplate().render_vtl(uri, variables)
+ if ":lambda:path" in uri:
+ uri = uri.split(":lambda:path")[1].split("functions/")[1].split("/invocations")[0]
+ return uri
+
+
+class KinesisIntegration(BackendIntegration):
+ def invoke(self, invocation_context: ApiInvocationContext):
+ integration = invocation_context.integration
+ integration_type_orig = integration.get("type") or integration.get("integrationType") or ""
+ integration_type = integration_type_orig.upper()
+ uri = integration.get("uri") or integration.get("integrationUri") or ""
+ integration_subtype = integration.get("integrationSubtype")
+
+ if uri.endswith("kinesis:action/PutRecord") or integration_subtype == "Kinesis-PutRecord":
+ target = "Kinesis_20131202.PutRecord"
+ elif uri.endswith("kinesis:action/PutRecords"):
+ target = "Kinesis_20131202.PutRecords"
+ elif uri.endswith("kinesis:action/ListStreams"):
+ target = "Kinesis_20131202.ListStreams"
+ else:
+ LOG.info(
+ "Unexpected API Gateway integration URI '%s' for integration type %s",
+ uri,
+ integration_type,
+ )
+ target = ""
+
+ try:
+ # xXx this "event" request context is used in multiple places, we probably
+ # want to refactor this into a model class.
+ # I'd argue we should not make a decision on the event_request_context inside the integration because,
+ # it's different between API types (REST, HTTP, WebSocket) and per event version
+ invocation_context.context = get_event_request_context(invocation_context)
+ invocation_context.stage_variables = get_stage_variables(invocation_context)
+
+ # integration type "AWS" is only supported for WebSocket APIs and REST
+ # API (v1), but the template selection expression is only supported for
+ # Websockets
+ if invocation_context.is_websocket_request():
+ template_key = self.render_template_selection_expression(invocation_context)
+ payload = self.request_templates.render(invocation_context, template_key)
+ else:
+ # For HTTP APIs with a specified integration_subtype,
+ # a key-value map specifying parameters that are passed to AWS_PROXY integrations
+ if integration_type == "AWS_PROXY" and integration_subtype == "Kinesis-PutRecord":
+ payload = self._create_request_parameters(invocation_context)
+ else:
+ payload = self.request_templates.render(invocation_context)
+
+ except Exception as e:
+ LOG.warning("Unable to convert API Gateway payload to str", e)
+ raise
+
+ # forward records to target kinesis stream
+ headers = get_internal_mocked_headers(
+ service_name="kinesis",
+ region_name=invocation_context.region_name,
+ role_arn=invocation_context.integration.get("credentials"),
+ source_arn=get_source_arn(invocation_context),
+ )
+ headers["X-Amz-Target"] = target
+
+ result = common.make_http_request(
+ url=config.internal_service_url(), data=payload, headers=headers, method="POST"
+ )
+
+ # apply response template
+ invocation_context.response = result
+ self.response_templates.render(invocation_context)
+ return invocation_context.response
+
+ @classmethod
+ def _validate_required_params(cls, request_parameters: Dict[str, Any]) -> None:
+ if not request_parameters:
+ raise BadRequestException("Missing required parameters")
+ # https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-aws-services-reference.html#Kinesis-PutRecord
+ stream_name = request_parameters.get("StreamName")
+ partition_key = request_parameters.get("PartitionKey")
+ data = request_parameters.get("Data")
+
+ if not stream_name:
+ raise BadRequestException("StreamName")
+
+ if not partition_key:
+ raise BadRequestException("PartitionKey")
+
+ if not data:
+ raise BadRequestException("Data")
+
+ def _create_request_parameters(
+ self, invocation_context: ApiInvocationContext
+ ) -> Dict[str, Any]:
+ request_parameters = invocation_context.integration.get("requestParameters", {})
+ self._validate_required_params(request_parameters)
+
+ variables = {
+ "request": {
+ "header": invocation_context.headers,
+ "querystring": invocation_context.query_params(),
+ "body": invocation_context.data_as_string(),
+ "context": invocation_context.context or {},
+ "stage_variables": invocation_context.stage_variables or {},
+ }
+ }
+
+ if invocation_context.headers.get("Content-Type") == "application/json":
+ variables["request"]["body"] = json.loads(invocation_context.data_as_string())
+ else:
+ # AWS parity no content type still yields a valid response from Kinesis
+ variables["request"]["body"] = try_json(invocation_context.data_as_string())
+
+ # Required parameters
+ payload = {
+ "StreamName": VtlTemplate().render_vtl(request_parameters.get("StreamName"), variables),
+ "Data": VtlTemplate().render_vtl(request_parameters.get("Data"), variables),
+ "PartitionKey": VtlTemplate().render_vtl(
+ request_parameters.get("PartitionKey"), variables
+ ),
+ }
+ # Optional Parameters
+ if "ExplicitHashKey" in request_parameters:
+ payload["ExplicitHashKey"] = VtlTemplate().render_vtl(
+ request_parameters.get("ExplicitHashKey"), variables
+ )
+ if "SequenceNumberForOrdering" in request_parameters:
+ payload["SequenceNumberForOrdering"] = VtlTemplate().render_vtl(
+ request_parameters.get("SequenceNumberForOrdering"), variables
+ )
+ # TODO: XXX we don't support the Region parameter
+ # if "Region" in request_parameters:
+ # payload["Region"] = VtlTemplate().render_vtl(
+ # request_parameters.get("Region"), variables
+ # )
+ return json.dumps(payload)
+
+
+class DynamoDBIntegration(BackendIntegration):
+ def invoke(self, invocation_context: ApiInvocationContext):
+ # TODO we might want to do it plain http instead of using boto here, like kinesis
+ integration = invocation_context.integration
+ uri = integration.get("uri") or integration.get("integrationUri") or ""
+
+ # example: arn:aws:apigateway:us-east-1:dynamodb:action/PutItem&Table=MusicCollection
+ action = uri.split(":dynamodb:action/")[1].split("&")[0]
+
+ # render request template
+ payload = self.request_templates.render(invocation_context)
+ payload = json.loads(payload)
+
+ # determine target method via reflection
+ clients = get_service_factory(
+ region_name=invocation_context.region_name,
+ role_arn=invocation_context.integration.get("credentials"),
+ )
+ dynamo_client = clients.dynamodb.request_metadata(
+ service_principal=ServicePrincipal.apigateway,
+ source_arn=get_source_arn(invocation_context),
+ )
+ method_name = camel_to_snake_case(action)
+ client_method = getattr(dynamo_client, method_name, None)
+ if not client_method:
+ raise Exception(f"Unsupported action {action} in API Gateway integration URI {uri}")
+
+ # run request against DynamoDB backend
+ try:
+ response = client_method(**payload)
+ except ClientError as e:
+ response = e.response
+ # The request body is packed into the "Error" field. To make the response match AWS, we will remove that
+ # field and merge with the response dict
+ error = response.pop("Error", {})
+ error.pop("Code", None) # the Code is also something not relayed
+ response |= error
+
+ status_code = response.get("ResponseMetadata", {}).get("HTTPStatusCode", 200)
+ # apply response templates
+ response_content = json.dumps(remove_attributes(response, ["ResponseMetadata"]))
+ response_obj = requests_response(content=response_content)
+ response = self.response_templates.render(invocation_context, response=response_obj)
+
+ # construct final response
+ # TODO: set response header based on response templates
+ headers = {HEADER_CONTENT_TYPE: APPLICATION_JSON}
+ response = requests_response(response, headers=headers, status_code=status_code)
+
+ return response
+
+
+class S3Integration(BackendIntegration):
+ # target ARN patterns
+ TARGET_REGEX_PATH_S3_URI = rf"{ARN_PARTITION_REGEX}:apigateway:[a-zA-Z0-9\-]+:s3:path/(?P[^/]+)/(?P.+)$"
+ TARGET_REGEX_ACTION_S3_URI = rf"{ARN_PARTITION_REGEX}:apigateway:[a-zA-Z0-9\-]+:s3:action/(?:GetObject&Bucket\=(?P[^&]+)&Key\=(?P.+))$"
+
+ def invoke(self, invocation_context: ApiInvocationContext):
+ invocation_path = invocation_context.path_with_query_string
+ integration = invocation_context.integration
+ path_params = invocation_context.path_params
+ relative_path, query_string_params = extract_query_string_params(path=invocation_path)
+ uri = integration.get("uri") or integration.get("integrationUri") or ""
+
+ s3 = connect_to().s3
+ uri = apply_request_parameters(
+ uri,
+ integration=integration,
+ path_params=path_params,
+ query_params=query_string_params,
+ )
+ uri_match = re.match(self.TARGET_REGEX_PATH_S3_URI, uri) or re.match(
+ self.TARGET_REGEX_ACTION_S3_URI, uri
+ )
+ if not uri_match:
+ msg = "Request URI does not match s3 specifications"
+ LOG.warning(msg)
+ return make_error_response(msg, 400)
+
+ bucket, object_key = uri_match.group("bucket", "object")
+ LOG.debug("Getting request for bucket %s object %s", bucket, object_key)
+
+ action = None
+ invoke_args = {"Bucket": bucket, "Key": object_key}
+ match invocation_context.method:
+ case HTTPMethod.GET:
+ action = s3.get_object
+ case HTTPMethod.PUT:
+ invoke_args["Body"] = invocation_context.data
+ action = s3.put_object
+ case HTTPMethod.DELETE:
+ action = s3.delete_object
+ case _:
+ make_error_response(
+ "The specified method is not allowed against this resource.", 405
+ )
+
+ try:
+ object = action(**invoke_args)
+ except s3.exceptions.NoSuchKey:
+ msg = f"Object {object_key} not found"
+ LOG.debug(msg)
+ return make_error_response(msg, 404)
+
+ headers = mock_aws_request_headers(
+ service="s3",
+ aws_access_key_id=invocation_context.account_id,
+ region_name=invocation_context.region_name,
+ )
+
+ if object.get("ContentType"):
+ headers["Content-Type"] = object["ContentType"]
+
+ # stream used so large files do not fill memory
+ if body := object.get("Body"):
+ response = request_response_stream(stream=body, headers=headers)
+ else:
+ response = requests_response(content="", headers=headers)
+ return response
+
+
+class HTTPIntegration(BackendIntegration):
+ @staticmethod
+ def _set_http_apigw_headers(headers: Dict[str, Any], invocation_context: ApiInvocationContext):
+ del headers["host"]
+ headers["x-amzn-apigateway-api-id"] = invocation_context.api_id
+ return headers
+
+ def invoke(self, invocation_context: ApiInvocationContext):
+ invocation_path = invocation_context.path_with_query_string
+ integration = invocation_context.integration
+ path_params = invocation_context.path_params
+ method = invocation_context.method
+ headers = invocation_context.headers
+
+ relative_path, query_string_params = extract_query_string_params(path=invocation_path)
+ uri = integration.get("uri") or integration.get("integrationUri") or ""
+
+ # resolve integration parameters
+ integration_parameters = self.request_params_resolver.resolve(context=invocation_context)
+ headers.update(integration_parameters.get("headers", {}))
+ self._set_http_apigw_headers(headers, invocation_context)
+
+ if ":servicediscovery:" in uri:
+ # check if this is a servicediscovery integration URI
+ client = connect_to().servicediscovery
+ service_id = uri.split("/")[-1]
+ instances = client.list_instances(ServiceId=service_id)["Instances"]
+ instance = (instances or [None])[0]
+ if instance and instance.get("Id"):
+ uri = "http://%s/%s" % (instance["Id"], invocation_path.lstrip("/"))
+
+ # apply custom request template
+ invocation_context.context = get_event_request_context(invocation_context)
+ invocation_context.stage_variables = get_stage_variables(invocation_context)
+ payload = self.request_templates.render(invocation_context)
+
+ if isinstance(payload, dict):
+ payload = json.dumps(payload)
+
+ # https://docs.aws.amazon.com/apigateway/latest/developerguide/aws-api-gateway-stage-variables-reference.html
+ # HTTP integration URIs
+ #
+ # A stage variable can be used as part of an HTTP integration URL, as shown in the following examples:
+ #
+ # A full URI without protocol β http://${stageVariables.}
+ # A full domain β http://${stageVariables.}/resource/operation
+ # A subdomain β http://${stageVariables.}.example.com/resource/operation
+ # A path β http://example.com/${stageVariables.}/bar
+ # A query string β http://example.com/foo?q=${stageVariables.}
+ render_vars = {"stageVariables": invocation_context.stage_variables}
+ rendered_uri = VtlTemplate().render_vtl(uri, render_vars)
+
+ uri = apply_request_parameters(
+ rendered_uri,
+ integration=integration,
+ path_params=path_params,
+ query_params=query_string_params,
+ )
+ result = requests.request(method=method, url=uri, data=payload, headers=headers)
+ if not result.ok:
+ LOG.debug(
+ "Upstream response from <%s> %s returned with status code: %s",
+ method,
+ uri,
+ result.status_code,
+ )
+ # apply custom response template for non-proxy integration
+ invocation_context.response = result
+ if integration["type"] != "HTTP_PROXY":
+ self.response_templates.render(invocation_context)
+ return invocation_context.response
+
+
+class SQSIntegration(BackendIntegration):
+ def invoke(self, invocation_context: ApiInvocationContext):
+ integration = invocation_context.integration
+ uri = integration.get("uri") or integration.get("integrationUri") or ""
+ account_id, queue = uri.split("/")[-2:]
+ region_name = uri.split(":")[3]
+
+ headers = get_internal_mocked_headers(
+ service_name="sqs",
+ region_name=region_name,
+ role_arn=invocation_context.integration.get("credentials"),
+ source_arn=get_source_arn(invocation_context),
+ )
+
+ # integration parameters can override headers
+ integration_parameters = self.request_params_resolver.resolve(context=invocation_context)
+ headers.update(integration_parameters.get("headers", {}))
+ if "Accept" not in headers:
+ headers["Accept"] = "application/json"
+
+ if invocation_context.is_websocket_request():
+ template_key = self.render_template_selection_expression(invocation_context)
+ payload = self.request_templates.render(invocation_context, template_key)
+ else:
+ payload = self.request_templates.render(invocation_context)
+
+ # not sure what the purpose of this is, but it's in the original code
+ # TODO: check if this is still needed
+ if "GetQueueUrl" in payload or "CreateQueue" in payload:
+ new_request = f"{payload}&QueueName={queue}"
+ else:
+ queue_url = f"{config.internal_service_url()}/queue/{region_name}/{account_id}/{queue}"
+ new_request = f"{payload}&QueueUrl={queue_url}"
+
+ url = urljoin(config.internal_service_url(), f"/queue/{region_name}/{account_id}/{queue}")
+ response = common.make_http_request(url, method="POST", headers=headers, data=new_request)
+
+ # apply response template
+ invocation_context.response = response
+ response._content = self.response_templates.render(invocation_context)
+ return response
+
+
+class SNSIntegration(BackendIntegration):
+ def invoke(self, invocation_context: ApiInvocationContext) -> Response:
+ # TODO: check if the logic below is accurate - cover with snapshot tests!
+ invocation_context.context = get_event_request_context(invocation_context)
+ invocation_context.stage_variables = get_stage_variables(invocation_context)
+ integration = invocation_context.integration
+ uri = integration.get("uri") or integration.get("integrationUri") or ""
+
+ try:
+ if invocation_context.is_websocket_request():
+ template_key = self.render_template_selection_expression(invocation_context)
+ payload = self.request_templates.render(invocation_context, template_key)
+ else:
+ payload = self.request_templates.render(invocation_context)
+ except Exception as e:
+ LOG.warning("Failed to apply template for SNS integration", e)
+ raise
+ region_name = uri.split(":")[3]
+ headers = mock_aws_request_headers(
+ service="sns", aws_access_key_id=invocation_context.account_id, region_name=region_name
+ )
+ response = make_http_request(
+ config.internal_service_url(), method="POST", headers=headers, data=payload
+ )
+
+ invocation_context.response = response
+ response._content = self.response_templates.render(invocation_context)
+ return self.apply_response_parameters(invocation_context, response)
+
+
+class StepFunctionIntegration(BackendIntegration):
+ @classmethod
+ def _validate_required_params(cls, request_parameters: Dict[str, Any]) -> None:
+ if not request_parameters:
+ raise BadRequestException("Missing required parameters")
+ # stateMachineArn and input are required
+ state_machine_arn_param = request_parameters.get("StateMachineArn")
+ input_param = request_parameters.get("Input")
+
+ if not state_machine_arn_param:
+ raise BadRequestException("StateMachineArn")
+
+ if not input_param:
+ raise BadRequestException("Input")
+
+ def invoke(self, invocation_context: ApiInvocationContext):
+ uri = (
+ invocation_context.integration.get("uri")
+ or invocation_context.integration.get("integrationUri")
+ or ""
+ )
+ action = uri.split("/")[-1]
+
+ if invocation_context.integration.get("IntegrationType") == "AWS_PROXY":
+ payload = self._create_request_parameters(invocation_context)
+ elif APPLICATION_JSON in invocation_context.integration.get("requestTemplates", {}):
+ payload = self.request_templates.render(invocation_context)
+ payload = json.loads(payload)
+ else:
+ payload = json.loads(invocation_context.data)
+
+ client = get_service_factory(
+ region_name=invocation_context.region_name,
+ role_arn=invocation_context.integration.get("credentials"),
+ ).stepfunctions
+
+ if isinstance(payload.get("input"), dict):
+ payload["input"] = json.dumps(payload["input"])
+
+ # Hot fix since step functions local package responses: Unsupported Operation: 'StartSyncExecution'
+ method_name = (
+ camel_to_snake_case(action) if action != "StartSyncExecution" else "start_execution"
+ )
+
+ try:
+ # call method on step function client
+ method = getattr(client, method_name)
+ except AttributeError:
+ msg = f"Invalid step function action: {method_name}"
+ LOG.error(msg)
+ return StepFunctionIntegration._create_response(
+ HTTPStatus.BAD_REQUEST.value,
+ headers={"Content-Type": APPLICATION_JSON},
+ data=json.dumps({"message": msg}),
+ )
+
+ result = method(**payload)
+ result = json_safe(remove_attributes(result, ["ResponseMetadata"]))
+ response = StepFunctionIntegration._create_response(
+ HTTPStatus.OK.value,
+ mock_aws_request_headers(
+ "stepfunctions",
+ aws_access_key_id=invocation_context.account_id,
+ region_name=invocation_context.region_name,
+ ),
+ data=json.dumps(result),
+ )
+ if action == "StartSyncExecution":
+ # poll for the execution result and return it
+ result = await_sfn_execution_result(result["executionArn"])
+ result_status = result.get("status")
+ if result_status != "SUCCEEDED":
+ return StepFunctionIntegration._create_response(
+ HTTPStatus.INTERNAL_SERVER_ERROR.value,
+ headers={"Content-Type": APPLICATION_JSON},
+ data=json.dumps(
+ {
+ "message": "StepFunctions execution %s failed with status '%s'"
+ % (result["executionArn"], result_status)
+ }
+ ),
+ )
+
+ result = json_safe(result)
+ response = requests_response(content=result)
+
+ # apply response templates
+ invocation_context.response = response
+ response._content = self.response_templates.render(invocation_context)
+ return response
+
+ def _create_request_parameters(self, invocation_context):
+ request_parameters = invocation_context.integration.get("requestParameters", {})
+ self._validate_required_params(request_parameters)
+
+ variables = {
+ "request": {
+ "header": invocation_context.headers,
+ "querystring": invocation_context.query_params(),
+ "body": invocation_context.data_as_string(),
+ "context": invocation_context.context or {},
+ "stage_variables": invocation_context.stage_variables or {},
+ }
+ }
+ rendered_input = VtlTemplate().render_vtl(request_parameters.get("Input"), variables)
+ return {
+ "stateMachineArn": request_parameters.get("StateMachineArn"),
+ "input": rendered_input,
+ }
+
+
+class MockIntegration(BackendIntegration):
+ @classmethod
+ def check_passthrough_behavior(cls, passthrough_behavior: str, request_template: str):
+ return MappingTemplates(passthrough_behavior).check_passthrough_behavior(request_template)
+
+ def invoke(self, invocation_context: ApiInvocationContext) -> Response:
+ passthrough_behavior = invocation_context.integration.get("passthroughBehavior") or ""
+ request_template = invocation_context.integration.get("requestTemplates", {}).get(
+ invocation_context.headers.get(HEADER_CONTENT_TYPE, APPLICATION_JSON)
+ )
+
+ # based on the configured passthrough behavior and the existence of template or not,
+ # we proceed calling the integration or raise an exception.
+ try:
+ self.check_passthrough_behavior(passthrough_behavior, request_template)
+ except MappingTemplates.UnsupportedMediaType:
+ return MockIntegration._create_response(
+ HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value,
+ headers={"Content-Type": APPLICATION_JSON},
+ data=json.dumps({"message": f"{HTTPStatus.UNSUPPORTED_MEDIA_TYPE.phrase}"}),
+ )
+
+ # request template rendering
+ request_payload = self.request_templates.render(invocation_context)
+
+ # mapping is done based on "statusCode" field, we default to 200
+ status_code = 200
+ if invocation_context.headers.get(HEADER_CONTENT_TYPE) == APPLICATION_JSON:
+ try:
+ mock_response = json.loads(request_payload)
+ status_code = mock_response.get("statusCode", status_code)
+ except Exception as e:
+ LOG.warning("failed to deserialize request payload after transformation: %s", e)
+ http_status = HTTPStatus(500)
+ return MockIntegration._create_response(
+ http_status.value,
+ headers={"Content-Type": APPLICATION_JSON},
+ data=json.dumps({"message": f"{http_status.phrase}"}),
+ )
+
+ # response template
+ response = MockIntegration._create_response(
+ status_code, invocation_context.headers, data=request_payload
+ )
+ response._content = self.response_templates.render(invocation_context, response=response)
+ # apply response parameters
+ response = self.apply_response_parameters(invocation_context, response)
+ if not invocation_context.headers.get(HEADER_CONTENT_TYPE):
+ invocation_context.headers.update({HEADER_CONTENT_TYPE: APPLICATION_JSON})
+ return response
+
+
+# TODO: remove once we migrate all usages to `apply_request_parameters` on BackendIntegration
+def apply_request_parameters(
+ uri: str, integration: Dict[str, Any], path_params: Dict[str, str], query_params: Dict[str, str]
+):
+ request_parameters = integration.get("requestParameters")
+ uri = uri or integration.get("uri") or integration.get("integrationUri") or ""
+ if request_parameters:
+ for key in path_params:
+ # check if path_params is present in the integration request parameters
+ request_param_key = f"integration.request.path.{key}"
+ request_param_value = f"method.request.path.{key}"
+ if request_parameters.get(request_param_key) == request_param_value:
+ uri = uri.replace(f"{{{key}}}", path_params[key])
+
+ if integration.get("type") != "HTTP_PROXY" and request_parameters:
+ for key in query_params.copy():
+ request_query_key = f"integration.request.querystring.{key}"
+ request_param_val = f"method.request.querystring.{key}"
+ if request_parameters.get(request_query_key, None) != request_param_val:
+ query_params.pop(key)
+
+ return add_query_params_to_url(uri, query_params)
+
+
+class EventBridgeIntegration(BackendIntegration):
+ def invoke(self, invocation_context: ApiInvocationContext):
+ invocation_context.context = get_event_request_context(invocation_context)
+ try:
+ payload = self.request_templates.render(invocation_context)
+ except Exception as e:
+ LOG.warning("Failed to apply template for EventBridge integration: %s", e)
+ raise
+ uri = (
+ invocation_context.integration.get("uri")
+ or invocation_context.integration.get("integrationUri")
+ or ""
+ )
+ region_name = uri.split(":")[3]
+ headers = get_internal_mocked_headers(
+ service_name="events",
+ region_name=region_name,
+ role_arn=invocation_context.integration.get("credentials"),
+ source_arn=get_source_arn(invocation_context),
+ )
+ headers.update({"X-Amz-Target": invocation_context.headers.get("X-Amz-Target")})
+ response = make_http_request(
+ config.internal_service_url(), method="POST", headers=headers, data=payload
+ )
+
+ invocation_context.response = response
+
+ self.response_templates.render(invocation_context)
+ invocation_context.response.headers["Content-Length"] = str(len(response.content or ""))
+ return invocation_context.response
diff --git a/localstack-core/localstack/services/apigateway/legacy/invocations.py b/localstack-core/localstack/services/apigateway/legacy/invocations.py
new file mode 100644
index 0000000000000..18085fc52e22e
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/legacy/invocations.py
@@ -0,0 +1,400 @@
+import json
+import logging
+import re
+
+from jsonschema import ValidationError, validate
+from requests.models import Response
+from werkzeug.exceptions import NotFound
+
+from localstack.aws.connect import connect_to
+from localstack.constants import APPLICATION_JSON
+from localstack.services.apigateway.helpers import (
+ EMPTY_MODEL,
+ ModelResolver,
+ get_apigateway_store_for_invocation,
+)
+from localstack.services.apigateway.legacy.context import ApiInvocationContext
+from localstack.services.apigateway.legacy.helpers import (
+ get_cors_response,
+ get_event_request_context,
+ get_target_resource_details,
+ make_error_response,
+ set_api_id_stage_invocation_path,
+)
+from localstack.services.apigateway.legacy.integration import (
+ ApiGatewayIntegrationError,
+ DynamoDBIntegration,
+ EventBridgeIntegration,
+ HTTPIntegration,
+ KinesisIntegration,
+ LambdaIntegration,
+ LambdaProxyIntegration,
+ MockIntegration,
+ S3Integration,
+ SNSIntegration,
+ SQSIntegration,
+ StepFunctionIntegration,
+)
+from localstack.services.apigateway.models import ApiGatewayStore
+from localstack.utils.aws.arns import ARN_PARTITION_REGEX
+from localstack.utils.aws.aws_responses import requests_response
+
+LOG = logging.getLogger(__name__)
+
+
+class AuthorizationError(Exception):
+ message: str
+ status_code: int
+
+ def __init__(self, message: str, status_code: int):
+ super().__init__(message)
+ self.message = message
+ self.status_code = status_code
+
+ def to_response(self):
+ return requests_response({"message": self.message}, status_code=self.status_code)
+
+
+# we separate those 2 exceptions to allow better GatewayResponse support later on
+class BadRequestParameters(Exception):
+ message: str
+
+ def __init__(self, message: str):
+ super().__init__(message)
+ self.message = message
+
+ def to_response(self):
+ return requests_response({"message": self.message}, status_code=400)
+
+
+class BadRequestBody(Exception):
+ message: str
+
+ def __init__(self, message: str):
+ super().__init__(message)
+ self.message = message
+
+ def to_response(self):
+ return requests_response({"message": self.message}, status_code=400)
+
+
+class RequestValidator:
+ __slots__ = ["context", "rest_api_container"]
+
+ def __init__(self, context: ApiInvocationContext, store: ApiGatewayStore = None):
+ self.context = context
+ store = store or get_apigateway_store_for_invocation(context=context)
+ if not (container := store.rest_apis.get(context.api_id)):
+ # TODO: find the right exception
+ raise NotFound()
+ self.rest_api_container = container
+
+ def validate_request(self) -> None:
+ """
+ :raises BadRequestParameters if the request has required parameters which are not present
+ :raises BadRequestBody if the request has required body validation with a model and it does not respect it
+ :return: None
+ """
+ # make all the positive checks first
+ if self.context.resource is None or "resourceMethods" not in self.context.resource:
+ return
+
+ resource_methods = self.context.resource["resourceMethods"]
+ if self.context.method not in resource_methods and "ANY" not in resource_methods:
+ return
+
+ # check if there is validator for the resource
+ resource = resource_methods.get(self.context.method, resource_methods.get("ANY", {}))
+ if not (resource.get("requestValidatorId") or "").strip():
+ return
+
+ # check if there is a validator for this request
+ validator = self.rest_api_container.validators.get(resource["requestValidatorId"])
+ if not validator:
+ return
+
+ if self.should_validate_request(validator) and (
+ missing_parameters := self._get_missing_required_parameters(resource)
+ ):
+ message = f"Missing required request parameters: [{', '.join(missing_parameters)}]"
+ raise BadRequestParameters(message=message)
+
+ if self.should_validate_body(validator) and not self._is_body_valid(resource):
+ raise BadRequestBody(message="Invalid request body")
+
+ return
+
+ def _is_body_valid(self, resource) -> bool:
+ # if there's no model to validate the body, use the Empty model
+ # https://docs.aws.amazon.com/cdk/api/v1/docs/@aws-cdk_aws-apigateway.EmptyModel.html
+ if not (request_models := resource.get("requestModels")):
+ model_name = EMPTY_MODEL
+ else:
+ model_name = request_models.get(
+ APPLICATION_JSON, request_models.get("$default", EMPTY_MODEL)
+ )
+
+ model_resolver = ModelResolver(
+ rest_api_container=self.rest_api_container,
+ model_name=model_name,
+ )
+
+ # try to get the resolved model first
+ resolved_schema = model_resolver.get_resolved_model()
+ if not resolved_schema:
+ LOG.exception(
+ "An exception occurred while trying to validate the request: could not find the model"
+ )
+ return False
+
+ try:
+ # if the body is empty, replace it with an empty JSON body
+ validate(
+ instance=json.loads(self.context.data or "{}"),
+ schema=resolved_schema,
+ )
+ return True
+ except ValidationError as e:
+ LOG.warning("failed to validate request body %s", e)
+ return False
+ except json.JSONDecodeError as e:
+ LOG.warning("failed to validate request body, request data is not valid JSON %s", e)
+ return False
+
+ def _get_missing_required_parameters(self, resource) -> list[str]:
+ missing_params = []
+ if not (request_parameters := resource.get("requestParameters")):
+ return missing_params
+
+ for request_parameter, required in sorted(request_parameters.items()):
+ if not required:
+ continue
+
+ param_type, param_value = request_parameter.removeprefix("method.request.").split(".")
+ match param_type:
+ case "header":
+ is_missing = param_value not in self.context.headers
+ case "path":
+ is_missing = param_value not in self.context.resource_path
+ case "querystring":
+ is_missing = param_value not in self.context.query_params()
+ case _:
+ # TODO: method.request.body is not specified in the documentation, and requestModels should do it
+ # verify this
+ is_missing = False
+
+ if is_missing:
+ missing_params.append(param_value)
+
+ return missing_params
+
+ @staticmethod
+ def should_validate_body(validator):
+ return validator["validateRequestBody"]
+
+ @staticmethod
+ def should_validate_request(validator):
+ return validator.get("validateRequestParameters")
+
+
+# ------------
+# API METHODS
+# ------------
+
+
+def validate_api_key(api_key: str, invocation_context: ApiInvocationContext):
+ usage_plan_ids = []
+ client = connect_to(
+ aws_access_key_id=invocation_context.account_id, region_name=invocation_context.region_name
+ ).apigateway
+
+ usage_plans = client.get_usage_plans()
+ for item in usage_plans.get("items", []):
+ api_stages = item.get("apiStages", [])
+ usage_plan_ids.extend(
+ item.get("id")
+ for api_stage in api_stages
+ if (
+ api_stage.get("stage") == invocation_context.stage
+ and api_stage.get("apiId") == invocation_context.api_id
+ )
+ )
+ for usage_plan_id in usage_plan_ids:
+ usage_plan_keys = client.get_usage_plan_keys(usagePlanId=usage_plan_id)
+ for key in usage_plan_keys.get("items", []):
+ if key.get("value") == api_key:
+ # check if the key is enabled
+ api_key = client.get_api_key(apiKey=key.get("id"))
+ return api_key.get("enabled") in ("true", True)
+
+ return False
+
+
+def is_api_key_valid(invocation_context: ApiInvocationContext) -> bool:
+ # https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-api-key-source.html
+ client = connect_to(
+ aws_access_key_id=invocation_context.account_id, region_name=invocation_context.region_name
+ ).apigateway
+ rest_api = client.get_rest_api(restApiId=invocation_context.api_id)
+
+ # The source of the API key for metering requests according to a usage plan.
+ # Valid values are:
+ # - HEADER to read the API key from the X-API-Key header of a request.
+ # - AUTHORIZER to read the API key from the UsageIdentifierKey from a custom authorizer.
+
+ api_key_source = rest_api.get("apiKeySource")
+ match api_key_source:
+ case "HEADER":
+ api_key = invocation_context.headers.get("X-API-Key")
+ return validate_api_key(api_key, invocation_context) if api_key else False
+ case "AUTHORIZER":
+ api_key = invocation_context.auth_identity.get("apiKey")
+ return validate_api_key(api_key, invocation_context) if api_key else False
+
+
+def update_content_length(response: Response):
+ if response and response.content is not None:
+ response.headers["Content-Length"] = str(len(response.content))
+
+
+def invoke_rest_api_from_request(invocation_context: ApiInvocationContext):
+ set_api_id_stage_invocation_path(invocation_context)
+ try:
+ return invoke_rest_api(invocation_context)
+ except AuthorizationError as e:
+ LOG.warning(
+ "Authorization error while invoking API Gateway ID %s: %s",
+ invocation_context.api_id,
+ e,
+ exc_info=LOG.isEnabledFor(logging.DEBUG),
+ )
+ return e.to_response()
+
+
+def invoke_rest_api(invocation_context: ApiInvocationContext):
+ invocation_path = invocation_context.path_with_query_string
+ raw_path = invocation_context.path or invocation_path
+ method = invocation_context.method
+ headers = invocation_context.headers
+
+ extracted_path, resource = get_target_resource_details(invocation_context)
+ if not resource:
+ return make_error_response("Unable to find path %s" % invocation_context.path, 404)
+
+ # validate request
+ validator = RequestValidator(invocation_context)
+ try:
+ validator.validate_request()
+ except (BadRequestParameters, BadRequestBody) as e:
+ return e.to_response()
+
+ api_key_required = resource.get("resourceMethods", {}).get(method, {}).get("apiKeyRequired")
+ if api_key_required and not is_api_key_valid(invocation_context):
+ raise AuthorizationError("Forbidden", 403)
+
+ resource_methods = resource.get("resourceMethods", {})
+ resource_method = resource_methods.get(method, {})
+ if not resource_method:
+ # HttpMethod: '*'
+ # ResourcePath: '/*' - produces 'X-AMAZON-APIGATEWAY-ANY-METHOD'
+ resource_method = resource_methods.get("ANY", {}) or resource_methods.get(
+ "X-AMAZON-APIGATEWAY-ANY-METHOD", {}
+ )
+ method_integration = resource_method.get("methodIntegration")
+ if not method_integration:
+ if method == "OPTIONS" and "Origin" in headers:
+ # default to returning CORS headers if this is an OPTIONS request
+ return get_cors_response(headers)
+ return make_error_response(
+ "Unable to find integration for: %s %s (%s)" % (method, invocation_path, raw_path),
+ 404,
+ )
+
+ # update fields in invocation context, then forward request to next handler
+ invocation_context.resource_path = extracted_path
+ invocation_context.integration = method_integration
+
+ return invoke_rest_api_integration(invocation_context)
+
+
+def invoke_rest_api_integration(invocation_context: ApiInvocationContext):
+ try:
+ response = invoke_rest_api_integration_backend(invocation_context)
+ # TODO remove this setter once all the integrations are migrated to the new response
+ # handling
+ invocation_context.response = response
+ return response
+ except ApiGatewayIntegrationError as e:
+ LOG.warning(
+ "Error while invoking integration for ApiGateway ID %s: %s",
+ invocation_context.api_id,
+ e,
+ exc_info=LOG.isEnabledFor(logging.DEBUG),
+ )
+ return e.to_response()
+ except Exception as e:
+ msg = f"Error invoking integration for API Gateway ID '{invocation_context.api_id}': {e}"
+ LOG.exception(msg)
+ return make_error_response(msg, 400)
+
+
+# This function is patched downstream for backend integrations that are only available
+# in Pro (potentially to be replaced with a runtime hook in the future).
+def invoke_rest_api_integration_backend(invocation_context: ApiInvocationContext):
+ # define local aliases from invocation context
+ method = invocation_context.method
+ headers = invocation_context.headers
+ integration = invocation_context.integration
+ integration_type_orig = integration.get("type") or integration.get("integrationType") or ""
+ integration_type = integration_type_orig.upper()
+ integration_method = integration.get("httpMethod")
+ uri = integration.get("uri") or integration.get("integrationUri") or ""
+
+ if (re.match(f"{ARN_PARTITION_REGEX}:apigateway:", uri) and ":lambda:path" in uri) or re.match(
+ f"{ARN_PARTITION_REGEX}:lambda", uri
+ ):
+ invocation_context.context = get_event_request_context(invocation_context)
+ if integration_type == "AWS_PROXY":
+ return LambdaProxyIntegration().invoke(invocation_context)
+ elif integration_type == "AWS":
+ return LambdaIntegration().invoke(invocation_context)
+
+ elif integration_type == "AWS":
+ if "kinesis:action/" in uri:
+ return KinesisIntegration().invoke(invocation_context)
+
+ if "states:action/" in uri:
+ return StepFunctionIntegration().invoke(invocation_context)
+
+ if ":dynamodb:action" in uri:
+ return DynamoDBIntegration().invoke(invocation_context)
+
+ if "s3:path/" in uri or "s3:action/" in uri:
+ return S3Integration().invoke(invocation_context)
+
+ if integration_method == "POST" and ":sqs:path" in uri:
+ return SQSIntegration().invoke(invocation_context)
+
+ if method == "POST" and ":sns:path" in uri:
+ return SNSIntegration().invoke(invocation_context)
+
+ if (
+ method == "POST"
+ and re.match(f"{ARN_PARTITION_REGEX}:apigateway:", uri)
+ and "events:action/PutEvents" in uri
+ ):
+ return EventBridgeIntegration().invoke(invocation_context)
+
+ elif integration_type in ["HTTP_PROXY", "HTTP"]:
+ return HTTPIntegration().invoke(invocation_context)
+
+ elif integration_type == "MOCK":
+ return MockIntegration().invoke(invocation_context)
+
+ if method == "OPTIONS":
+ # fall back to returning CORS headers if this is an OPTIONS request
+ return get_cors_response(headers)
+
+ raise Exception(
+ f'API Gateway integration type "{integration_type}", method "{method}", URI "{uri}" not yet implemented'
+ )
diff --git a/localstack-core/localstack/services/apigateway/legacy/provider.py b/localstack-core/localstack/services/apigateway/legacy/provider.py
new file mode 100644
index 0000000000000..25ff91ddfedc5
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/legacy/provider.py
@@ -0,0 +1,3017 @@
+import copy
+import io
+import json
+import logging
+import re
+from copy import deepcopy
+from datetime import datetime
+from typing import IO, Any
+
+from moto.apigateway import models as apigw_models
+from moto.apigateway.models import Resource as MotoResource
+from moto.apigateway.models import RestAPI as MotoRestAPI
+from moto.core.utils import camelcase_to_underscores
+
+from localstack.aws.api import CommonServiceException, RequestContext, ServiceRequest, handler
+from localstack.aws.api.apigateway import (
+ Account,
+ ApigatewayApi,
+ ApiKey,
+ ApiKeys,
+ Authorizer,
+ Authorizers,
+ BadRequestException,
+ BasePathMapping,
+ BasePathMappings,
+ Blob,
+ Boolean,
+ ClientCertificate,
+ ClientCertificates,
+ ConflictException,
+ ConnectionType,
+ CreateAuthorizerRequest,
+ CreateRestApiRequest,
+ CreateStageRequest,
+ Deployment,
+ DocumentationPart,
+ DocumentationPartIds,
+ DocumentationPartLocation,
+ DocumentationParts,
+ DocumentationVersion,
+ DocumentationVersions,
+ DomainName,
+ DomainNames,
+ DomainNameStatus,
+ EndpointConfiguration,
+ ExportResponse,
+ GatewayResponse,
+ GatewayResponses,
+ GatewayResponseType,
+ GetDocumentationPartsRequest,
+ Integration,
+ IntegrationResponse,
+ IntegrationType,
+ ListOfApiStage,
+ ListOfPatchOperation,
+ ListOfStageKeys,
+ ListOfString,
+ MapOfStringToBoolean,
+ MapOfStringToString,
+ Method,
+ MethodResponse,
+ Model,
+ Models,
+ MutualTlsAuthenticationInput,
+ NotFoundException,
+ NullableBoolean,
+ NullableInteger,
+ PutIntegrationRequest,
+ PutIntegrationResponseRequest,
+ PutMode,
+ PutRestApiRequest,
+ QuotaSettings,
+ RequestValidator,
+ RequestValidators,
+ Resource,
+ ResourceOwner,
+ RestApi,
+ RestApis,
+ SecurityPolicy,
+ Stage,
+ Stages,
+ StatusCode,
+ String,
+ Tags,
+ TestInvokeMethodRequest,
+ TestInvokeMethodResponse,
+ ThrottleSettings,
+ UsagePlan,
+ UsagePlanKeys,
+ UsagePlans,
+ VpcLink,
+ VpcLinks,
+)
+from localstack.aws.connect import connect_to
+from localstack.aws.forwarder import NotImplementedAvoidFallbackError, create_aws_request_context
+from localstack.constants import APPLICATION_JSON
+from localstack.services.apigateway.exporter import OpenApiExporter
+from localstack.services.apigateway.helpers import (
+ EMPTY_MODEL,
+ ERROR_MODEL,
+ OpenAPIExt,
+ apply_json_patch_safe,
+ get_apigateway_store,
+ get_moto_backend,
+ get_moto_rest_api,
+ get_regional_domain_name,
+ get_rest_api_container,
+ import_api_from_openapi_spec,
+ is_greedy_path,
+ is_variable_path,
+ log_template,
+ resolve_references,
+)
+from localstack.services.apigateway.legacy.helpers import multi_value_dict_for_list
+from localstack.services.apigateway.legacy.invocations import invoke_rest_api_from_request
+from localstack.services.apigateway.legacy.router_asf import ApigatewayRouter, to_invocation_context
+from localstack.services.apigateway.models import ApiGatewayStore, RestApiContainer
+from localstack.services.apigateway.next_gen.execute_api.router import (
+ ApiGatewayRouter as ApiGatewayRouterNextGen,
+)
+from localstack.services.apigateway.patches import apply_patches
+from localstack.services.edge import ROUTER
+from localstack.services.moto import call_moto, call_moto_with_request
+from localstack.services.plugins import ServiceLifecycleHook
+from localstack.utils.aws.arns import get_partition
+from localstack.utils.collections import (
+ DelSafeDict,
+ PaginatedList,
+ ensure_list,
+ select_from_typed_dict,
+)
+from localstack.utils.json import parse_json_or_yaml
+from localstack.utils.strings import md5, short_uid, str_to_bool, to_bytes, to_str
+from localstack.utils.time import TIMESTAMP_FORMAT_TZ, now_utc, timestamp
+
+LOG = logging.getLogger(__name__)
+
+# list of valid paths for Stage update patch operations (extracted from AWS responses via snapshot tests)
+STAGE_UPDATE_PATHS = [
+ "/deploymentId",
+ "/description",
+ "/cacheClusterEnabled",
+ "/cacheClusterSize",
+ "/clientCertificateId",
+ "/accessLogSettings",
+ "/accessLogSettings/destinationArn",
+ "/accessLogSettings/format",
+ "/{resourcePath}/{httpMethod}/metrics/enabled",
+ "/{resourcePath}/{httpMethod}/logging/dataTrace",
+ "/{resourcePath}/{httpMethod}/logging/loglevel",
+ "/{resourcePath}/{httpMethod}/throttling/burstLimit",
+ "/{resourcePath}/{httpMethod}/throttling/rateLimit",
+ "/{resourcePath}/{httpMethod}/caching/ttlInSeconds",
+ "/{resourcePath}/{httpMethod}/caching/enabled",
+ "/{resourcePath}/{httpMethod}/caching/dataEncrypted",
+ "/{resourcePath}/{httpMethod}/caching/requireAuthorizationForCacheControl",
+ "/{resourcePath}/{httpMethod}/caching/unauthorizedCacheControlHeaderStrategy",
+ "/*/*/metrics/enabled",
+ "/*/*/logging/dataTrace",
+ "/*/*/logging/loglevel",
+ "/*/*/throttling/burstLimit",
+ "/*/*/throttling/rateLimit",
+ "/*/*/caching/ttlInSeconds",
+ "/*/*/caching/enabled",
+ "/*/*/caching/dataEncrypted",
+ "/*/*/caching/requireAuthorizationForCacheControl",
+ "/*/*/caching/unauthorizedCacheControlHeaderStrategy",
+ "/variables/{variable_name}",
+ "/tracingEnabled",
+]
+
+VALID_INTEGRATION_TYPES = {
+ IntegrationType.AWS,
+ IntegrationType.AWS_PROXY,
+ IntegrationType.HTTP,
+ IntegrationType.HTTP_PROXY,
+ IntegrationType.MOCK,
+}
+
+
+class ApigatewayProvider(ApigatewayApi, ServiceLifecycleHook):
+ router: ApigatewayRouter | ApiGatewayRouterNextGen
+
+ def __init__(self, router: ApigatewayRouter | ApiGatewayRouterNextGen = None):
+ self.router = router or ApigatewayRouter(ROUTER)
+
+ def on_after_init(self):
+ apply_patches()
+ self.router.register_routes()
+
+ @handler("TestInvokeMethod", expand=False)
+ def test_invoke_method(
+ self, context: RequestContext, request: TestInvokeMethodRequest
+ ) -> TestInvokeMethodResponse:
+ invocation_context = to_invocation_context(context.request)
+ invocation_context.method = request.get("httpMethod")
+ invocation_context.api_id = request.get("restApiId")
+ invocation_context.path_with_query_string = request.get("pathWithQueryString")
+ invocation_context.region_name = context.region
+ invocation_context.account_id = context.account_id
+
+ moto_rest_api = get_moto_rest_api(context=context, rest_api_id=invocation_context.api_id)
+ resource = moto_rest_api.resources.get(request["resourceId"])
+ if not resource:
+ raise NotFoundException("Invalid Resource identifier specified")
+
+ invocation_context.resource = {"id": resource.id}
+ invocation_context.resource_path = resource.path_part
+
+ if data := parse_json_or_yaml(to_str(invocation_context.data or b"")):
+ invocation_context.data = data.get("body")
+ invocation_context.headers = data.get("headers", {})
+
+ req_start_time = datetime.now()
+ result = invoke_rest_api_from_request(invocation_context)
+ req_end_time = datetime.now()
+
+ # TODO: add the missing fields to the log. Next iteration will add helpers to extract the missing fields
+ # from the apicontext
+ log = log_template(
+ request_id=invocation_context.context["requestId"],
+ date=req_start_time,
+ http_method=invocation_context.method,
+ resource_path=invocation_context.invocation_path,
+ request_path="",
+ query_string="",
+ request_headers="",
+ request_body="",
+ response_body="",
+ response_headers=result.headers,
+ status_code=result.status_code,
+ )
+ return TestInvokeMethodResponse(
+ status=result.status_code,
+ headers=dict(result.headers),
+ body=to_str(result.content),
+ log=log,
+ latency=int((req_end_time - req_start_time).total_seconds()),
+ multiValueHeaders=multi_value_dict_for_list(result.headers),
+ )
+
+ @handler("CreateRestApi", expand=False)
+ def create_rest_api(self, context: RequestContext, request: CreateRestApiRequest) -> RestApi:
+ if request.get("description") == "":
+ raise BadRequestException("Description cannot be an empty string")
+
+ minimum_compression_size = request.get("minimumCompressionSize")
+ if minimum_compression_size is not None and (
+ minimum_compression_size < 0 or minimum_compression_size > 10485760
+ ):
+ raise BadRequestException(
+ "Invalid minimum compression size, must be between 0 and 10485760"
+ )
+
+ result = call_moto(context)
+ rest_api = get_moto_rest_api(context, rest_api_id=result["id"])
+ rest_api.version = request.get("version")
+ response: RestApi = rest_api.to_dict()
+ remove_empty_attributes_from_rest_api(response)
+ store = get_apigateway_store(context=context)
+ rest_api_container = RestApiContainer(rest_api=response)
+ store.rest_apis[result["id"]] = rest_api_container
+ # add the 2 default models
+ rest_api_container.models[EMPTY_MODEL] = DEFAULT_EMPTY_MODEL
+ rest_api_container.models[ERROR_MODEL] = DEFAULT_ERROR_MODEL
+
+ return response
+
+ def create_api_key(
+ self,
+ context: RequestContext,
+ name: String = None,
+ description: String = None,
+ enabled: Boolean = None,
+ generate_distinct_id: Boolean = None,
+ value: String = None,
+ stage_keys: ListOfStageKeys = None,
+ customer_id: String = None,
+ tags: MapOfStringToString = None,
+ **kwargs,
+ ) -> ApiKey:
+ api_key = call_moto(context)
+
+ # transform array of stage keys [{'restApiId': '0iscapk09u', 'stageName': 'dev'}] into
+ # array of strings ['0iscapk09u/dev']
+ stage_keys = api_key.get("stageKeys", [])
+ api_key["stageKeys"] = [f"{sk['restApiId']}/{sk['stageName']}" for sk in stage_keys]
+
+ return api_key
+
+ def get_rest_api(self, context: RequestContext, rest_api_id: String, **kwargs) -> RestApi:
+ rest_api: RestApi = call_moto(context)
+ remove_empty_attributes_from_rest_api(rest_api)
+ return rest_api
+
+ def update_rest_api(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> RestApi:
+ rest_api = get_moto_rest_api(context, rest_api_id)
+
+ fixed_patch_ops = []
+ binary_media_types_path = "/binaryMediaTypes"
+ # TODO: validate a bit more patch operations
+ for patch_op in patch_operations:
+ patch_op_path = patch_op.get("path", "")
+ # binaryMediaTypes has a specific way of being set
+ # see https://docs.aws.amazon.com/apigateway/latest/api/API_PatchOperation.html
+ # TODO: maybe implement a more generalized way if this happens anywhere else
+ if patch_op_path.startswith(binary_media_types_path):
+ if patch_op_path == binary_media_types_path:
+ raise BadRequestException(f"Invalid patch path {patch_op_path}")
+ value = patch_op_path.rsplit("/", maxsplit=1)[-1]
+ path_value = value.replace("~1", "/")
+ patch_op["path"] = binary_media_types_path
+
+ if patch_op["op"] == "add":
+ patch_op["value"] = path_value
+
+ elif patch_op["op"] == "remove":
+ remove_index = rest_api.binaryMediaTypes.index(path_value)
+ patch_op["path"] = f"{binary_media_types_path}/{remove_index}"
+
+ elif patch_op["op"] == "replace":
+ # AWS is behaving weirdly, and will actually remove/add instead of replacing in place
+ # it will put the replaced value last in the array
+ replace_index = rest_api.binaryMediaTypes.index(path_value)
+ fixed_patch_ops.append(
+ {"op": "remove", "path": f"{binary_media_types_path}/{replace_index}"}
+ )
+ patch_op["op"] = "add"
+
+ elif patch_op_path == "/minimumCompressionSize":
+ if patch_op["op"] != "replace":
+ raise BadRequestException(
+ "Invalid patch operation specified. Must be one of: [replace]"
+ )
+
+ try:
+ # try to cast the value to integer if truthy, else reject
+ value = int(val) if (val := patch_op.get("value")) else None
+ except ValueError:
+ raise BadRequestException(
+ "Invalid minimum compression size, must be between 0 and 10485760"
+ )
+
+ if value is not None and (value < 0 or value > 10485760):
+ raise BadRequestException(
+ "Invalid minimum compression size, must be between 0 and 10485760"
+ )
+ patch_op["value"] = value
+
+ fixed_patch_ops.append(patch_op)
+
+ _patch_api_gateway_entity(rest_api, fixed_patch_ops)
+
+ # fix data types after patches have been applied
+ endpoint_configs = rest_api.endpoint_configuration or {}
+ if isinstance(endpoint_configs.get("vpcEndpointIds"), str):
+ endpoint_configs["vpcEndpointIds"] = [endpoint_configs["vpcEndpointIds"]]
+
+ # minimum_compression_size is a unique path as it's a nullable integer,
+ # it would throw an error if it stays an empty string
+ if rest_api.minimum_compression_size == "":
+ rest_api.minimum_compression_size = None
+
+ response = rest_api.to_dict()
+
+ remove_empty_attributes_from_rest_api(response, remove_tags=False)
+ store = get_apigateway_store(context=context)
+ store.rest_apis[rest_api_id].rest_api = response
+ return response
+
+ @handler("PutRestApi", expand=False)
+ def put_rest_api(self, context: RequestContext, request: PutRestApiRequest) -> RestApi:
+ # TODO: take into account the mode: overwrite or merge
+ # the default is now `merge`, but we are removing everything
+ rest_api = get_moto_rest_api(context, request["restApiId"])
+ rest_api, warnings = import_api_from_openapi_spec(
+ rest_api, context=context, request=request
+ )
+
+ rest_api.root_resource_id = get_moto_rest_api_root_resource(rest_api)
+ response = rest_api.to_dict()
+ remove_empty_attributes_from_rest_api(response)
+ store = get_apigateway_store(context=context)
+ store.rest_apis[request["restApiId"]].rest_api = response
+ # TODO: verify this
+ response = to_rest_api_response_json(response)
+ response.setdefault("tags", {})
+
+ # TODO Failing still keeps all applied mutations. We need to revert to the previous state instead
+ if warnings:
+ response["warnings"] = warnings
+
+ return response
+
+ @handler("CreateDomainName")
+ def create_domain_name(
+ self,
+ context: RequestContext,
+ domain_name: String,
+ certificate_name: String = None,
+ certificate_body: String = None,
+ certificate_private_key: String = None,
+ certificate_chain: String = None,
+ certificate_arn: String = None,
+ regional_certificate_name: String = None,
+ regional_certificate_arn: String = None,
+ endpoint_configuration: EndpointConfiguration = None,
+ tags: MapOfStringToString = None,
+ security_policy: SecurityPolicy = None,
+ mutual_tls_authentication: MutualTlsAuthenticationInput = None,
+ ownership_verification_certificate_arn: String = None,
+ policy: String = None,
+ **kwargs,
+ ) -> DomainName:
+ if not domain_name:
+ raise BadRequestException("No Domain Name specified")
+
+ store: ApiGatewayStore = get_apigateway_store(context=context)
+ if store.domain_names.get(domain_name):
+ raise ConflictException(f"Domain name with ID {domain_name} already exists")
+
+ # find matching hosted zone
+ zone_id = None
+ # TODO check if this call is IAM enforced
+ route53 = connect_to(
+ region_name=context.region, aws_access_key_id=context.account_id
+ ).route53
+ hosted_zones = route53.list_hosted_zones().get("HostedZones", [])
+ hosted_zones = [hz for hz in hosted_zones if domain_name.endswith(hz["Name"].strip("."))]
+ zone_id = hosted_zones[0]["Id"].replace("/hostedzone/", "") if hosted_zones else zone_id
+
+ domain: DomainName = DomainName(
+ domainName=domain_name,
+ certificateName=certificate_name,
+ certificateArn=certificate_arn,
+ regionalDomainName=get_regional_domain_name(domain_name),
+ domainNameStatus=DomainNameStatus.AVAILABLE,
+ regionalHostedZoneId=zone_id,
+ regionalCertificateName=regional_certificate_name,
+ regionalCertificateArn=regional_certificate_arn,
+ securityPolicy=SecurityPolicy.TLS_1_2,
+ endpointConfiguration=endpoint_configuration,
+ )
+ store.domain_names[domain_name] = domain
+ return domain
+
+ @handler("GetDomainName")
+ def get_domain_name(
+ self, context: RequestContext, domain_name: String, domain_name_id: String = None, **kwargs
+ ) -> DomainName:
+ store: ApiGatewayStore = get_apigateway_store(context=context)
+ if domain := store.domain_names.get(domain_name):
+ return domain
+ raise NotFoundException("Invalid domain name identifier specified")
+
+ @handler("GetDomainNames")
+ def get_domain_names(
+ self,
+ context: RequestContext,
+ position: String = None,
+ limit: NullableInteger = None,
+ resource_owner: ResourceOwner = None,
+ **kwargs,
+ ) -> DomainNames:
+ store = get_apigateway_store(context=context)
+ domain_names = store.domain_names.values()
+ return DomainNames(items=list(domain_names), position=position)
+
+ @handler("DeleteDomainName")
+ def delete_domain_name(
+ self, context: RequestContext, domain_name: String, domain_name_id: String = None, **kwargs
+ ) -> None:
+ store: ApiGatewayStore = get_apigateway_store(context=context)
+ if not store.domain_names.pop(domain_name, None):
+ raise NotFoundException("Invalid domain name identifier specified")
+
+ def delete_rest_api(self, context: RequestContext, rest_api_id: String, **kwargs) -> None:
+ try:
+ store = get_apigateway_store(context=context)
+ store.rest_apis.pop(rest_api_id, None)
+ call_moto(context)
+ except KeyError as e:
+ # moto raises a key error if we're trying to delete an API that doesn't exist
+ raise NotFoundException(
+ f"Invalid API identifier specified {context.account_id}:{rest_api_id}"
+ ) from e
+
+ def get_rest_apis(
+ self,
+ context: RequestContext,
+ position: String = None,
+ limit: NullableInteger = None,
+ **kwargs,
+ ) -> RestApis:
+ response: RestApis = call_moto(context)
+ for rest_api in response["items"]:
+ remove_empty_attributes_from_rest_api(rest_api)
+ return response
+
+ # resources
+
+ def create_resource(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ parent_id: String,
+ path_part: String,
+ **kwargs,
+ ) -> Resource:
+ moto_rest_api = get_moto_rest_api(context, rest_api_id)
+ parent_moto_resource: MotoResource = moto_rest_api.resources.get(parent_id, None)
+ # validate here if the parent exists. Moto would first create then validate, which would lead to the resource
+ # being created anyway
+ if not parent_moto_resource:
+ raise NotFoundException("Invalid Resource identifier specified")
+
+ parent_path = parent_moto_resource.path_part
+ if is_greedy_path(parent_path):
+ raise BadRequestException(
+ f"Cannot create a child of a resource with a greedy path variable: {parent_path}"
+ )
+
+ store = get_apigateway_store(context=context)
+ rest_api = store.rest_apis.get(rest_api_id)
+ children = rest_api.resource_children.setdefault(parent_id, [])
+
+ if is_variable_path(path_part):
+ for sibling in children:
+ sibling_resource: MotoResource = moto_rest_api.resources.get(sibling, None)
+ if is_variable_path(sibling_resource.path_part):
+ raise BadRequestException(
+ f"A sibling ({sibling_resource.path_part}) of this resource already has a variable path part -- only one is allowed"
+ )
+
+ response: Resource = call_moto(context)
+
+ # save children to allow easy deletion of all children if we delete a parent route
+ children.append(response["id"])
+
+ return response
+
+ def delete_resource(
+ self, context: RequestContext, rest_api_id: String, resource_id: String, **kwargs
+ ) -> None:
+ moto_rest_api = get_moto_rest_api(context, rest_api_id)
+
+ moto_resource: MotoResource = moto_rest_api.resources.pop(resource_id, None)
+ if not moto_resource:
+ raise NotFoundException("Invalid Resource identifier specified")
+
+ store = get_apigateway_store(context=context)
+ rest_api = store.rest_apis.get(rest_api_id)
+ api_resources = rest_api.resource_children
+ # we need to recursively delete all children resources of the resource we're deleting
+
+ def _delete_children(resource_to_delete: str):
+ children = api_resources.get(resource_to_delete, [])
+ for child in children:
+ moto_rest_api.resources.pop(child)
+ _delete_children(child)
+
+ api_resources.pop(resource_to_delete, None)
+
+ _delete_children(resource_id)
+
+ # remove the resource as a child from its parent
+ parent_id = moto_resource.parent_id
+ api_resources[parent_id].remove(resource_id)
+
+ def update_integration_response(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ resource_id: String,
+ http_method: String,
+ status_code: StatusCode,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> IntegrationResponse:
+ # XXX: THIS IS NOT A COMPLETE IMPLEMENTATION, just the minimum required to get tests going
+ # TODO: validate patch operations
+
+ moto_rest_api = get_moto_rest_api(context, rest_api_id)
+ moto_resource = moto_rest_api.resources.get(resource_id)
+ if not moto_resource:
+ raise NotFoundException("Invalid Resource identifier specified")
+
+ moto_method = moto_resource.resource_methods.get(http_method)
+ if not moto_method:
+ raise NotFoundException("Invalid Method identifier specified")
+
+ integration_response = moto_method.method_integration.integration_responses.get(status_code)
+ if not integration_response:
+ raise NotFoundException("Invalid Integration Response identifier specified")
+
+ for patch_operation in patch_operations:
+ op = patch_operation.get("op")
+ path = patch_operation.get("path")
+
+ # for path "/responseTemplates/application~1json"
+ if "/responseTemplates" in path:
+ value = patch_operation.get("value")
+ if not isinstance(value, str):
+ raise BadRequestException(
+ f"Invalid patch value '{value}' specified for op '{op}'. Must be a string"
+ )
+ param = path.removeprefix("/responseTemplates/")
+ param = param.replace("~1", "/")
+ integration_response.response_templates.pop(param)
+
+ def update_resource(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ resource_id: String,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> Resource:
+ moto_rest_api = get_moto_rest_api(context, rest_api_id)
+ moto_resource = moto_rest_api.resources.get(resource_id)
+ if not moto_resource:
+ raise NotFoundException("Invalid Resource identifier specified")
+
+ store = get_apigateway_store(context=context)
+
+ rest_api = store.rest_apis.get(rest_api_id)
+ api_resources = rest_api.resource_children
+
+ future_path_part = moto_resource.path_part
+ current_parent_id = moto_resource.parent_id
+
+ for patch_operation in patch_operations:
+ op = patch_operation.get("op")
+ if (path := patch_operation.get("path")) not in ("/pathPart", "/parentId"):
+ raise BadRequestException(
+ f"Invalid patch path '{path}' specified for op '{op}'. Must be one of: [/parentId, /pathPart]"
+ )
+ if op != "replace":
+ raise BadRequestException(
+ f"Invalid patch path '{path}' specified for op '{op}'. Please choose supported operations"
+ )
+
+ if path == "/parentId":
+ value = patch_operation.get("value")
+ future_parent_resource = moto_rest_api.resources.get(value)
+ if not future_parent_resource:
+ raise NotFoundException("Invalid Resource identifier specified")
+
+ children_resources = api_resources.get(resource_id, [])
+ if value in children_resources:
+ raise BadRequestException("Resources cannot be cyclical.")
+
+ new_sibling_resources = api_resources.get(value, [])
+
+ else: # path == "/pathPart"
+ future_path_part = patch_operation.get("value")
+ new_sibling_resources = api_resources.get(moto_resource.parent_id, [])
+
+ for sibling in new_sibling_resources:
+ sibling_resource = moto_rest_api.resources[sibling]
+ if sibling_resource.path_part == future_path_part:
+ raise ConflictException(
+ f"Another resource with the same parent already has this name: {future_path_part}"
+ )
+
+ # TODO: test with multiple patch operations which would not be compatible between each other
+ _patch_api_gateway_entity(moto_resource, patch_operations)
+
+ # after setting it, mutate the store
+ if moto_resource.parent_id != current_parent_id:
+ current_sibling_resources = api_resources.get(current_parent_id)
+ if current_sibling_resources:
+ current_sibling_resources.remove(resource_id)
+ # if the parent does not have children anymore, remove from the list
+ if not current_sibling_resources:
+ api_resources.pop(current_parent_id)
+
+ # add it to the new parent children
+ future_sibling_resources = api_resources[moto_resource.parent_id]
+ future_sibling_resources.append(resource_id)
+
+ response = moto_resource.to_dict()
+ return response
+
+ # resource method
+
+ def get_method(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ resource_id: String,
+ http_method: String,
+ **kwargs,
+ ) -> Method:
+ response: Method = call_moto(context)
+ remove_empty_attributes_from_method(response)
+ if method_integration := response.get("methodIntegration"):
+ remove_empty_attributes_from_integration(method_integration)
+ # moto will not return `responseParameters` field if it's not truthy, but AWS will return an empty dict
+ # if it was set to an empty dict
+ if "responseParameters" not in method_integration:
+ moto_rest_api = get_moto_rest_api(context, rest_api_id)
+ moto_resource = moto_rest_api.resources[resource_id]
+ moto_method_integration = moto_resource.resource_methods[
+ http_method
+ ].method_integration
+ if moto_method_integration.integration_responses:
+ for (
+ status_code,
+ integration_response,
+ ) in moto_method_integration.integration_responses.items():
+ if integration_response.response_parameters == {}:
+ method_integration["integrationResponses"][str(status_code)][
+ "responseParameters"
+ ] = {}
+
+ return response
+
+ def put_method(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ resource_id: String,
+ http_method: String,
+ authorization_type: String,
+ authorizer_id: String = None,
+ api_key_required: Boolean = None,
+ operation_name: String = None,
+ request_parameters: MapOfStringToBoolean = None,
+ request_models: MapOfStringToString = None,
+ request_validator_id: String = None,
+ authorization_scopes: ListOfString = None,
+ **kwargs,
+ ) -> Method:
+ # TODO: add missing validation? check order of validation as well
+ moto_backend = get_moto_backend(context.account_id, context.region)
+ moto_rest_api: MotoRestAPI = moto_backend.apis.get(rest_api_id)
+ if not moto_rest_api or not (moto_resource := moto_rest_api.resources.get(resource_id)):
+ raise NotFoundException("Invalid Resource identifier specified")
+
+ if http_method not in ("GET", "PUT", "POST", "DELETE", "PATCH", "OPTIONS", "HEAD", "ANY"):
+ raise BadRequestException(
+ "Invalid HttpMethod specified. "
+ "Valid options are GET,PUT,POST,DELETE,PATCH,OPTIONS,HEAD,ANY"
+ )
+
+ if request_parameters:
+ request_parameters_names = {
+ name.rsplit(".", maxsplit=1)[-1] for name in request_parameters.keys()
+ }
+ if len(request_parameters_names) != len(request_parameters):
+ raise BadRequestException(
+ "Parameter names must be unique across querystring, header and path"
+ )
+ need_authorizer_id = authorization_type in ("CUSTOM", "COGNITO_USER_POOLS")
+ store = get_apigateway_store(context=context)
+ rest_api_container = store.rest_apis[rest_api_id]
+ if need_authorizer_id and (
+ not authorizer_id or authorizer_id not in rest_api_container.authorizers
+ ):
+ # TODO: will be cleaner with https://github.com/localstack/localstack/pull/7750
+ raise BadRequestException(
+ "Invalid authorizer ID specified. "
+ "Setting the authorization type to CUSTOM or COGNITO_USER_POOLS requires a valid authorizer."
+ )
+
+ if request_validator_id and request_validator_id not in rest_api_container.validators:
+ raise BadRequestException("Invalid Request Validator identifier specified")
+
+ if request_models:
+ for content_type, model_name in request_models.items():
+ # FIXME: add Empty model to rest api at creation
+ if model_name == EMPTY_MODEL:
+ continue
+ if model_name not in rest_api_container.models:
+ raise BadRequestException(f"Invalid model identifier specified: {model_name}")
+
+ response: Method = call_moto(context)
+ remove_empty_attributes_from_method(response)
+ moto_http_method = moto_resource.resource_methods[http_method]
+ moto_http_method.authorization_type = moto_http_method.authorization_type.upper()
+
+ # this is straight from the moto patch, did not test it yet but has the same functionality
+ # FIXME: check if still necessary after testing Authorizers
+ if need_authorizer_id and "authorizerId" not in response:
+ response["authorizerId"] = authorizer_id
+
+ response["authorizationType"] = response["authorizationType"].upper()
+
+ return response
+
+ def update_method(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ resource_id: String,
+ http_method: String,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> Method:
+ # see https://www.linkedin.com/pulse/updating-aws-cli-patch-operations-rest-api-yitzchak-meirovich/
+ # for path construction
+ moto_backend = get_moto_backend(context.account_id, context.region)
+ moto_rest_api: MotoRestAPI = moto_backend.apis.get(rest_api_id)
+ if not moto_rest_api or not (moto_resource := moto_rest_api.resources.get(resource_id)):
+ raise NotFoundException("Invalid Resource identifier specified")
+
+ if not (moto_method := moto_resource.resource_methods.get(http_method)):
+ raise NotFoundException("Invalid Method identifier specified")
+ store = get_apigateway_store(context=context)
+ rest_api = store.rest_apis[rest_api_id]
+ applicable_patch_operations = []
+ modifying_auth_type = False
+ modified_authorizer_id = False
+ had_req_params = bool(moto_method.request_parameters)
+ had_req_models = bool(moto_method.request_models)
+
+ for patch_operation in patch_operations:
+ op = patch_operation.get("op")
+ path = patch_operation.get("path")
+ # if the path is not supported at all, raise an Exception
+ if len(path.split("/")) > 3 or not any(
+ path.startswith(s_path) for s_path in UPDATE_METHOD_PATCH_PATHS["supported_paths"]
+ ):
+ raise BadRequestException(f"Invalid patch path {path}")
+
+ # if the path is not supported by the operation, ignore it and skip
+ op_supported_path = UPDATE_METHOD_PATCH_PATHS.get(op, [])
+ if not any(path.startswith(s_path) for s_path in op_supported_path):
+ available_ops = [
+ available_op
+ for available_op in ("add", "replace", "delete")
+ if available_op != op
+ ]
+ supported_ops = ", ".join(
+ [
+ supported_op
+ for supported_op in available_ops
+ if any(
+ path.startswith(s_path)
+ for s_path in UPDATE_METHOD_PATCH_PATHS.get(supported_op, [])
+ )
+ ]
+ )
+ raise BadRequestException(
+ f"Invalid patch operation specified. Must be one of: [{supported_ops}]"
+ )
+
+ value = patch_operation.get("value")
+ if op not in ("add", "replace"):
+ # skip
+ applicable_patch_operations.append(patch_operation)
+ continue
+
+ if path == "/authorizationType" and value in ("CUSTOM", "COGNITO_USER_POOLS"):
+ modifying_auth_type = True
+
+ elif path == "/authorizerId":
+ modified_authorizer_id = value
+
+ if any(
+ path.startswith(s_path) for s_path in ("/apiKeyRequired", "/requestParameters/")
+ ):
+ patch_op = {"op": op, "path": path, "value": str_to_bool(value)}
+ applicable_patch_operations.append(patch_op)
+ continue
+
+ elif path == "/requestValidatorId" and value not in rest_api.validators:
+ if not value:
+ # you can remove a requestValidator by passing an empty string as a value
+ patch_op = {"op": "remove", "path": path, "value": value}
+ applicable_patch_operations.append(patch_op)
+ continue
+ raise BadRequestException("Invalid Request Validator identifier specified")
+
+ elif path.startswith("/requestModels/"):
+ if value != EMPTY_MODEL and value not in rest_api.models:
+ raise BadRequestException(f"Invalid model identifier specified: {value}")
+
+ applicable_patch_operations.append(patch_operation)
+
+ if modifying_auth_type:
+ if not modified_authorizer_id or modified_authorizer_id not in rest_api.authorizers:
+ raise BadRequestException(
+ "Invalid authorizer ID specified. "
+ "Setting the authorization type to CUSTOM or COGNITO_USER_POOLS requires a valid authorizer."
+ )
+ elif modified_authorizer_id:
+ if moto_method.authorization_type not in ("CUSTOM", "COGNITO_USER_POOLS"):
+ # AWS will ignore this patch if the method does not have a proper authorization type
+ # filter the patches to remove the modified authorizerId
+ applicable_patch_operations = [
+ op for op in applicable_patch_operations if op.get("path") != "/authorizerId"
+ ]
+
+ # TODO: test with multiple patch operations which would not be compatible between each other
+ _patch_api_gateway_entity(moto_method, applicable_patch_operations)
+
+ # if we removed all values of those fields, set them to None so that they're not returned anymore
+ if had_req_params and len(moto_method.request_parameters) == 0:
+ moto_method.request_parameters = None
+ if had_req_models and len(moto_method.request_models) == 0:
+ moto_method.request_models = None
+
+ response = moto_method.to_json()
+ remove_empty_attributes_from_method(response)
+ remove_empty_attributes_from_integration(response.get("methodIntegration"))
+ return response
+
+ def delete_method(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ resource_id: String,
+ http_method: String,
+ **kwargs,
+ ) -> None:
+ moto_backend = get_moto_backend(context.account_id, context.region)
+ moto_rest_api: MotoRestAPI = moto_backend.apis.get(rest_api_id)
+ if not moto_rest_api or not (moto_resource := moto_rest_api.resources.get(resource_id)):
+ raise NotFoundException("Invalid Resource identifier specified")
+
+ if not (moto_resource.resource_methods.get(http_method)):
+ raise NotFoundException("Invalid Method identifier specified")
+
+ call_moto(context)
+
+ # method responses
+
+ def get_method_response(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ resource_id: String,
+ http_method: String,
+ status_code: StatusCode,
+ **kwargs,
+ ) -> MethodResponse:
+ # this could probably be easier in a patch?
+ moto_backend = get_moto_backend(context.account_id, context.region)
+ moto_rest_api: MotoRestAPI = moto_backend.apis.get(rest_api_id)
+ # TODO: snapshot test different possibilities
+ if not moto_rest_api or not (moto_resource := moto_rest_api.resources.get(resource_id)):
+ raise NotFoundException("Invalid Resource identifier specified")
+
+ if not (moto_method := moto_resource.resource_methods.get(http_method)):
+ raise NotFoundException("Invalid Method identifier specified")
+
+ if not (moto_method_response := moto_method.get_response(status_code)):
+ raise NotFoundException("Invalid Response status code specified")
+
+ method_response = moto_method_response.to_json()
+ return method_response
+
+ @handler("UpdateMethodResponse", expand=False)
+ def update_method_response(
+ self, context: RequestContext, request: TestInvokeMethodRequest
+ ) -> MethodResponse:
+ # this operation is not implemented by moto, but raises a 500 error (instead of a 501).
+ # avoid a fallback to moto and return the 501 to the client directly instead.
+ raise NotImplementedAvoidFallbackError
+
+ # stages
+
+ # TODO: add createdDate / lastUpdatedDate in Stage operations below!
+ @handler("CreateStage", expand=False)
+ def create_stage(self, context: RequestContext, request: CreateStageRequest) -> Stage:
+ call_moto(context)
+ moto_api = get_moto_rest_api(context, rest_api_id=request["restApiId"])
+ stage = moto_api.stages.get(request["stageName"])
+ if not stage:
+ raise NotFoundException("Invalid Stage identifier specified")
+
+ if not hasattr(stage, "documentation_version"):
+ stage.documentation_version = request.get("documentationVersion")
+
+ # make sure we update the stage_name on the deployment entity in moto
+ deployment = moto_api.deployments.get(request["deploymentId"])
+ deployment.stage_name = stage.name
+
+ response = stage.to_json()
+ self._patch_stage_response(response)
+ return response
+
+ def get_stage(
+ self, context: RequestContext, rest_api_id: String, stage_name: String, **kwargs
+ ) -> Stage:
+ response = call_moto(context)
+ self._patch_stage_response(response)
+ return response
+
+ def get_stages(
+ self, context: RequestContext, rest_api_id: String, deployment_id: String = None, **kwargs
+ ) -> Stages:
+ response = call_moto(context)
+ for stage in response["item"]:
+ self._patch_stage_response(stage)
+ if not stage.get("description"):
+ stage.pop("description", None)
+ return Stages(**response)
+
+ @handler("UpdateStage")
+ def update_stage(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ stage_name: String,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> Stage:
+ call_moto(context)
+
+ moto_backend = get_moto_backend(context.account_id, context.region)
+ moto_rest_api: MotoRestAPI = moto_backend.apis.get(rest_api_id)
+ if not (moto_stage := moto_rest_api.stages.get(stage_name)):
+ raise NotFoundException("Invalid Stage identifier specified")
+
+ # construct list of path regexes for validation
+ path_regexes = [re.sub("{[^}]+}", ".+", path) for path in STAGE_UPDATE_PATHS]
+
+ # copy the patch operations to not mutate them, so that we're logging the correct input
+ patch_operations = copy.deepcopy(patch_operations) or []
+ for patch_operation in patch_operations:
+ patch_path = patch_operation["path"]
+
+ # special case: handle updates (op=remove) for wildcard method settings
+ patch_path_stripped = patch_path.strip("/")
+ if patch_path_stripped == "*/*" and patch_operation["op"] == "remove":
+ if not moto_stage.method_settings.pop(patch_path_stripped, None):
+ raise BadRequestException(
+ "Cannot remove method setting */* because there is no method setting for this method "
+ )
+ response = moto_stage.to_json()
+ self._patch_stage_response(response)
+ return response
+
+ path_valid = patch_path in STAGE_UPDATE_PATHS or any(
+ re.match(regex, patch_path) for regex in path_regexes
+ )
+ if not path_valid:
+ valid_paths = f"[{', '.join(STAGE_UPDATE_PATHS)}]"
+ # note: weird formatting in AWS - required for snapshot testing
+ valid_paths = valid_paths.replace(
+ "/{resourcePath}/{httpMethod}/throttling/burstLimit, /{resourcePath}/{httpMethod}/throttling/rateLimit, /{resourcePath}/{httpMethod}/caching/ttlInSeconds",
+ "/{resourcePath}/{httpMethod}/throttling/burstLimit/{resourcePath}/{httpMethod}/throttling/rateLimit/{resourcePath}/{httpMethod}/caching/ttlInSeconds",
+ )
+ valid_paths = valid_paths.replace("/burstLimit, /", "/burstLimit /")
+ valid_paths = valid_paths.replace("/rateLimit, /", "/rateLimit /")
+ raise BadRequestException(
+ f"Invalid method setting path: {patch_operation['path']}. Must be one of: {valid_paths}"
+ )
+
+ # TODO: check if there are other boolean, maybe add a global step in _patch_api_gateway_entity
+ if patch_path == "/tracingEnabled" and (value := patch_operation.get("value")):
+ patch_operation["value"] = value and value.lower() == "true" or False
+
+ _patch_api_gateway_entity(moto_stage, patch_operations)
+ moto_stage.apply_operations(patch_operations)
+
+ response = moto_stage.to_json()
+ self._patch_stage_response(response)
+ return response
+
+ def _patch_stage_response(self, response: dict):
+ """Apply a few patches required for AWS parity"""
+ response.setdefault("cacheClusterStatus", "NOT_AVAILABLE")
+ response.setdefault("tracingEnabled", False)
+ if not response.get("variables"):
+ response.pop("variables", None)
+
+ def update_deployment(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ deployment_id: String,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> Deployment:
+ moto_rest_api = get_moto_rest_api(context, rest_api_id)
+ try:
+ deployment = moto_rest_api.get_deployment(deployment_id)
+ except KeyError:
+ raise NotFoundException("Invalid Deployment identifier specified")
+
+ for patch_operation in patch_operations:
+ # TODO: add validation for unsupported paths
+ # see https://docs.aws.amazon.com/apigateway/latest/api/patch-operations.html#UpdateDeployment-Patch
+ if (
+ patch_operation.get("path") == "/description"
+ and patch_operation.get("op") == "replace"
+ ):
+ deployment.description = patch_operation["value"]
+
+ deployment_response: Deployment = deployment.to_json() or {}
+ return deployment_response
+
+ # authorizers
+
+ @handler("CreateAuthorizer", expand=False)
+ def create_authorizer(
+ self, context: RequestContext, request: CreateAuthorizerRequest
+ ) -> Authorizer:
+ # TODO: add validation
+ api_id = request["restApiId"]
+ store = get_apigateway_store(context=context)
+ if api_id not in store.rest_apis:
+ # this seems like a weird exception to throw, but couldn't get anything different
+ # we might need to have a look again
+ raise ConflictException(
+ "Unable to complete operation due to concurrent modification. Please try again later."
+ )
+
+ authorizer_id = short_uid()[:6] # length 6 to make TF tests pass
+ authorizer = deepcopy(select_from_typed_dict(Authorizer, request))
+ authorizer["id"] = authorizer_id
+ authorizer["authorizerResultTtlInSeconds"] = int(
+ authorizer.get("authorizerResultTtlInSeconds", 300)
+ )
+ store.rest_apis[api_id].authorizers[authorizer_id] = authorizer
+
+ response = to_authorizer_response_json(api_id, authorizer)
+ return response
+
+ def get_authorizers(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ position: String = None,
+ limit: NullableInteger = None,
+ **kwargs,
+ ) -> Authorizers:
+ # TODO add paging, validation
+ rest_api_container = get_rest_api_container(context, rest_api_id=rest_api_id)
+ result = [
+ to_authorizer_response_json(rest_api_id, a)
+ for a in rest_api_container.authorizers.values()
+ ]
+ return Authorizers(items=result)
+
+ def get_authorizer(
+ self, context: RequestContext, rest_api_id: String, authorizer_id: String, **kwargs
+ ) -> Authorizer:
+ store = get_apigateway_store(context=context)
+ rest_api_container = store.rest_apis.get(rest_api_id)
+ # TODO: validate the restAPI id to remove the conditional
+ authorizer = (
+ rest_api_container.authorizers.get(authorizer_id) if rest_api_container else None
+ )
+
+ if authorizer is None:
+ raise NotFoundException(f"Authorizer not found: {authorizer_id}")
+ return to_authorizer_response_json(rest_api_id, authorizer)
+
+ def delete_authorizer(
+ self, context: RequestContext, rest_api_id: String, authorizer_id: String, **kwargs
+ ) -> None:
+ # TODO: add validation if authorizer does not exist
+ store = get_apigateway_store(context=context)
+ rest_api_container = store.rest_apis.get(rest_api_id)
+ if rest_api_container:
+ rest_api_container.authorizers.pop(authorizer_id, None)
+
+ def update_authorizer(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ authorizer_id: String,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> Authorizer:
+ # TODO: add validation
+ store = get_apigateway_store(context=context)
+ rest_api_container = store.rest_apis.get(rest_api_id)
+ # TODO: validate the restAPI id to remove the conditional
+ authorizer = (
+ rest_api_container.authorizers.get(authorizer_id) if rest_api_container else None
+ )
+
+ if authorizer is None:
+ raise NotFoundException(f"Authorizer not found: {authorizer_id}")
+
+ patched_authorizer = apply_json_patch_safe(authorizer, patch_operations)
+ # terraform sends this as a string in patch, so convert to int
+ patched_authorizer["authorizerResultTtlInSeconds"] = int(
+ patched_authorizer.get("authorizerResultTtlInSeconds", 300)
+ )
+
+ # store the updated Authorizer
+ rest_api_container.authorizers[authorizer_id] = patched_authorizer
+
+ result = to_authorizer_response_json(rest_api_id, patched_authorizer)
+ return result
+
+ # accounts
+
+ def get_account(self, context: RequestContext, **kwargs) -> Account:
+ region_details = get_apigateway_store(context=context)
+ result = to_account_response_json(region_details.account)
+ return Account(**result)
+
+ def update_account(
+ self, context: RequestContext, patch_operations: ListOfPatchOperation = None, **kwargs
+ ) -> Account:
+ region_details = get_apigateway_store(context=context)
+ apply_json_patch_safe(region_details.account, patch_operations, in_place=True)
+ result = to_account_response_json(region_details.account)
+ return Account(**result)
+
+ # documentation parts
+
+ def get_documentation_parts(
+ self, context: RequestContext, request: GetDocumentationPartsRequest, **kwargs
+ ) -> DocumentationParts:
+ # TODO: add validation
+ api_id = request["restApiId"]
+ rest_api_container = get_rest_api_container(context, rest_api_id=api_id)
+
+ result = [
+ to_documentation_part_response_json(api_id, a)
+ for a in rest_api_container.documentation_parts.values()
+ ]
+ return DocumentationParts(items=result)
+
+ def get_documentation_part(
+ self, context: RequestContext, rest_api_id: String, documentation_part_id: String, **kwargs
+ ) -> DocumentationPart:
+ # TODO: add validation
+ store = get_apigateway_store(context=context)
+ rest_api_container = store.rest_apis.get(rest_api_id)
+ # TODO: validate the restAPI id to remove the conditional
+ documentation_part = (
+ rest_api_container.documentation_parts.get(documentation_part_id)
+ if rest_api_container
+ else None
+ )
+
+ if documentation_part is None:
+ raise NotFoundException("Invalid Documentation part identifier specified")
+ return to_documentation_part_response_json(rest_api_id, documentation_part)
+
+ def create_documentation_part(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ location: DocumentationPartLocation,
+ properties: String,
+ **kwargs,
+ ) -> DocumentationPart:
+ entity_id = short_uid()[:6] # length 6 for AWS parity / Terraform compatibility
+ rest_api_container = get_rest_api_container(context, rest_api_id=rest_api_id)
+
+ # TODO: add complete validation for
+ # location parameter: https://docs.aws.amazon.com/apigateway/latest/api/API_DocumentationPartLocation.html
+ # As of now we validate only "type"
+ location_type = location.get("type")
+ valid_location_types = [
+ "API",
+ "AUTHORIZER",
+ "MODEL",
+ "RESOURCE",
+ "METHOD",
+ "PATH_PARAMETER",
+ "QUERY_PARAMETER",
+ "REQUEST_HEADER",
+ "REQUEST_BODY",
+ "RESPONSE",
+ "RESPONSE_HEADER",
+ "RESPONSE_BODY",
+ ]
+ if location_type not in valid_location_types:
+ raise CommonServiceException(
+ "ValidationException",
+ f"1 validation error detected: Value '{location_type}' at "
+ f"'createDocumentationPartInput.location.type' failed to satisfy constraint: "
+ f"Member must satisfy enum value set: "
+ f"[RESPONSE_BODY, RESPONSE, METHOD, MODEL, AUTHORIZER, RESPONSE_HEADER, "
+ f"RESOURCE, PATH_PARAMETER, REQUEST_BODY, QUERY_PARAMETER, API, REQUEST_HEADER]",
+ )
+
+ doc_part = DocumentationPart(
+ id=entity_id,
+ location=location,
+ properties=properties,
+ )
+ rest_api_container.documentation_parts[entity_id] = doc_part
+
+ result = to_documentation_part_response_json(rest_api_id, doc_part)
+ return DocumentationPart(**result)
+
+ def update_documentation_part(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ documentation_part_id: String,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> DocumentationPart:
+ # TODO: add validation
+ store = get_apigateway_store(context=context)
+ rest_api_container = store.rest_apis.get(rest_api_id)
+ # TODO: validate the restAPI id to remove the conditional
+ doc_part = (
+ rest_api_container.documentation_parts.get(documentation_part_id)
+ if rest_api_container
+ else None
+ )
+
+ if doc_part is None:
+ raise NotFoundException("Invalid Documentation part identifier specified")
+
+ for patch_operation in patch_operations:
+ path = patch_operation.get("path")
+ operation = patch_operation.get("op")
+ if operation != "replace":
+ raise BadRequestException(
+ f"Invalid patch path '{path}' specified for op '{operation}'. "
+ f"Please choose supported operations"
+ )
+
+ if path != "/properties":
+ raise BadRequestException(
+ f"Invalid patch path '{path}' specified for op 'replace'. "
+ f"Must be one of: [/properties]"
+ )
+
+ key = path[1:]
+ if key == "properties" and not patch_operation.get("value"):
+ raise BadRequestException("Documentation part properties must be non-empty")
+
+ patched_doc_part = apply_json_patch_safe(doc_part, patch_operations)
+
+ rest_api_container.documentation_parts[documentation_part_id] = patched_doc_part
+
+ return to_documentation_part_response_json(rest_api_id, patched_doc_part)
+
+ def delete_documentation_part(
+ self, context: RequestContext, rest_api_id: String, documentation_part_id: String, **kwargs
+ ) -> None:
+ # TODO: add validation if document_part does not exist, or rest_api
+ rest_api_container = get_rest_api_container(context, rest_api_id=rest_api_id)
+
+ documentation_part = rest_api_container.documentation_parts.get(documentation_part_id)
+
+ if documentation_part is None:
+ raise NotFoundException("Invalid Documentation part identifier specified")
+
+ if rest_api_container:
+ rest_api_container.documentation_parts.pop(documentation_part_id, None)
+
+ def import_documentation_parts(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ body: IO[Blob],
+ mode: PutMode = None,
+ fail_on_warnings: Boolean = None,
+ **kwargs,
+ ) -> DocumentationPartIds:
+ body_data = body.read()
+ openapi_spec = parse_json_or_yaml(to_str(body_data))
+
+ rest_api_container = get_rest_api_container(context, rest_api_id=rest_api_id)
+
+ # https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-documenting-api-quick-start-import-export.html
+ resolved_schema = resolve_references(openapi_spec, rest_api_id=rest_api_id)
+ documentation = resolved_schema.get(OpenAPIExt.DOCUMENTATION)
+
+ ids = []
+ # overwrite mode
+ if mode == PutMode.overwrite:
+ rest_api_container.documentation_parts.clear()
+ for doc_part in documentation["documentationParts"]:
+ entity_id = short_uid()[:6]
+ rest_api_container.documentation_parts[entity_id] = DocumentationPart(
+ id=entity_id, **doc_part
+ )
+ ids.append(entity_id)
+ # TODO: implement the merge mode
+ return DocumentationPartIds(ids=ids)
+
+ # documentation versions
+
+ def create_documentation_version(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ documentation_version: String,
+ stage_name: String = None,
+ description: String = None,
+ **kwargs,
+ ) -> DocumentationVersion:
+ rest_api_container = get_rest_api_container(context, rest_api_id=rest_api_id)
+
+ result = DocumentationVersion(
+ version=documentation_version, createdDate=datetime.now(), description=description
+ )
+ rest_api_container.documentation_versions[documentation_version] = result
+
+ return result
+
+ def get_documentation_version(
+ self, context: RequestContext, rest_api_id: String, documentation_version: String, **kwargs
+ ) -> DocumentationVersion:
+ rest_api_container = get_rest_api_container(context, rest_api_id=rest_api_id)
+
+ result = rest_api_container.documentation_versions.get(documentation_version)
+ if not result:
+ raise NotFoundException(f"Documentation version not found: {documentation_version}")
+
+ return result
+
+ def get_documentation_versions(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ position: String = None,
+ limit: NullableInteger = None,
+ **kwargs,
+ ) -> DocumentationVersions:
+ rest_api_container = get_rest_api_container(context, rest_api_id=rest_api_id)
+ result = list(rest_api_container.documentation_versions.values())
+ return DocumentationVersions(items=result)
+
+ def delete_documentation_version(
+ self, context: RequestContext, rest_api_id: String, documentation_version: String, **kwargs
+ ) -> None:
+ rest_api_container = get_rest_api_container(context, rest_api_id=rest_api_id)
+
+ result = rest_api_container.documentation_versions.pop(documentation_version, None)
+ if not result:
+ raise NotFoundException(f"Documentation version not found: {documentation_version}")
+
+ def update_documentation_version(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ documentation_version: String,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> DocumentationVersion:
+ rest_api_container = get_rest_api_container(context, rest_api_id=rest_api_id)
+
+ result = rest_api_container.documentation_versions.get(documentation_version)
+ if not result:
+ raise NotFoundException(f"Documentation version not found: {documentation_version}")
+
+ _patch_api_gateway_entity(result, patch_operations)
+
+ return result
+
+ # base path mappings
+
+ def get_base_path_mappings(
+ self,
+ context: RequestContext,
+ domain_name: String,
+ domain_name_id: String = None,
+ position: String = None,
+ limit: NullableInteger = None,
+ **kwargs,
+ ) -> BasePathMappings:
+ region_details = get_apigateway_store(context=context)
+
+ mappings_list = region_details.base_path_mappings.get(domain_name) or []
+
+ result = [
+ to_base_mapping_response_json(domain_name, m["basePath"], m) for m in mappings_list
+ ]
+ return BasePathMappings(items=result)
+
+ def get_base_path_mapping(
+ self,
+ context: RequestContext,
+ domain_name: String,
+ base_path: String,
+ domain_name_id: String = None,
+ **kwargs,
+ ) -> BasePathMapping:
+ region_details = get_apigateway_store(context=context)
+
+ mappings_list = region_details.base_path_mappings.get(domain_name) or []
+ mapping = ([m for m in mappings_list if m["basePath"] == base_path] or [None])[0]
+ if mapping is None:
+ raise NotFoundException(f"Base path mapping not found: {domain_name} - {base_path}")
+
+ result = to_base_mapping_response_json(domain_name, base_path, mapping)
+ return BasePathMapping(**result)
+
+ def create_base_path_mapping(
+ self,
+ context: RequestContext,
+ domain_name: String,
+ rest_api_id: String,
+ domain_name_id: String = None,
+ base_path: String = None,
+ stage: String = None,
+ **kwargs,
+ ) -> BasePathMapping:
+ region_details = get_apigateway_store(context=context)
+
+ # Note: "(none)" is a special value in API GW:
+ # https://docs.aws.amazon.com/apigateway/api-reference/link-relation/basepathmapping-by-base-path
+ base_path = base_path or "(none)"
+
+ entry = {
+ "domainName": domain_name,
+ "restApiId": rest_api_id,
+ "basePath": base_path,
+ "stage": stage,
+ }
+ region_details.base_path_mappings.setdefault(domain_name, []).append(entry)
+
+ result = to_base_mapping_response_json(domain_name, base_path, entry)
+ return BasePathMapping(**result)
+
+ def update_base_path_mapping(
+ self,
+ context: RequestContext,
+ domain_name: String,
+ base_path: String,
+ domain_name_id: String = None,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> BasePathMapping:
+ region_details = get_apigateway_store(context=context)
+
+ mappings_list = region_details.base_path_mappings.get(domain_name) or []
+
+ mapping = ([m for m in mappings_list if m["basePath"] == base_path] or [None])[0]
+ if mapping is None:
+ raise NotFoundException(
+ f"Not found: mapping for domain name {domain_name}, "
+ f"base path {base_path} in list {mappings_list}"
+ )
+
+ patch_operations = ensure_list(patch_operations)
+ for operation in patch_operations:
+ if operation["path"] == "/restapiId":
+ operation["path"] = "/restApiId"
+ result = apply_json_patch_safe(mapping, patch_operations)
+
+ for i in range(len(mappings_list)):
+ if mappings_list[i]["basePath"] == base_path:
+ mappings_list[i] = result
+
+ result = to_base_mapping_response_json(domain_name, base_path, result)
+ return BasePathMapping(**result)
+
+ def delete_base_path_mapping(
+ self,
+ context: RequestContext,
+ domain_name: String,
+ base_path: String,
+ domain_name_id: String = None,
+ **kwargs,
+ ) -> None:
+ region_details = get_apigateway_store(context=context)
+
+ mappings_list = region_details.base_path_mappings.get(domain_name) or []
+ for i in range(len(mappings_list)):
+ if mappings_list[i]["basePath"] == base_path:
+ del mappings_list[i]
+ return
+
+ raise NotFoundException(f"Base path mapping {base_path} for domain {domain_name} not found")
+
+ # client certificates
+
+ def get_client_certificate(
+ self, context: RequestContext, client_certificate_id: String, **kwargs
+ ) -> ClientCertificate:
+ region_details = get_apigateway_store(context=context)
+ result = region_details.client_certificates.get(client_certificate_id)
+ if result is None:
+ raise NotFoundException(f"Client certificate ID {client_certificate_id} not found")
+ return ClientCertificate(**result)
+
+ def get_client_certificates(
+ self,
+ context: RequestContext,
+ position: String = None,
+ limit: NullableInteger = None,
+ **kwargs,
+ ) -> ClientCertificates:
+ region_details = get_apigateway_store(context=context)
+ result = list(region_details.client_certificates.values())
+ return ClientCertificates(items=result)
+
+ def generate_client_certificate(
+ self,
+ context: RequestContext,
+ description: String = None,
+ tags: MapOfStringToString = None,
+ **kwargs,
+ ) -> ClientCertificate:
+ region_details = get_apigateway_store(context=context)
+ cert_id = short_uid()
+ creation_time = now_utc()
+ entry = {
+ "description": description,
+ "tags": tags,
+ "clientCertificateId": cert_id,
+ "createdDate": creation_time,
+ "expirationDate": creation_time + 60 * 60 * 24 * 30, # assume 30 days validity
+ "pemEncodedCertificate": "testcert-123", # TODO return proper certificate!
+ }
+ region_details.client_certificates[cert_id] = entry
+ result = to_client_cert_response_json(entry)
+ return ClientCertificate(**result)
+
+ def update_client_certificate(
+ self,
+ context: RequestContext,
+ client_certificate_id: String,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> ClientCertificate:
+ region_details = get_apigateway_store(context=context)
+ entity = region_details.client_certificates.get(client_certificate_id)
+ if entity is None:
+ raise NotFoundException(f'Client certificate ID "{client_certificate_id}" not found')
+ result = apply_json_patch_safe(entity, patch_operations)
+ result = to_client_cert_response_json(result)
+ return ClientCertificate(**result)
+
+ def delete_client_certificate(
+ self, context: RequestContext, client_certificate_id: String, **kwargs
+ ) -> None:
+ region_details = get_apigateway_store(context=context)
+ entity = region_details.client_certificates.pop(client_certificate_id, None)
+ if entity is None:
+ raise NotFoundException(f'VPC link ID "{client_certificate_id}" not found for deletion')
+
+ # VPC links
+
+ def create_vpc_link(
+ self,
+ context: RequestContext,
+ name: String,
+ target_arns: ListOfString,
+ description: String = None,
+ tags: MapOfStringToString = None,
+ **kwargs,
+ ) -> VpcLink:
+ region_details = get_apigateway_store(context=context)
+ link_id = short_uid()
+ entry = {"id": link_id, "status": "AVAILABLE"}
+ region_details.vpc_links[link_id] = entry
+ result = to_vpc_link_response_json(entry)
+ return VpcLink(**result)
+
+ def get_vpc_links(
+ self,
+ context: RequestContext,
+ position: String = None,
+ limit: NullableInteger = None,
+ **kwargs,
+ ) -> VpcLinks:
+ region_details = get_apigateway_store(context=context)
+ result = region_details.vpc_links.values()
+ result = [to_vpc_link_response_json(r) for r in result]
+ result = {"items": result}
+ return result
+
+ def get_vpc_link(self, context: RequestContext, vpc_link_id: String, **kwargs) -> VpcLink:
+ region_details = get_apigateway_store(context=context)
+ vpc_link = region_details.vpc_links.get(vpc_link_id)
+ if vpc_link is None:
+ raise NotFoundException(f'VPC link ID "{vpc_link_id}" not found')
+ result = to_vpc_link_response_json(vpc_link)
+ return VpcLink(**result)
+
+ def update_vpc_link(
+ self,
+ context: RequestContext,
+ vpc_link_id: String,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> VpcLink:
+ region_details = get_apigateway_store(context=context)
+ vpc_link = region_details.vpc_links.get(vpc_link_id)
+ if vpc_link is None:
+ raise NotFoundException(f'VPC link ID "{vpc_link_id}" not found')
+ result = apply_json_patch_safe(vpc_link, patch_operations)
+ result = to_vpc_link_response_json(result)
+ return VpcLink(**result)
+
+ def delete_vpc_link(self, context: RequestContext, vpc_link_id: String, **kwargs) -> None:
+ region_details = get_apigateway_store(context=context)
+ vpc_link = region_details.vpc_links.pop(vpc_link_id, None)
+ if vpc_link is None:
+ raise NotFoundException(f'VPC link ID "{vpc_link_id}" not found for deletion')
+
+ # request validators
+
+ def get_request_validators(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ position: String = None,
+ limit: NullableInteger = None,
+ **kwargs,
+ ) -> RequestValidators:
+ # TODO: add validation and pagination?
+ store = get_apigateway_store(context=context)
+ if not (rest_api_container := store.rest_apis.get(rest_api_id)):
+ raise NotFoundException(
+ f"Invalid API identifier specified {context.account_id}:{rest_api_id}"
+ )
+
+ result = [
+ to_validator_response_json(rest_api_id, a)
+ for a in rest_api_container.validators.values()
+ ]
+ return RequestValidators(items=result)
+
+ def get_request_validator(
+ self, context: RequestContext, rest_api_id: String, request_validator_id: String, **kwargs
+ ) -> RequestValidator:
+ store = get_apigateway_store(context=context)
+ rest_api_container = store.rest_apis.get(rest_api_id)
+ # TODO: validate the restAPI id to remove the conditional
+ validator = (
+ rest_api_container.validators.get(request_validator_id) if rest_api_container else None
+ )
+
+ if validator is None:
+ raise NotFoundException("Invalid Request Validator identifier specified")
+
+ result = to_validator_response_json(rest_api_id, validator)
+ return result
+
+ def create_request_validator(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ name: String = None,
+ validate_request_body: Boolean = None,
+ validate_request_parameters: Boolean = None,
+ **kwargs,
+ ) -> RequestValidator:
+ # TODO: add validation (ex: name cannot be blank)
+ store = get_apigateway_store(context=context)
+ if not (rest_api_container := store.rest_apis.get(rest_api_id)):
+ raise BadRequestException("Invalid REST API identifier specified")
+ # length 6 for AWS parity and TF compatibility
+ validator_id = short_uid()[:6]
+
+ validator = RequestValidator(
+ id=validator_id,
+ name=name,
+ validateRequestBody=validate_request_body or False,
+ validateRequestParameters=validate_request_parameters or False,
+ )
+
+ rest_api_container.validators[validator_id] = validator
+
+ # missing to_validator_response_json ?
+ return validator
+
+ def update_request_validator(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ request_validator_id: String,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> RequestValidator:
+ # TODO: add validation
+ store = get_apigateway_store(context=context)
+ rest_api_container = store.rest_apis.get(rest_api_id)
+ # TODO: validate the restAPI id to remove the conditional
+ validator = (
+ rest_api_container.validators.get(request_validator_id) if rest_api_container else None
+ )
+
+ if validator is None:
+ raise NotFoundException(
+ f"Validator {request_validator_id} for API Gateway {rest_api_id} not found"
+ )
+
+ for patch_operation in patch_operations:
+ path = patch_operation.get("path")
+ operation = patch_operation.get("op")
+ if operation != "replace":
+ raise BadRequestException(
+ f"Invalid patch path '{path}' specified for op '{operation}'. "
+ f"Please choose supported operations"
+ )
+ if path not in ("/name", "/validateRequestBody", "/validateRequestParameters"):
+ raise BadRequestException(
+ f"Invalid patch path '{path}' specified for op 'replace'. "
+ f"Must be one of: [/name, /validateRequestParameters, /validateRequestBody]"
+ )
+
+ key = path[1:]
+ value = patch_operation.get("value")
+ if key == "name" and not value:
+ raise BadRequestException("Request Validator name cannot be blank")
+
+ elif key in ("validateRequestParameters", "validateRequestBody"):
+ value = value and value.lower() == "true" or False
+
+ rest_api_container.validators[request_validator_id][key] = value
+
+ return to_validator_response_json(
+ rest_api_id, rest_api_container.validators[request_validator_id]
+ )
+
+ def delete_request_validator(
+ self, context: RequestContext, rest_api_id: String, request_validator_id: String, **kwargs
+ ) -> None:
+ # TODO: add validation if rest api does not exist
+ store = get_apigateway_store(context=context)
+ rest_api_container = store.rest_apis.get(rest_api_id)
+ if not rest_api_container:
+ raise NotFoundException("Invalid Request Validator identifier specified")
+
+ validator = rest_api_container.validators.pop(request_validator_id, None)
+ if not validator:
+ raise NotFoundException("Invalid Request Validator identifier specified")
+
+ # tags
+
+ def get_tags(
+ self,
+ context: RequestContext,
+ resource_arn: String,
+ position: String = None,
+ limit: NullableInteger = None,
+ **kwargs,
+ ) -> Tags:
+ result = get_apigateway_store(context=context).TAGS.get(resource_arn, {})
+ return Tags(tags=result)
+
+ def tag_resource(
+ self, context: RequestContext, resource_arn: String, tags: MapOfStringToString, **kwargs
+ ) -> None:
+ resource_tags = get_apigateway_store(context=context).TAGS.setdefault(resource_arn, {})
+ resource_tags.update(tags)
+
+ def untag_resource(
+ self, context: RequestContext, resource_arn: String, tag_keys: ListOfString, **kwargs
+ ) -> None:
+ resource_tags = get_apigateway_store(context=context).TAGS.setdefault(resource_arn, {})
+ for key in tag_keys:
+ resource_tags.pop(key, None)
+
+ def import_rest_api(
+ self,
+ context: RequestContext,
+ body: IO[Blob],
+ fail_on_warnings: Boolean = None,
+ parameters: MapOfStringToString = None,
+ **kwargs,
+ ) -> RestApi:
+ body_data = body.read()
+
+ # create rest api
+ openapi_spec = parse_json_or_yaml(to_str(body_data))
+ create_api_request = CreateRestApiRequest(name=openapi_spec.get("info").get("title"))
+ create_api_context = create_custom_context(
+ context,
+ "CreateRestApi",
+ create_api_request,
+ )
+ response = self.create_rest_api(create_api_context, create_api_request)
+ api_id = response.get("id")
+ # remove the 2 default models automatically created, but not when importing
+ store = get_apigateway_store(context=context)
+ store.rest_apis[api_id].models = {}
+
+ # put rest api
+ put_api_request = PutRestApiRequest(
+ restApiId=api_id,
+ failOnWarnings=str_to_bool(fail_on_warnings) or False,
+ parameters=parameters or {},
+ body=io.BytesIO(body_data),
+ )
+ put_api_context = create_custom_context(
+ context,
+ "PutRestApi",
+ put_api_request,
+ )
+ put_api_response = self.put_rest_api(put_api_context, put_api_request)
+ if not put_api_response.get("tags"):
+ put_api_response.pop("tags", None)
+ return put_api_response
+
+ # integrations
+
+ def get_integration(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ resource_id: String,
+ http_method: String,
+ **kwargs,
+ ) -> Integration:
+ try:
+ response: Integration = call_moto(context)
+ except CommonServiceException as e:
+ # the Exception raised by moto does not have the right message not status code
+ if e.code == "NotFoundException":
+ raise NotFoundException("Invalid Integration identifier specified")
+ raise
+
+ if integration_responses := response.get("integrationResponses"):
+ for integration_response in integration_responses.values():
+ remove_empty_attributes_from_integration_response(integration_response)
+
+ return response
+
+ def put_integration(
+ self, context: RequestContext, request: PutIntegrationRequest, **kwargs
+ ) -> Integration:
+ if (integration_type := request.get("type")) not in VALID_INTEGRATION_TYPES:
+ raise CommonServiceException(
+ "ValidationException",
+ f"1 validation error detected: Value '{integration_type}' at "
+ f"'putIntegrationInput.type' failed to satisfy constraint: "
+ f"Member must satisfy enum value set: [HTTP, MOCK, AWS_PROXY, HTTP_PROXY, AWS]",
+ )
+
+ elif integration_type == IntegrationType.AWS_PROXY:
+ integration_uri = request.get("uri") or ""
+ if ":lambda:" not in integration_uri and ":firehose:" not in integration_uri:
+ raise BadRequestException(
+ "Integrations of type 'AWS_PROXY' currently only supports "
+ "Lambda function and Firehose stream invocations."
+ )
+ moto_rest_api = get_moto_rest_api(context=context, rest_api_id=request.get("restApiId"))
+ resource = moto_rest_api.resources.get(request.get("resourceId"))
+ if not resource:
+ raise NotFoundException("Invalid Resource identifier specified")
+
+ method = resource.resource_methods.get(request.get("httpMethod"))
+ if not method:
+ raise NotFoundException("Invalid Method identifier specified")
+
+ # TODO: if the IntegrationType is AWS, `credentials` is mandatory
+ moto_request = copy.copy(request)
+ moto_request.setdefault("passthroughBehavior", "WHEN_NO_MATCH")
+ moto_request.setdefault("timeoutInMillis", 29000)
+ if integration_type in (IntegrationType.HTTP, IntegrationType.HTTP_PROXY):
+ moto_request.setdefault("connectionType", ConnectionType.INTERNET)
+ response = call_moto_with_request(context, moto_request)
+ remove_empty_attributes_from_integration(integration=response)
+
+ # TODO: should fix fundamentally once we move away from moto
+ if integration_type == "MOCK":
+ response.pop("uri", None)
+
+ return response
+
+ def update_integration(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ resource_id: String,
+ http_method: String,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> Integration:
+ moto_rest_api = get_moto_rest_api(context=context, rest_api_id=rest_api_id)
+ resource = moto_rest_api.resources.get(resource_id)
+ if not resource:
+ raise NotFoundException("Invalid Resource identifier specified")
+
+ method = resource.resource_methods.get(http_method)
+ if not method:
+ raise NotFoundException("Invalid Integration identifier specified")
+
+ integration = method.method_integration
+ _patch_api_gateway_entity(integration, patch_operations)
+
+ # fix data types
+ if integration.timeout_in_millis:
+ integration.timeout_in_millis = int(integration.timeout_in_millis)
+ if skip_verification := (integration.tls_config or {}).get("insecureSkipVerification"):
+ integration.tls_config["insecureSkipVerification"] = str_to_bool(skip_verification)
+
+ integration_dict: Integration = integration.to_json()
+ return integration_dict
+
+ def delete_integration(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ resource_id: String,
+ http_method: String,
+ **kwargs,
+ ) -> None:
+ try:
+ call_moto(context)
+ except Exception as e:
+ raise NotFoundException("Invalid Resource identifier specified") from e
+
+ # integration responses
+
+ def get_integration_response(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ resource_id: String,
+ http_method: String,
+ status_code: StatusCode,
+ **kwargs,
+ ) -> IntegrationResponse:
+ response: IntegrationResponse = call_moto(context)
+ remove_empty_attributes_from_integration_response(response)
+ # moto does not return selectionPattern is set to an empty string
+ # TODO: fix upstream
+ if "selectionPattern" not in response:
+ moto_rest_api = get_moto_rest_api(context, rest_api_id)
+ moto_resource = moto_rest_api.resources.get(resource_id)
+ method_integration = moto_resource.resource_methods[http_method].method_integration
+ integration_response = method_integration.integration_responses[status_code]
+ if integration_response.selection_pattern is not None:
+ response["selectionPattern"] = integration_response.selection_pattern
+ return response
+
+ @handler("PutIntegrationResponse", expand=False)
+ def put_integration_response(
+ self,
+ context: RequestContext,
+ request: PutIntegrationResponseRequest,
+ ) -> IntegrationResponse:
+ moto_rest_api = get_moto_rest_api(context=context, rest_api_id=request.get("restApiId"))
+ moto_resource = moto_rest_api.resources.get(request.get("resourceId"))
+ if not moto_resource:
+ raise NotFoundException("Invalid Resource identifier specified")
+
+ method = moto_resource.resource_methods.get(request.get("httpMethod"))
+ if not method:
+ raise NotFoundException("Invalid Method identifier specified")
+
+ response = call_moto(context)
+ # Moto has a specific case where it will set a None to an empty dict, but AWS does not behave the same
+ if request.get("responseTemplates") is None:
+ method_integration = moto_resource.resource_methods[
+ request["httpMethod"]
+ ].method_integration
+ integration_response = method_integration.integration_responses[request["statusCode"]]
+ integration_response.response_templates = None
+ response.pop("responseTemplates", None)
+
+ # Moto also does not return the selection pattern if it is set to an empty string
+ # TODO: fix upstream
+ if (selection_pattern := request.get("selectionPattern")) is not None:
+ response["selectionPattern"] = selection_pattern
+
+ return response
+
+ def get_export(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ stage_name: String,
+ export_type: String,
+ parameters: MapOfStringToString = None,
+ accepts: String = None,
+ **kwargs,
+ ) -> ExportResponse:
+ moto_rest_api = get_moto_rest_api(context, rest_api_id)
+ openapi_exporter = OpenApiExporter()
+ # FIXME: look into parser why `parameters` is always None
+ has_extension = context.request.values.get("extensions") == "apigateway"
+ result = openapi_exporter.export_api(
+ api_id=rest_api_id,
+ stage=stage_name,
+ export_type=export_type,
+ export_format=accepts,
+ with_extension=has_extension,
+ account_id=context.account_id,
+ region_name=context.region,
+ )
+
+ accepts = accepts or APPLICATION_JSON
+
+ if accepts == APPLICATION_JSON:
+ result = json.dumps(result, indent=2)
+
+ file_ext = accepts.split("/")[-1]
+ version = moto_rest_api.version or timestamp(
+ moto_rest_api.create_date, format=TIMESTAMP_FORMAT_TZ
+ )
+ return ExportResponse(
+ body=to_bytes(result),
+ contentType="application/octet-stream",
+ contentDisposition=f'attachment; filename="{export_type}_{version}.{file_ext}"',
+ )
+
+ def get_api_keys(
+ self,
+ context: RequestContext,
+ position: String = None,
+ limit: NullableInteger = None,
+ name_query: String = None,
+ customer_id: String = None,
+ include_values: NullableBoolean = None,
+ **kwargs,
+ ) -> ApiKeys:
+ # TODO: migrate API keys in our store
+ moto_backend = get_moto_backend(context.account_id, context.region)
+ api_keys = [api_key.to_json() for api_key in reversed(moto_backend.keys.values())]
+ if not include_values:
+ for api_key in api_keys:
+ api_key.pop("value")
+
+ item_list = PaginatedList(api_keys)
+
+ def token_generator(item):
+ return md5(item["id"])
+
+ def filter_function(item):
+ return item["name"].startswith(name_query)
+
+ paginated_list, next_token = item_list.get_page(
+ token_generator=token_generator,
+ next_token=position,
+ page_size=limit,
+ filter_function=filter_function if name_query else None,
+ )
+
+ return ApiKeys(items=paginated_list, position=next_token)
+
+ def update_api_key(
+ self,
+ context: RequestContext,
+ api_key: String,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> ApiKey:
+ response: ApiKey = call_moto(context)
+ if "value" in response:
+ response.pop("value", None)
+
+ if "tags" not in response:
+ response["tags"] = {}
+
+ return response
+
+ def create_model(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ name: String,
+ content_type: String,
+ description: String = None,
+ schema: String = None,
+ **kwargs,
+ ) -> Model:
+ store = get_apigateway_store(context=context)
+ if rest_api_id not in store.rest_apis:
+ raise NotFoundException(
+ f"Invalid API identifier specified {context.account_id}:{rest_api_id}"
+ )
+
+ if not name:
+ raise BadRequestException("Model name must be non-empty")
+
+ if name in store.rest_apis[rest_api_id].models:
+ raise ConflictException("Model name already exists for this REST API")
+
+ if not schema:
+ # TODO: maybe add more validation around the schema, valid json string?
+ raise BadRequestException(
+ "Model schema must have at least 1 property or array items defined"
+ )
+
+ model_id = short_uid()[:6] # length 6 to make TF tests pass
+ model = Model(
+ id=model_id, name=name, contentType=content_type, description=description, schema=schema
+ )
+ store.rest_apis[rest_api_id].models[name] = model
+ remove_empty_attributes_from_model(model)
+ return model
+
+ def get_models(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ position: String = None,
+ limit: NullableInteger = None,
+ **kwargs,
+ ) -> Models:
+ store = get_apigateway_store(context=context)
+ if rest_api_id not in store.rest_apis:
+ raise NotFoundException(
+ f"Invalid API identifier specified {context.account_id}:{rest_api_id}"
+ )
+
+ models = [
+ remove_empty_attributes_from_model(model)
+ for model in store.rest_apis[rest_api_id].models.values()
+ ]
+ return Models(items=models)
+
+ def get_model(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ model_name: String,
+ flatten: Boolean = None,
+ **kwargs,
+ ) -> Model:
+ store = get_apigateway_store(context=context)
+ if rest_api_id not in store.rest_apis or not (
+ model := store.rest_apis[rest_api_id].models.get(model_name)
+ ):
+ raise NotFoundException(f"Invalid model name specified: {model_name}")
+
+ return model
+
+ def update_model(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ model_name: String,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> Model:
+ # manually update the model, not need for JSON patch, only 2 path supported with replace operation
+ # /schema
+ # /description
+ store = get_apigateway_store(context=context)
+ if rest_api_id not in store.rest_apis or not (
+ model := store.rest_apis[rest_api_id].models.get(model_name)
+ ):
+ raise NotFoundException(f"Invalid model name specified: {model_name}")
+
+ for operation in patch_operations:
+ path = operation.get("path")
+ if operation.get("op") != "replace":
+ raise BadRequestException(
+ f"Invalid patch path '{path}' specified for op 'add'. Please choose supported operations"
+ )
+ if path not in ("/schema", "/description"):
+ raise BadRequestException(
+ f"Invalid patch path '{path}' specified for op 'replace'. Must be one of: [/description, /schema]"
+ )
+
+ key = path[1:] # remove the leading slash
+ value = operation.get("value")
+ if key == "schema":
+ if not value:
+ raise BadRequestException(
+ "Model schema must have at least 1 property or array items defined"
+ )
+ # delete the resolved model to invalidate it
+ store.rest_apis[rest_api_id].resolved_models.pop(model_name, None)
+ model[key] = value
+ remove_empty_attributes_from_model(model)
+ return model
+
+ def delete_model(
+ self, context: RequestContext, rest_api_id: String, model_name: String, **kwargs
+ ) -> None:
+ store = get_apigateway_store(context=context)
+
+ if (
+ rest_api_id not in store.rest_apis
+ or model_name not in store.rest_apis[rest_api_id].models
+ ):
+ raise NotFoundException(f"Invalid model name specified: {model_name}")
+
+ moto_rest_api = get_moto_rest_api(context, rest_api_id)
+ validate_model_in_use(moto_rest_api, model_name)
+
+ store.rest_apis[rest_api_id].models.pop(model_name, None)
+ store.rest_apis[rest_api_id].resolved_models.pop(model_name, None)
+
+ @handler("CreateUsagePlan")
+ def create_usage_plan(
+ self,
+ context: RequestContext,
+ name: String,
+ description: String = None,
+ api_stages: ListOfApiStage = None,
+ throttle: ThrottleSettings = None,
+ quota: QuotaSettings = None,
+ tags: MapOfStringToString = None,
+ **kwargs,
+ ) -> UsagePlan:
+ usage_plan: UsagePlan = call_moto(context=context)
+ if not usage_plan.get("quota"):
+ usage_plan.pop("quota", None)
+
+ fix_throttle_and_quota_from_usage_plan(usage_plan)
+
+ return usage_plan
+
+ def update_usage_plan(
+ self,
+ context: RequestContext,
+ usage_plan_id: String,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> UsagePlan:
+ for patch_op in patch_operations:
+ if patch_op.get("op") == "remove" and patch_op.get("path") == "/apiStages":
+ if not (api_stage_id := patch_op.get("value")):
+ raise BadRequestException("Invalid API Stage specified")
+ if not len(split_stage_id := api_stage_id.split(":")) == 2:
+ raise BadRequestException("Invalid API Stage specified")
+ rest_api_id, stage_name = split_stage_id
+ moto_backend = apigw_models.apigateway_backends[context.account_id][context.region]
+ if not (rest_api := moto_backend.apis.get(rest_api_id)):
+ raise NotFoundException(
+ f"Invalid API Stage {{api: {rest_api_id}, stage: {stage_name}}} specified for usageplan {usage_plan_id}"
+ )
+ if stage_name not in rest_api.stages:
+ raise NotFoundException(
+ f"Invalid API Stage {{api: {rest_api_id}, stage: {stage_name}}} specified for usageplan {usage_plan_id}"
+ )
+
+ usage_plan = call_moto(context=context)
+ if not usage_plan.get("quota"):
+ usage_plan.pop("quota", None)
+
+ usage_plan_arn = f"arn:{get_partition(context.region)}:apigateway:{context.region}::/usageplans/{usage_plan_id}"
+ existing_tags = get_apigateway_store(context=context).TAGS.get(usage_plan_arn, {})
+ if "tags" not in usage_plan:
+ usage_plan["tags"] = existing_tags
+ else:
+ usage_plan["tags"].update(existing_tags)
+
+ fix_throttle_and_quota_from_usage_plan(usage_plan)
+
+ return usage_plan
+
+ def get_usage_plan(self, context: RequestContext, usage_plan_id: String, **kwargs) -> UsagePlan:
+ usage_plan: UsagePlan = call_moto(context=context)
+ if not usage_plan.get("quota"):
+ usage_plan.pop("quota", None)
+
+ fix_throttle_and_quota_from_usage_plan(usage_plan)
+
+ usage_plan_arn = f"arn:{get_partition(context.region)}:apigateway:{context.region}::/usageplans/{usage_plan_id}"
+ existing_tags = get_apigateway_store(context=context).TAGS.get(usage_plan_arn, {})
+ if "tags" not in usage_plan:
+ usage_plan["tags"] = existing_tags
+ else:
+ usage_plan["tags"].update(existing_tags)
+
+ return usage_plan
+
+ @handler("GetUsagePlans")
+ def get_usage_plans(
+ self,
+ context: RequestContext,
+ position: String = None,
+ key_id: String = None,
+ limit: NullableInteger = None,
+ **kwargs,
+ ) -> UsagePlans:
+ usage_plans: UsagePlans = call_moto(context=context)
+ if not usage_plans.get("items"):
+ usage_plans["items"] = []
+
+ items = usage_plans["items"]
+ for up in items:
+ if not up.get("quota"):
+ up.pop("quota", None)
+
+ fix_throttle_and_quota_from_usage_plan(up)
+
+ if "tags" not in up:
+ up.pop("tags", None)
+
+ return usage_plans
+
+ def get_usage_plan_keys(
+ self,
+ context: RequestContext,
+ usage_plan_id: String,
+ position: String = None,
+ limit: NullableInteger = None,
+ name_query: String = None,
+ **kwargs,
+ ) -> UsagePlanKeys:
+ # TODO: migrate Usage Plan and UsagePlan Keys to our store
+ moto_backend = get_moto_backend(context.account_id, context.region)
+
+ if not (usage_plan_keys := moto_backend.usage_plan_keys.get(usage_plan_id)):
+ return UsagePlanKeys(items=[])
+
+ usage_plan_keys = [
+ usage_plan_key.to_json()
+ for usage_plan_key in reversed(usage_plan_keys.values())
+ if usage_plan_key.id in moto_backend.keys
+ ]
+
+ item_list = PaginatedList(usage_plan_keys)
+
+ def token_generator(item):
+ return md5(item["id"])
+
+ def filter_function(item):
+ return item["name"].startswith(name_query)
+
+ paginated_list, next_token = item_list.get_page(
+ token_generator=token_generator,
+ next_token=position,
+ page_size=limit,
+ filter_function=filter_function if name_query else None,
+ )
+
+ return UsagePlanKeys(items=paginated_list, position=next_token)
+
+ def put_gateway_response(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ response_type: GatewayResponseType,
+ status_code: StatusCode = None,
+ response_parameters: MapOfStringToString = None,
+ response_templates: MapOfStringToString = None,
+ **kwargs,
+ ) -> GatewayResponse:
+ # There were no validation in moto, so implementing as is
+ # TODO: add validation
+ # TODO: this is only the CRUD implementation, implement it in the invocation part of the code
+ store = get_apigateway_store(context=context)
+ if not (rest_api_container := store.rest_apis.get(rest_api_id)):
+ raise NotFoundException(
+ f"Invalid API identifier specified {context.account_id}:{rest_api_id}"
+ )
+
+ if response_type not in DEFAULT_GATEWAY_RESPONSES:
+ raise CommonServiceException(
+ code="ValidationException",
+ message=f"1 validation error detected: Value '{response_type}' at 'responseType' failed to satisfy constraint: Member must satisfy enum value set: [{', '.join(DEFAULT_GATEWAY_RESPONSES)}]",
+ )
+
+ gateway_response = GatewayResponse(
+ statusCode=status_code,
+ responseParameters=response_parameters,
+ responseTemplates=response_templates,
+ responseType=response_type,
+ defaultResponse=False,
+ )
+ rest_api_container.gateway_responses[response_type] = gateway_response
+ return gateway_response
+
+ def get_gateway_response(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ response_type: GatewayResponseType,
+ **kwargs,
+ ) -> GatewayResponse:
+ store = get_apigateway_store(context=context)
+ if not (rest_api_container := store.rest_apis.get(rest_api_id)):
+ raise NotFoundException(
+ f"Invalid API identifier specified {context.account_id}:{rest_api_id}"
+ )
+
+ if response_type not in DEFAULT_GATEWAY_RESPONSES:
+ raise CommonServiceException(
+ code="ValidationException",
+ message=f"1 validation error detected: Value '{response_type}' at 'responseType' failed to satisfy constraint: Member must satisfy enum value set: [{', '.join(DEFAULT_GATEWAY_RESPONSES)}]",
+ )
+
+ gateway_response = rest_api_container.gateway_responses.get(
+ response_type, DEFAULT_GATEWAY_RESPONSES[response_type]
+ )
+ # TODO: add validation with the parameters? seems like it validated client side? how to try?
+ return gateway_response
+
+ def get_gateway_responses(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ position: String = None,
+ limit: NullableInteger = None,
+ **kwargs,
+ ) -> GatewayResponses:
+ store = get_apigateway_store(context=context)
+ if not (rest_api_container := store.rest_apis.get(rest_api_id)):
+ raise NotFoundException(
+ f"Invalid API identifier specified {context.account_id}:{rest_api_id}"
+ )
+
+ user_gateway_resp = rest_api_container.gateway_responses
+ gateway_responses = [
+ user_gateway_resp.get(key) or value for key, value in DEFAULT_GATEWAY_RESPONSES.items()
+ ]
+ return GatewayResponses(items=gateway_responses)
+
+ def delete_gateway_response(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ response_type: GatewayResponseType,
+ **kwargs,
+ ) -> None:
+ store = get_apigateway_store(context=context)
+ if not (rest_api_container := store.rest_apis.get(rest_api_id)):
+ raise NotFoundException(
+ f"Invalid API identifier specified {context.account_id}:{rest_api_id}"
+ )
+
+ if response_type not in DEFAULT_GATEWAY_RESPONSES:
+ raise CommonServiceException(
+ code="ValidationException",
+ message=f"1 validation error detected: Value '{response_type}' at 'responseType' failed to satisfy constraint: Member must satisfy enum value set: [{', '.join(DEFAULT_GATEWAY_RESPONSES)}]",
+ )
+
+ if not rest_api_container.gateway_responses.pop(response_type, None):
+ raise NotFoundException("Gateway response type not defined on api")
+
+ def update_gateway_response(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ response_type: GatewayResponseType,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> GatewayResponse:
+ """
+ Support operations table:
+ Path | op:add | op:replace | op:remove | op:copy
+ /statusCode | Not supported | Supported | Not supported | Not supported
+ /responseParameters | Supported | Supported | Supported | Not supported
+ /responseTemplates | Supported | Supported | Supported | Not supported
+ See https://docs.aws.amazon.com/apigateway/latest/api/patch-operations.html#UpdateGatewayResponse-Patch
+ """
+ store = get_apigateway_store(context=context)
+ if not (rest_api_container := store.rest_apis.get(rest_api_id)):
+ raise NotFoundException(
+ f"Invalid API identifier specified {context.account_id}:{rest_api_id}"
+ )
+
+ if response_type not in DEFAULT_GATEWAY_RESPONSES:
+ raise CommonServiceException(
+ code="ValidationException",
+ message=f"1 validation error detected: Value '{response_type}' at 'responseType' failed to satisfy constraint: Member must satisfy enum value set: [{', '.join(DEFAULT_GATEWAY_RESPONSES)}]",
+ )
+
+ if response_type not in rest_api_container.gateway_responses:
+ # deep copy to avoid in place mutation of the default response when update using JSON patch
+ rest_api_container.gateway_responses[response_type] = copy.deepcopy(
+ DEFAULT_GATEWAY_RESPONSES[response_type]
+ )
+ rest_api_container.gateway_responses[response_type]["defaultResponse"] = False
+
+ patched_entity = rest_api_container.gateway_responses[response_type]
+
+ for index, operation in enumerate(patch_operations):
+ if (op := operation.get("op")) not in VALID_PATCH_OPERATIONS:
+ raise CommonServiceException(
+ code="ValidationException",
+ message=f"1 validation error detected: Value '{op}' at 'updateGatewayResponseInput.patchOperations.{index + 1}.member.op' failed to satisfy constraint: Member must satisfy enum value set: [{', '.join(VALID_PATCH_OPERATIONS)}]",
+ )
+
+ path = operation.get("path", "null")
+ if not any(
+ path.startswith(s_path)
+ for s_path in ("/statusCode", "/responseParameters", "/responseTemplates")
+ ):
+ raise BadRequestException(f"Invalid patch path {path}")
+
+ if op in ("add", "remove") and path == "/statusCode":
+ raise BadRequestException(f"Invalid patch path {path}")
+
+ elif op in ("add", "replace"):
+ for param_type in ("responseParameters", "responseTemplates"):
+ if path.startswith(f"/{param_type}"):
+ if op == "replace":
+ param = path.removeprefix(f"/{param_type}/")
+ param = param.replace("~1", "/")
+ if param not in patched_entity.get(param_type):
+ raise NotFoundException("Invalid parameter name specified")
+ if operation.get("value") is None:
+ raise BadRequestException(
+ f"Invalid null or empty value in {param_type}"
+ )
+
+ _patch_api_gateway_entity(patched_entity, patch_operations)
+
+ return patched_entity
+
+ # TODO
+
+
+# ---------------
+# UTIL FUNCTIONS
+# ---------------
+
+
+def remove_empty_attributes_from_rest_api(rest_api: RestApi, remove_tags=True) -> RestApi:
+ if not rest_api.get("binaryMediaTypes"):
+ rest_api.pop("binaryMediaTypes", None)
+
+ if not isinstance(rest_api.get("minimumCompressionSize"), int):
+ rest_api.pop("minimumCompressionSize", None)
+
+ if not rest_api.get("tags"):
+ if remove_tags:
+ rest_api.pop("tags", None)
+ else:
+ # if `tags` is falsy, set it to an empty dict
+ rest_api["tags"] = {}
+
+ if not rest_api.get("version"):
+ rest_api.pop("version", None)
+ if not rest_api.get("description"):
+ rest_api.pop("description", None)
+
+ return rest_api
+
+
+def remove_empty_attributes_from_method(method: Method) -> Method:
+ if not method.get("methodResponses"):
+ method.pop("methodResponses", None)
+
+ if method.get("requestModels") is None:
+ method.pop("requestModels", None)
+
+ if method.get("requestParameters") is None:
+ method.pop("requestParameters", None)
+
+ return method
+
+
+def remove_empty_attributes_from_integration(integration: Integration):
+ if not integration:
+ return integration
+
+ if not integration.get("integrationResponses"):
+ integration.pop("integrationResponses", None)
+
+ if integration.get("requestParameters") is None:
+ integration.pop("requestParameters", None)
+
+ return integration
+
+
+def remove_empty_attributes_from_model(model: Model) -> Model:
+ if not model.get("description"):
+ model.pop("description", None)
+
+ return model
+
+
+def remove_empty_attributes_from_integration_response(integration_response: IntegrationResponse):
+ if integration_response.get("responseTemplates") is None:
+ integration_response.pop("responseTemplates", None)
+
+ return integration_response
+
+
+def fix_throttle_and_quota_from_usage_plan(usage_plan: UsagePlan) -> None:
+ if quota := usage_plan.get("quota"):
+ if "offset" not in quota:
+ quota["offset"] = 0
+ else:
+ usage_plan.pop("quota", None)
+
+ if throttle := usage_plan.get("throttle"):
+ if rate_limit := throttle.get("rateLimit"):
+ throttle["rateLimit"] = float(rate_limit)
+
+ if burst_limit := throttle.get("burstLimit"):
+ throttle["burstLimit"] = int(burst_limit)
+ else:
+ usage_plan.pop("throttle", None)
+
+
+def validate_model_in_use(moto_rest_api: MotoRestAPI, model_name: str) -> None:
+ for resource in moto_rest_api.resources.values():
+ for method in resource.resource_methods.values():
+ if method.request_models and model_name in set(method.request_models.values()):
+ path = f"{resource.get_path()}/{method.http_method}"
+ raise ConflictException(
+ f"Cannot delete model '{model_name}', is referenced in method request: {path}"
+ )
+
+
+def get_moto_rest_api_root_resource(moto_rest_api: MotoRestAPI) -> str:
+ for res_id, res_obj in moto_rest_api.resources.items():
+ if res_obj.path_part == "/" and not res_obj.parent_id:
+ return res_id
+ raise Exception(f"Unable to find root resource for API {moto_rest_api.id}")
+
+
+def create_custom_context(
+ context: RequestContext, action: str, parameters: ServiceRequest
+) -> RequestContext:
+ ctx = create_aws_request_context(
+ service_name=context.service.service_name,
+ action=action,
+ parameters=parameters,
+ region=context.region,
+ )
+ ctx.request.headers.update(context.request.headers)
+ ctx.account_id = context.account_id
+ return ctx
+
+
+def _patch_api_gateway_entity(entity: Any, patch_operations: ListOfPatchOperation):
+ patch_operations = patch_operations or []
+
+ if isinstance(entity, dict):
+ entity_dict = entity
+ else:
+ if not isinstance(entity.__dict__, DelSafeDict):
+ entity.__dict__ = DelSafeDict(entity.__dict__)
+ entity_dict = entity.__dict__
+
+ not_supported_attributes = {"/id", "/region_name", "/create_date"}
+
+ model_attributes = list(entity_dict.keys())
+ for operation in patch_operations:
+ path_start = operation["path"].strip("/").split("/")[0]
+ path_start_usc = camelcase_to_underscores(path_start)
+ if path_start not in model_attributes and path_start_usc in model_attributes:
+ operation["path"] = operation["path"].replace(path_start, path_start_usc)
+ if operation["path"] in not_supported_attributes:
+ raise BadRequestException(f"Invalid patch path {operation['path']}")
+
+ apply_json_patch_safe(entity_dict, patch_operations, in_place=True)
+
+
+def to_authorizer_response_json(api_id, data):
+ result = to_response_json("authorizer", data, api_id=api_id)
+ result = select_from_typed_dict(Authorizer, result)
+ return result
+
+
+def to_validator_response_json(api_id, data):
+ result = to_response_json("validator", data, api_id=api_id)
+ result = select_from_typed_dict(RequestValidator, result)
+ return result
+
+
+def to_documentation_part_response_json(api_id, data):
+ result = to_response_json("documentationpart", data, api_id=api_id)
+ result = select_from_typed_dict(DocumentationPart, result)
+ return result
+
+
+def to_base_mapping_response_json(domain_name, base_path, data):
+ self_link = "/domainnames/%s/basepathmappings/%s" % (domain_name, base_path)
+ result = to_response_json("basepathmapping", data, self_link=self_link)
+ result = select_from_typed_dict(BasePathMapping, result)
+ return result
+
+
+def to_account_response_json(data):
+ result = to_response_json("account", data, self_link="/account")
+ result = select_from_typed_dict(Account, result)
+ return result
+
+
+def to_vpc_link_response_json(data):
+ result = to_response_json("vpclink", data)
+ result = select_from_typed_dict(VpcLink, result)
+ return result
+
+
+def to_client_cert_response_json(data):
+ result = to_response_json("clientcertificate", data, id_attr="clientCertificateId")
+ result = select_from_typed_dict(ClientCertificate, result)
+ return result
+
+
+def to_rest_api_response_json(data):
+ result = to_response_json("restapi", data)
+ result = select_from_typed_dict(RestApi, result)
+ return result
+
+
+def to_response_json(model_type, data, api_id=None, self_link=None, id_attr=None):
+ if isinstance(data, list) and len(data) == 1:
+ data = data[0]
+ id_attr = id_attr or "id"
+ result = deepcopy(data)
+ if not self_link:
+ self_link = "/%ss/%s" % (model_type, data[id_attr])
+ if api_id:
+ self_link = "/restapis/%s/%s" % (api_id, self_link)
+ # TODO: check if this is still required - "_links" are listed in the sample responses in the docs, but
+ # recent parity tests indicate that this field is not returned by real AWS...
+ # https://docs.aws.amazon.com/apigateway/latest/api/API_GetAuthorizers.html#API_GetAuthorizers_Example_1_Response
+ if "_links" not in result:
+ result["_links"] = {}
+ result["_links"]["self"] = {"href": self_link}
+ result["_links"]["curies"] = {
+ "href": "https://docs.aws.amazon.com/apigateway/latest/developerguide/restapi-authorizer-latest.html",
+ "name": model_type,
+ "templated": True,
+ }
+ result["_links"]["%s:delete" % model_type] = {"href": self_link}
+ return result
+
+
+DEFAULT_EMPTY_MODEL = Model(
+ id=short_uid()[:6],
+ name=EMPTY_MODEL,
+ contentType="application/json",
+ description="This is a default empty schema model",
+ schema=json.dumps(
+ {
+ "$schema": "http://json-schema.org/draft-04/schema#",
+ "title": "Empty Schema",
+ "type": "object",
+ }
+ ),
+)
+
+DEFAULT_ERROR_MODEL = Model(
+ id=short_uid()[:6],
+ name=ERROR_MODEL,
+ contentType="application/json",
+ description="This is a default error schema model",
+ schema=json.dumps(
+ {
+ "$schema": "http://json-schema.org/draft-04/schema#",
+ "title": "Error Schema",
+ "type": "object",
+ "properties": {"message": {"type": "string"}},
+ }
+ ),
+)
+
+
+# TODO: maybe extract this in its own files, or find a better generalizable way
+UPDATE_METHOD_PATCH_PATHS = {
+ "supported_paths": [
+ "/authorizationScopes",
+ "/authorizationType",
+ "/authorizerId",
+ "/apiKeyRequired",
+ "/operationName",
+ "/requestParameters/",
+ "/requestModels/",
+ "/requestValidatorId",
+ ],
+ "add": [
+ "/authorizationScopes",
+ "/requestParameters/",
+ "/requestModels/",
+ ],
+ "remove": [
+ "/authorizationScopes",
+ "/requestParameters/",
+ "/requestModels/",
+ ],
+ "replace": [
+ "/authorizationType",
+ "/authorizerId",
+ "/apiKeyRequired",
+ "/operationName",
+ "/requestParameters/",
+ "/requestModels/",
+ "/requestValidatorId",
+ ],
+}
+
+DEFAULT_GATEWAY_RESPONSES: dict[GatewayResponseType, GatewayResponse] = {
+ GatewayResponseType.REQUEST_TOO_LARGE: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "REQUEST_TOO_LARGE",
+ "statusCode": "413",
+ },
+ GatewayResponseType.RESOURCE_NOT_FOUND: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "RESOURCE_NOT_FOUND",
+ "statusCode": "404",
+ },
+ GatewayResponseType.AUTHORIZER_CONFIGURATION_ERROR: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "AUTHORIZER_CONFIGURATION_ERROR",
+ "statusCode": "500",
+ },
+ GatewayResponseType.MISSING_AUTHENTICATION_TOKEN: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "MISSING_AUTHENTICATION_TOKEN",
+ "statusCode": "403",
+ },
+ GatewayResponseType.BAD_REQUEST_BODY: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "BAD_REQUEST_BODY",
+ "statusCode": "400",
+ },
+ GatewayResponseType.INVALID_SIGNATURE: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "INVALID_SIGNATURE",
+ "statusCode": "403",
+ },
+ GatewayResponseType.INVALID_API_KEY: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "INVALID_API_KEY",
+ "statusCode": "403",
+ },
+ GatewayResponseType.BAD_REQUEST_PARAMETERS: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "BAD_REQUEST_PARAMETERS",
+ "statusCode": "400",
+ },
+ GatewayResponseType.AUTHORIZER_FAILURE: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "AUTHORIZER_FAILURE",
+ "statusCode": "500",
+ },
+ GatewayResponseType.UNAUTHORIZED: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "UNAUTHORIZED",
+ "statusCode": "401",
+ },
+ GatewayResponseType.INTEGRATION_TIMEOUT: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "INTEGRATION_TIMEOUT",
+ "statusCode": "504",
+ },
+ GatewayResponseType.ACCESS_DENIED: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "ACCESS_DENIED",
+ "statusCode": "403",
+ },
+ GatewayResponseType.DEFAULT_4XX: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "DEFAULT_4XX",
+ },
+ GatewayResponseType.DEFAULT_5XX: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "DEFAULT_5XX",
+ },
+ GatewayResponseType.WAF_FILTERED: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "WAF_FILTERED",
+ "statusCode": "403",
+ },
+ GatewayResponseType.QUOTA_EXCEEDED: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "QUOTA_EXCEEDED",
+ "statusCode": "429",
+ },
+ GatewayResponseType.THROTTLED: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "THROTTLED",
+ "statusCode": "429",
+ },
+ GatewayResponseType.API_CONFIGURATION_ERROR: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "API_CONFIGURATION_ERROR",
+ "statusCode": "500",
+ },
+ GatewayResponseType.UNSUPPORTED_MEDIA_TYPE: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "UNSUPPORTED_MEDIA_TYPE",
+ "statusCode": "415",
+ },
+ GatewayResponseType.INTEGRATION_FAILURE: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "INTEGRATION_FAILURE",
+ "statusCode": "504",
+ },
+ GatewayResponseType.EXPIRED_TOKEN: {
+ "defaultResponse": True,
+ "responseParameters": {},
+ "responseTemplates": {"application/json": '{"message":$context.error.messageString}'},
+ "responseType": "EXPIRED_TOKEN",
+ "statusCode": "403",
+ },
+}
+
+VALID_PATCH_OPERATIONS = ["add", "remove", "move", "test", "replace", "copy"]
diff --git a/localstack-core/localstack/services/apigateway/legacy/router_asf.py b/localstack-core/localstack/services/apigateway/legacy/router_asf.py
new file mode 100644
index 0000000000000..0664c98c56f20
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/legacy/router_asf.py
@@ -0,0 +1,160 @@
+import logging
+from typing import Any, Dict
+
+from requests.models import Response as RequestsResponse
+from werkzeug.datastructures import Headers
+from werkzeug.exceptions import NotFound
+
+from localstack.constants import HEADER_LOCALSTACK_EDGE_URL
+from localstack.http import Request, Response, Router
+from localstack.http.dispatcher import Handler
+from localstack.http.request import restore_payload
+from localstack.services.apigateway.legacy.context import ApiInvocationContext
+from localstack.services.apigateway.legacy.helpers import get_api_account_id_and_region
+from localstack.services.apigateway.legacy.invocations import invoke_rest_api_from_request
+from localstack.utils.aws.aws_responses import LambdaResponse
+from localstack.utils.strings import remove_leading_extra_slashes
+
+LOG = logging.getLogger(__name__)
+
+
+# TODO: with the latest snapshot tests, we might start moving away from the
+# invocation context property decorators and use the url_params directly,
+# something asked for a long time.
+def to_invocation_context(
+ request: Request, url_params: Dict[str, Any] = None
+) -> ApiInvocationContext:
+ """
+ Converts an HTTP Request object into an ApiInvocationContext.
+
+ :param request: the original request
+ :param url_params: the parameters extracted from the URL matching rules
+ :return: the ApiInvocationContext
+ """
+ if url_params is None:
+ url_params = {}
+
+ method = request.method
+ # Base path is not URL-decoded.
+ # Example: test%2Balias@gmail.com => test%2Balias@gmail.com
+ raw_uri = path = request.environ.get("RAW_URI")
+ if raw_uri.startswith("//"):
+ # if starts with //, then replace the first // with /
+ path = remove_leading_extra_slashes(raw_uri)
+
+ data = restore_payload(request)
+ headers = Headers(request.headers)
+
+ # TODO: verify that this is needed
+ # adjust the X-Forwarded-For header
+ x_forwarded_for = headers.getlist("X-Forwarded-For")
+ x_forwarded_for.append(request.remote_addr)
+ x_forwarded_for.append(request.host)
+ headers["X-Forwarded-For"] = ", ".join(x_forwarded_for)
+
+ # set the x-localstack-edge header, it is used to parse the domain
+ headers[HEADER_LOCALSTACK_EDGE_URL] = request.host_url.strip("/")
+
+ # FIXME: Use the already parsed url params instead of parsing them into the ApiInvocationContext part-by-part.
+ # We already would have all params at hand to avoid _all_ the parsing, but the parsing
+ # has side-effects (f.e. setting the region in a thread local)!
+ # It would be best to use a small (immutable) context for the already parsed params and the Request object
+ # and use it everywhere.
+ ctx = ApiInvocationContext(method, path, data, headers, stage=url_params.get("stage"))
+ ctx.raw_uri = raw_uri
+ ctx.auth_identity["sourceIp"] = request.remote_addr
+
+ return ctx
+
+
+def convert_response(result: RequestsResponse) -> Response:
+ """
+ Utility function to convert a response for the requests library to our internal (Werkzeug based) Response object.
+ """
+ if result is None:
+ return Response()
+
+ if isinstance(result, LambdaResponse):
+ headers = Headers(dict(result.headers))
+ for k, values in result.multi_value_headers.items():
+ for value in values:
+ headers.add(k, value)
+ else:
+ headers = dict(result.headers)
+
+ response = Response(status=result.status_code, headers=headers)
+
+ if isinstance(result.content, dict):
+ response.set_json(result.content)
+ elif isinstance(result.content, (str, bytes)):
+ response.data = result.content
+ else:
+ raise ValueError(f"Unhandled content type {type(result.content)}")
+
+ return response
+
+
+class ApigatewayRouter:
+ """
+ Simple implementation around a Router to manage dynamic restapi routes (routes added by a user through the
+ apigateway API).
+ """
+
+ router: Router[Handler]
+
+ def __init__(self, router: Router[Handler]):
+ self.router = router
+ self.registered = False
+
+ def register_routes(self) -> None:
+ """Registers parameterized routes for API Gateway user invocations."""
+ if self.registered:
+ LOG.debug("Skipped API Gateway route registration (routes already registered).")
+ return
+ self.registered = True
+ LOG.debug("Registering parameterized API Gateway routes.")
+ host_pattern = ".execute-api."
+ self.router.add(
+ "/",
+ host=host_pattern,
+ endpoint=self.invoke_rest_api,
+ defaults={"path": "", "stage": None},
+ strict_slashes=True,
+ )
+ self.router.add(
+ "//",
+ host=host_pattern,
+ endpoint=self.invoke_rest_api,
+ defaults={"path": ""},
+ strict_slashes=False,
+ )
+ self.router.add(
+ "//",
+ host=host_pattern,
+ endpoint=self.invoke_rest_api,
+ strict_slashes=True,
+ )
+
+ # add the localstack-specific _user_request_ routes
+ self.router.add(
+ "/restapis///_user_request_",
+ endpoint=self.invoke_rest_api,
+ defaults={"path": ""},
+ )
+ self.router.add(
+ "/restapis///_user_request_/",
+ endpoint=self.invoke_rest_api,
+ strict_slashes=True,
+ )
+
+ def invoke_rest_api(self, request: Request, **url_params: str) -> Response:
+ account_id, region_name = get_api_account_id_and_region(url_params["api_id"])
+ if not region_name:
+ return Response(status=404)
+ invocation_context = to_invocation_context(request, url_params)
+ invocation_context.region_name = region_name
+ invocation_context.account_id = account_id
+ result = invoke_rest_api_from_request(invocation_context)
+ if result is not None:
+ return convert_response(result)
+ raise NotFound()
diff --git a/localstack-core/localstack/services/apigateway/legacy/templates.py b/localstack-core/localstack/services/apigateway/legacy/templates.py
new file mode 100644
index 0000000000000..0ae853981ac02
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/legacy/templates.py
@@ -0,0 +1,381 @@
+import base64
+import copy
+import json
+import logging
+from enum import Enum
+from typing import Any, Dict, Union
+from urllib.parse import quote_plus, unquote_plus
+
+import xmltodict
+
+from localstack import config
+from localstack.constants import APPLICATION_JSON, APPLICATION_XML
+from localstack.services.apigateway.legacy.context import ApiInvocationContext
+from localstack.services.apigateway.legacy.helpers import select_integration_response
+from localstack.utils.aws.templating import APIGW_SOURCE, VelocityUtil, VtlTemplate
+from localstack.utils.json import extract_jsonpath, json_safe, try_json
+from localstack.utils.strings import to_str
+
+LOG = logging.getLogger(__name__)
+
+
+class PassthroughBehavior(Enum):
+ WHEN_NO_MATCH = "WHEN_NO_MATCH"
+ WHEN_NO_TEMPLATES = "WHEN_NO_TEMPLATES"
+ NEVER = "NEVER"
+
+
+class MappingTemplates:
+ """
+ API Gateway uses mapping templates to transform incoming requests before they are sent to the
+ integration back end. With API Gateway, you can define one mapping template for each possible
+ content type. The content type selection is based on the Content-Type header of the incoming
+ request. If no content type is specified in the request, API Gateway uses an application/json
+ mapping template. By default, mapping templates are configured to simply pass through the
+ request input. Mapping templates use Apache Velocity to generate a request to your back end.
+ """
+
+ passthrough_behavior: PassthroughBehavior
+
+ class UnsupportedMediaType(Exception):
+ pass
+
+ def __init__(self, passthrough_behaviour: str):
+ self.passthrough_behavior = self.get_passthrough_behavior(passthrough_behaviour)
+
+ def check_passthrough_behavior(self, request_template):
+ """
+ Specifies how the method request body of an unmapped content type will be passed through
+ the integration request to the back end without transformation.
+ A content type is unmapped if no mapping template is defined in the integration or the
+ content type does not match any of the mapped content types, as specified in requestTemplates
+ """
+ if not request_template and self.passthrough_behavior in {
+ PassthroughBehavior.NEVER,
+ PassthroughBehavior.WHEN_NO_TEMPLATES,
+ }:
+ raise MappingTemplates.UnsupportedMediaType()
+
+ @staticmethod
+ def get_passthrough_behavior(passthrough_behaviour: str):
+ return getattr(PassthroughBehavior, passthrough_behaviour, None)
+
+
+class AttributeDict(dict):
+ """
+ Wrapper returned by VelocityUtilApiGateway.parseJson to allow access to dict values as attributes (dot notation),
+ e.g.: $util.parseJson('$.foo').bar
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(AttributeDict, self).__init__(*args, **kwargs)
+ for key, value in self.items():
+ if isinstance(value, dict):
+ self[key] = AttributeDict(value)
+
+ def __getattr__(self, name):
+ if name in self:
+ return self[name]
+ raise AttributeError(f"'AttributeDict' object has no attribute '{name}'")
+
+ def __setattr__(self, name, value):
+ self[name] = value
+
+ def __delattr__(self, name):
+ if name in self:
+ del self[name]
+ else:
+ raise AttributeError(f"'AttributeDict' object has no attribute '{name}'")
+
+
+class VelocityUtilApiGateway(VelocityUtil):
+ """
+ Simple class to mimic the behavior of variable '$util' in AWS API Gateway integration
+ velocity templates.
+ See: https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html
+ """
+
+ def base64Encode(self, s):
+ if not isinstance(s, str):
+ s = json.dumps(s)
+ encoded_str = s.encode(config.DEFAULT_ENCODING)
+ encoded_b64_str = base64.b64encode(encoded_str)
+ return encoded_b64_str.decode(config.DEFAULT_ENCODING)
+
+ def base64Decode(self, s):
+ if not isinstance(s, str):
+ s = json.dumps(s)
+ return base64.b64decode(s)
+
+ def toJson(self, obj):
+ return obj and json.dumps(obj)
+
+ def urlEncode(self, s):
+ return quote_plus(s)
+
+ def urlDecode(self, s):
+ return unquote_plus(s)
+
+ def escapeJavaScript(self, obj: Any) -> str:
+ """
+ Converts the given object to a string and escapes any regular single quotes (') into escaped ones (\').
+ JSON dumps will escape the single quotes.
+ https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html
+ """
+ if obj is None:
+ return "null"
+ if isinstance(obj, str):
+ # empty string escapes to empty object
+ if len(obj.strip()) == 0:
+ return "{}"
+ return json.dumps(obj)[1:-1]
+ if obj in (True, False):
+ return str(obj).lower()
+ return str(obj)
+
+ def parseJson(self, s: str):
+ obj = json.loads(s)
+ return AttributeDict(obj) if isinstance(obj, dict) else obj
+
+
+class VelocityInput:
+ """
+ Simple class to mimic the behavior of variable '$input' in AWS API Gateway integration
+ velocity templates.
+ See: http://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html
+ """
+
+ def __init__(self, body, params):
+ self.parameters = params or {}
+ self.value = body
+
+ def path(self, path):
+ if not self.value:
+ return {}
+ value = self.value if isinstance(self.value, dict) else json.loads(self.value)
+ return extract_jsonpath(value, path)
+
+ def json(self, path):
+ path = path or "$"
+ matching = self.path(path)
+ if isinstance(matching, (list, dict)):
+ matching = json_safe(matching)
+ return json.dumps(matching)
+
+ @property
+ def body(self):
+ return self.value
+
+ def params(self, name=None):
+ if not name:
+ return self.parameters
+ for k in ["path", "querystring", "header"]:
+ if val := self.parameters.get(k).get(name):
+ return val
+ return ""
+
+ def __getattr__(self, name):
+ return self.value.get(name)
+
+ def __repr__(self):
+ return "$input"
+
+
+class ApiGatewayVtlTemplate(VtlTemplate):
+ """Util class for rendering VTL templates with API Gateway specific extensions"""
+
+ def prepare_namespace(self, variables, source: str = APIGW_SOURCE) -> Dict[str, Any]:
+ namespace = super().prepare_namespace(variables, source)
+ if stage_var := variables.get("stage_variables") or {}:
+ namespace["stageVariables"] = stage_var
+ input_var = variables.get("input") or {}
+ variables = {
+ "input": VelocityInput(input_var.get("body"), input_var.get("params")),
+ "util": VelocityUtilApiGateway(),
+ }
+ namespace.update(variables)
+ return namespace
+
+
+class Templates:
+ __slots__ = ["vtl"]
+
+ def __init__(self):
+ self.vtl = ApiGatewayVtlTemplate()
+
+ def render(self, api_context: ApiInvocationContext) -> Union[bytes, str]:
+ pass
+
+ def render_vtl(self, template, variables):
+ return self.vtl.render_vtl(template, variables=variables)
+
+ @staticmethod
+ def build_variables_mapping(api_context: ApiInvocationContext) -> dict[str, Any]:
+ # TODO: make this (dict) an object so usages of "render_vtl" variables are defined
+ ctx = copy.deepcopy(api_context.context or {})
+ # https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-override-request-response-parameters.html
+ # create namespace for request override
+ ctx["requestOverride"] = {
+ "header": {},
+ "path": {},
+ "querystring": {},
+ }
+
+ ctx["responseOverride"] = {
+ "header": {},
+ "status": 200,
+ }
+
+ return {
+ "context": ctx,
+ "stage_variables": api_context.stage_variables or {},
+ "input": {
+ "body": api_context.data_as_string(),
+ "params": {
+ "path": api_context.path_params,
+ "querystring": api_context.query_params(),
+ # Sometimes we get a werkzeug.datastructures.Headers object, sometimes a dict
+ # depending on the request. We need to convert to a dict to be able to render
+ # the template.
+ "header": dict(api_context.headers),
+ },
+ },
+ }
+
+
+class RequestTemplates(Templates):
+ """
+ Handles request template rendering
+ """
+
+ def render(
+ self, api_context: ApiInvocationContext, template_key: str = APPLICATION_JSON
+ ) -> Union[bytes, str]:
+ LOG.debug(
+ "Method request body before transformations: %s", to_str(api_context.data_as_string())
+ )
+ request_templates = api_context.integration.get("requestTemplates", {})
+ template = request_templates.get(template_key)
+ if not template:
+ return api_context.data_as_string()
+
+ variables = self.build_variables_mapping(api_context)
+ result = self.render_vtl(template.strip(), variables=variables)
+
+ # set the request overrides into context
+ api_context.headers.update(
+ variables.get("context", {}).get("requestOverride", {}).get("header", {})
+ )
+
+ LOG.debug("Endpoint request body after transformations:\n%s", result)
+ return result
+
+
+class ResponseTemplates(Templates):
+ """
+ Handles response template rendering. The integration response status code is used to select
+ the correct template to render, if there is no template for the status code, the default
+ template is used.
+ """
+
+ def render(self, api_context: ApiInvocationContext, **kwargs) -> Union[bytes, str]:
+ # XXX: keep backwards compatibility until we migrate all integrations to this new classes
+ # api_context contains a response object that we want slowly remove from it
+ data = kwargs.get("response", "")
+ response = data or api_context.response
+ integration = api_context.integration
+ # we set context data with the response content because later on we use context data as
+ # the body field in the template. We need to improve this by using the right source
+ # depending on the type of templates.
+ api_context.data = response._content
+
+ # status code returned by the integration
+ status_code = str(response.status_code)
+
+ # get the integration responses configuration from the integration object
+ integration_responses = integration.get("integrationResponses")
+ if not integration_responses:
+ return response._content
+
+ # get the configured integration response status codes,
+ # e.g. ["200", "400", "500"]
+ integration_status_codes = [str(code) for code in list(integration_responses.keys())]
+ # if there are no integration responses, we return the response as is
+ if not integration_status_codes:
+ return response.content
+
+ # The following code handles two use cases.If there is an integration response for the status code returned
+ # by the integration, we use the template configured for that status code (1) or the errorMessage (2) for
+ # lambda integrations.
+ # For an HTTP integration, API Gateway matches the regex to the HTTP status code to return
+ # For a Lambda function, API Gateway matches the regex to the errorMessage header to
+ # return a status code.
+ # For example, to set a 400 response for any error that starts with Malformed,
+ # set the method response status code to 400 and the Lambda error regex to Malformed.*.
+ match_resp = status_code
+ if isinstance(try_json(response._content), dict):
+ resp_dict = try_json(response._content)
+ if "errorMessage" in resp_dict:
+ match_resp = resp_dict.get("errorMessage")
+
+ selected_integration_response = select_integration_response(match_resp, api_context)
+ response.status_code = int(selected_integration_response.get("statusCode", 200))
+ response_templates = selected_integration_response.get("responseTemplates", {})
+
+ # we only support JSON and XML templates for now - if there is no template we return the response as is
+ # If the content type is not supported we always use application/json as default value
+ # TODO - support other content types, besides application/json and application/xml
+ # see https://docs.aws.amazon.com/apigateway/latest/developerguide/request-response-data-mappings.html#selecting-mapping-templates
+ accept = api_context.headers.get("accept", APPLICATION_JSON)
+ supported_types = [APPLICATION_JSON, APPLICATION_XML]
+ media_type = accept if accept in supported_types else APPLICATION_JSON
+ if not (template := response_templates.get(media_type, {})):
+ return response._content
+
+ # we render the template with the context data and the response content
+ variables = self.build_variables_mapping(api_context)
+ # update the response body
+ response._content = self._render_as_text(template, variables)
+ if media_type == APPLICATION_JSON:
+ self._validate_json(response.content)
+ elif media_type == APPLICATION_XML:
+ self._validate_xml(response.content)
+
+ if response_overrides := variables.get("context", {}).get("responseOverride", {}):
+ response.headers.update(response_overrides.get("header", {}).items())
+ response.status_code = response_overrides.get("status", 200)
+
+ LOG.debug("Endpoint response body after transformations:\n%s", response._content)
+ return response._content
+
+ def _render_as_text(self, template: str, variables: dict[str, Any]) -> str:
+ """
+ Render the given Velocity template string + variables into a plain string.
+ :return: the template rendering result as a string
+ """
+ rendered_tpl = self.render_vtl(template, variables=variables)
+ return rendered_tpl.strip()
+
+ @staticmethod
+ def _validate_json(content: str):
+ """
+ Checks that the content received is a valid JSON.
+ :raise JSONDecodeError: if content is not valid JSON
+ """
+ try:
+ json.loads(content)
+ except Exception as e:
+ LOG.info("Unable to parse template result as JSON: %s - %s", e, content)
+ raise
+
+ @staticmethod
+ def _validate_xml(content: str):
+ """
+ Checks that the content received is a valid XML.
+ :raise xml.parsers.expat.ExpatError: if content is not valid XML
+ """
+ try:
+ xmltodict.parse(content)
+ except Exception as e:
+ LOG.info("Unable to parse template result as XML: %s - %s", e, content)
+ raise
diff --git a/localstack-core/localstack/services/apigateway/models.py b/localstack-core/localstack/services/apigateway/models.py
new file mode 100644
index 0000000000000..44fca6b65ae29
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/models.py
@@ -0,0 +1,155 @@
+from typing import Any, Dict, List
+
+from requests.structures import CaseInsensitiveDict
+
+from localstack.aws.api.apigateway import (
+ Authorizer,
+ DocumentationPart,
+ DocumentationVersion,
+ DomainName,
+ GatewayResponse,
+ GatewayResponseType,
+ Model,
+ RequestValidator,
+ Resource,
+ RestApi,
+)
+from localstack.services.stores import (
+ AccountRegionBundle,
+ BaseStore,
+ CrossAccountAttribute,
+ CrossRegionAttribute,
+ LocalAttribute,
+)
+from localstack.utils.aws import arns
+
+
+class RestApiContainer:
+ # contains the RestApi dictionary. We're not making use of it yet, still using moto data.
+ rest_api: RestApi
+ # maps AuthorizerId -> Authorizer
+ authorizers: Dict[str, Authorizer]
+ # maps RequestValidatorId -> RequestValidator
+ validators: Dict[str, RequestValidator]
+ # map DocumentationPartId -> DocumentationPart
+ documentation_parts: Dict[str, DocumentationPart]
+ # map doc version name -> DocumentationVersion
+ documentation_versions: Dict[str, DocumentationVersion]
+ # not used yet, still in moto
+ gateway_responses: Dict[GatewayResponseType, GatewayResponse]
+ # maps Model name -> Model
+ models: Dict[str, Model]
+ # maps Model name -> resolved dict Model, so we don't need to load the JSON everytime
+ resolved_models: Dict[str, dict]
+ # maps ResourceId of a Resource to its children ResourceIds
+ resource_children: Dict[str, List[str]]
+
+ def __init__(self, rest_api: RestApi):
+ self.rest_api = rest_api
+ self.authorizers = {}
+ self.validators = {}
+ self.documentation_parts = {}
+ self.documentation_versions = {}
+ self.gateway_responses = {}
+ self.models = {}
+ self.resolved_models = {}
+ self.resource_children = {}
+
+
+class MergedRestApi(RestApiContainer):
+ """Merged REST API between Moto data and LocalStack data, used in our Invocation logic"""
+
+ # TODO: when migrating away from Moto, RestApiContainer and MergedRestApi will have the same signature, so we can
+ # safely remove it and only use RestApiContainer in our invocation logic
+ resources: dict[str, Resource]
+
+ def __init__(self, rest_api: RestApi):
+ super().__init__(rest_api)
+ self.resources = {}
+
+ @classmethod
+ def from_rest_api_container(
+ cls,
+ rest_api_container: RestApiContainer,
+ resources: dict[str, Resource],
+ ) -> "MergedRestApi":
+ merged = cls(rest_api=rest_api_container.rest_api)
+ merged.authorizers = rest_api_container.authorizers
+ merged.validators = rest_api_container.validators
+ merged.documentation_parts = rest_api_container.documentation_parts
+ merged.documentation_versions = rest_api_container.documentation_versions
+ merged.gateway_responses = rest_api_container.gateway_responses
+ merged.models = rest_api_container.models
+ merged.resolved_models = rest_api_container.resolved_models
+ merged.resource_children = rest_api_container.resource_children
+ merged.resources = resources
+
+ return merged
+
+
+class RestApiDeployment:
+ def __init__(
+ self,
+ account_id: str,
+ region: str,
+ rest_api: MergedRestApi,
+ ):
+ self.rest_api = rest_api
+ self.account_id = account_id
+ self.region = region
+
+
+class ApiGatewayStore(BaseStore):
+ # maps (API id) -> RestApiContainer
+ # TODO: remove CaseInsensitiveDict, and lower the value of the ID when getting it from the tags
+ rest_apis: Dict[str, RestApiContainer] = LocalAttribute(default=CaseInsensitiveDict)
+
+ # account details
+ _account: Dict[str, Any] = LocalAttribute(default=dict)
+
+ # maps (domain_name) -> [path_mappings]
+ base_path_mappings: Dict[str, List[Dict]] = LocalAttribute(default=dict)
+
+ # maps ID to VPC link details
+ vpc_links: Dict[str, Dict] = LocalAttribute(default=dict)
+
+ # maps cert ID to client certificate details
+ client_certificates: Dict[str, Dict] = LocalAttribute(default=dict)
+
+ # maps domain name to domain name model
+ domain_names: Dict[str, DomainName] = LocalAttribute(default=dict)
+
+ # maps resource ARN to tags
+ TAGS: Dict[str, Dict[str, str]] = CrossRegionAttribute(default=dict)
+
+ # internal deployments, represents a frozen REST API for a deployment, used in our router
+ # TODO: make sure API ID are unique across all accounts
+ # maps ApiID to a map of deploymentId and RestApiDeployment, an executable/snapshot of a REST API
+ internal_deployments: dict[str, dict[str, RestApiDeployment]] = CrossAccountAttribute(
+ default=dict
+ )
+
+ # active deployments, mapping API ID to a map of Stage and deployment ID
+ # TODO: make sure API ID are unique across all accounts
+ active_deployments: dict[str, dict[str, str]] = CrossAccountAttribute(dict)
+
+ def __init__(self):
+ super().__init__()
+
+ @property
+ def account(self):
+ if not self._account:
+ self._account.update(
+ {
+ "cloudwatchRoleArn": arns.iam_role_arn(
+ "api-gw-cw-role", self._account_id, self._region_name
+ ),
+ "throttleSettings": {"burstLimit": 1000, "rateLimit": 500},
+ "features": ["UsagePlans"],
+ "apiKeyVersion": "1",
+ }
+ )
+ return self._account
+
+
+apigateway_stores = AccountRegionBundle("apigateway", ApiGatewayStore)
diff --git a/localstack/utils/aws/__init__.py b/localstack-core/localstack/services/apigateway/next_gen/__init__.py
similarity index 100%
rename from localstack/utils/aws/__init__.py
rename to localstack-core/localstack/services/apigateway/next_gen/__init__.py
diff --git a/localstack/utils/cloudformation/__init__.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/__init__.py
similarity index 100%
rename from localstack/utils/cloudformation/__init__.py
rename to localstack-core/localstack/services/apigateway/next_gen/execute_api/__init__.py
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/api.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/api.py
new file mode 100644
index 0000000000000..843938e0611ed
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/api.py
@@ -0,0 +1,17 @@
+from typing import Callable, Type
+
+from rolo import Response
+from rolo.gateway.chain import HandlerChain as RoloHandlerChain
+
+from .context import RestApiInvocationContext
+
+RestApiGatewayHandler = Callable[
+ [RoloHandlerChain[RestApiInvocationContext], RestApiInvocationContext, Response], None
+]
+
+RestApiGatewayExceptionHandler = Callable[
+ [RoloHandlerChain[RestApiInvocationContext], Exception, RestApiInvocationContext, Response],
+ None,
+]
+
+RestApiGatewayHandlerChain: Type[RoloHandlerChain[RestApiInvocationContext]] = RoloHandlerChain
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/context.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/context.py
new file mode 100644
index 0000000000000..03632d0829aaa
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/context.py
@@ -0,0 +1,131 @@
+from http import HTTPMethod
+from typing import Optional, TypedDict
+
+from rolo import Request
+from rolo.gateway import RequestContext
+from werkzeug.datastructures import Headers
+
+from localstack.aws.api.apigateway import Integration, Method, Resource
+from localstack.services.apigateway.models import RestApiDeployment
+
+from .variables import ContextVariables, LoggingContextVariables
+
+
+class InvocationRequest(TypedDict, total=False):
+ http_method: HTTPMethod
+ """HTTP Method of the incoming request"""
+ raw_path: Optional[str]
+ # TODO: verify if raw_path is needed
+ """Raw path of the incoming request with no modification, needed to keep double forward slashes"""
+ path: Optional[str]
+ """Path of the request with no URL decoding"""
+ path_parameters: Optional[dict[str, str]]
+ """Path parameters of the request"""
+ query_string_parameters: dict[str, str]
+ """Query string parameters of the request"""
+ headers: Headers
+ """Raw headers using the Headers datastructure which allows access with no regards to casing"""
+ multi_value_query_string_parameters: dict[str, list[str]]
+ """Multi value query string parameters of the request"""
+ body: bytes
+ """Body content of the request"""
+
+
+class IntegrationRequest(TypedDict, total=False):
+ http_method: HTTPMethod
+ """HTTP Method of the incoming request"""
+ uri: str
+ """URI of the integration"""
+ query_string_parameters: dict[str, str | list[str]]
+ """Query string parameters of the request"""
+ headers: Headers
+ """Headers of the request"""
+ body: bytes
+ """Body content of the request"""
+
+
+class BaseResponse(TypedDict):
+ """Base class for Response objects in the context"""
+
+ status_code: int
+ """Status code of the response"""
+ headers: Headers
+ """Headers of the response"""
+ body: bytes
+ """Body content of the response"""
+
+
+class EndpointResponse(BaseResponse):
+ """Represents the response coming from an integration, called Endpoint Response in AWS"""
+
+ pass
+
+
+class InvocationResponse(BaseResponse):
+ """Represents the response coming after being serialized in an Integration Response in AWS"""
+
+ pass
+
+
+class RestApiInvocationContext(RequestContext):
+ """
+ This context is going to be used to pass relevant information across an API Gateway invocation.
+ """
+
+ deployment: Optional[RestApiDeployment]
+ """Contains the invoked REST API Resources"""
+ integration: Optional[Integration]
+ """The Method Integration for the invoked request"""
+ api_id: Optional[str]
+ """The REST API identifier of the invoked API"""
+ stage: Optional[str]
+ """The REST API stage linked to this invocation"""
+ base_path: Optional[str]
+ """The REST API base path mapped to the stage of this invocation"""
+ deployment_id: Optional[str]
+ """The REST API deployment linked to this invocation"""
+ region: Optional[str]
+ """The region the REST API is living in."""
+ account_id: Optional[str]
+ """The account the REST API is living in."""
+ trace_id: Optional[str]
+ """The X-Ray trace ID for the request."""
+ resource: Optional[Resource]
+ """The resource the invocation matched"""
+ resource_method: Optional[Method]
+ """The method of the resource the invocation matched"""
+ stage_variables: Optional[dict[str, str]]
+ """The Stage variables, also used in parameters mapping and mapping templates"""
+ context_variables: Optional[ContextVariables]
+ """The $context used in data models, authorizers, mapping templates, and CloudWatch access logging"""
+ logging_context_variables: Optional[LoggingContextVariables]
+ """Additional $context variables available only for access logging, not yet implemented"""
+ invocation_request: Optional[InvocationRequest]
+ """Contains the data relative to the invocation request"""
+ integration_request: Optional[IntegrationRequest]
+ """Contains the data needed to construct an HTTP request to an Integration"""
+ endpoint_response: Optional[EndpointResponse]
+ """Contains the data returned by an Integration"""
+ invocation_response: Optional[InvocationResponse]
+ """Contains the data serialized and to be returned by an invocation"""
+
+ def __init__(self, request: Request):
+ super().__init__(request)
+ self.deployment = None
+ self.api_id = None
+ self.stage = None
+ self.base_path = None
+ self.deployment_id = None
+ self.account_id = None
+ self.region = None
+ self.invocation_request = None
+ self.resource = None
+ self.resource_method = None
+ self.integration = None
+ self.stage_variables = None
+ self.context_variables = None
+ self.logging_context_variables = None
+ self.integration_request = None
+ self.endpoint_response = None
+ self.invocation_response = None
+ self.trace_id = None
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/gateway.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/gateway.py
new file mode 100644
index 0000000000000..48b4b4eeacf42
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/gateway.py
@@ -0,0 +1,51 @@
+from rolo import Response
+from rolo.gateway import Gateway
+
+from . import handlers
+from .context import RestApiInvocationContext
+
+
+class RestApiGateway(Gateway):
+ """
+ This class controls the main path of an API Gateway REST API. It contains the definitions of the different handlers
+ to be called as part of the different steps of the invocation of the API.
+
+ For now, you can extend the behavior of the invocation by adding handlers to the `preprocess_request`
+ CompositeHandler.
+ The documentation of this class will be extended as more behavior will be added to its handlers, as well as more
+ ways to extend it.
+ """
+
+ def __init__(self):
+ super().__init__(context_class=RestApiInvocationContext)
+ self.request_handlers.extend(
+ [
+ handlers.parse_request,
+ handlers.modify_request,
+ handlers.route_request,
+ handlers.preprocess_request,
+ handlers.api_key_validation_handler,
+ handlers.method_request_handler,
+ handlers.integration_request_handler,
+ handlers.integration_handler,
+ handlers.integration_response_handler,
+ handlers.method_response_handler,
+ ]
+ )
+ self.exception_handlers.extend(
+ [
+ handlers.gateway_exception_handler,
+ ]
+ )
+ self.response_handlers.extend(
+ [
+ handlers.response_enricher,
+ handlers.cors_response_enricher,
+ handlers.usage_counter,
+ # add composite response handlers?
+ ]
+ )
+
+ def process_with_context(self, context: RestApiInvocationContext, response: Response):
+ chain = self.new_chain()
+ chain.handle(context, response)
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/gateway_response.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/gateway_response.py
new file mode 100644
index 0000000000000..a0e9935ccf775
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/gateway_response.py
@@ -0,0 +1,298 @@
+from enum import Enum
+
+from localstack.aws.api.apigateway import (
+ GatewayResponse,
+ GatewayResponseType,
+ MapOfStringToString,
+ StatusCode,
+)
+from localstack.constants import APPLICATION_JSON
+
+
+class GatewayResponseCode(StatusCode, Enum):
+ REQUEST_TOO_LARGE = "413"
+ RESOURCE_NOT_FOUND = "404"
+ AUTHORIZER_CONFIGURATION_ERROR = "500"
+ MISSING_AUTHENTICATION_TOKEN = "403"
+ BAD_REQUEST_BODY = "400"
+ INVALID_SIGNATURE = "403"
+ INVALID_API_KEY = "403"
+ BAD_REQUEST_PARAMETERS = "400"
+ AUTHORIZER_FAILURE = "500"
+ UNAUTHORIZED = "401"
+ INTEGRATION_TIMEOUT = "504"
+ ACCESS_DENIED = "403"
+ DEFAULT_4XX = ""
+ DEFAULT_5XX = ""
+ WAF_FILTERED = "403"
+ QUOTA_EXCEEDED = "429"
+ THROTTLED = "429"
+ API_CONFIGURATION_ERROR = "500"
+ UNSUPPORTED_MEDIA_TYPE = "415"
+ INTEGRATION_FAILURE = "504"
+ EXPIRED_TOKEN = "403"
+
+
+class BaseGatewayException(Exception):
+ """
+ Base class for all Gateway exceptions
+ Do not raise from this class directly. Instead, raise the specific Exception
+ """
+
+ message: str = "Unimplemented Response"
+ type: GatewayResponseType = None
+ status_code: int | str = None
+ code: str = ""
+
+ def __init__(self, message: str = None, status_code: int | str = None):
+ if message is not None:
+ self.message = message
+ if status_code is not None:
+ self.status_code = status_code
+ elif self.status_code is None and self.type:
+ # Fallback to the default value
+ self.status_code = GatewayResponseCode[self.type]
+
+
+class Default4xxError(BaseGatewayException):
+ """Do not raise from this class directly.
+ Use one of the subclasses instead, as they contain the appropriate header
+ """
+
+ type = GatewayResponseType.DEFAULT_4XX
+ status_code = 400
+
+
+class Default5xxError(BaseGatewayException):
+ """Do not raise from this class directly.
+ Use one of the subclasses instead, as they contain the appropriate header
+ """
+
+ type = GatewayResponseType.DEFAULT_5XX
+ status_code = 500
+
+
+class BadRequestException(Default4xxError):
+ code = "BadRequestException"
+
+
+class InternalFailureException(Default5xxError):
+ code = "InternalFailureException"
+
+
+class InternalServerError(Default5xxError):
+ code = "InternalServerErrorException"
+
+
+class AccessDeniedError(BaseGatewayException):
+ type = GatewayResponseType.ACCESS_DENIED
+ # TODO validate this header with aws validated tests
+ code = "AccessDeniedException"
+
+
+class ApiConfigurationError(BaseGatewayException):
+ type = GatewayResponseType.API_CONFIGURATION_ERROR
+ # TODO validate this header with aws validated tests
+ code = "ApiConfigurationException"
+
+
+class AuthorizerConfigurationError(BaseGatewayException):
+ type = GatewayResponseType.AUTHORIZER_CONFIGURATION_ERROR
+ # TODO validate this header with aws validated tests
+ code = "AuthorizerConfigurationException"
+ # the message is set to None by default in AWS
+ message = None
+
+
+class AuthorizerFailureError(BaseGatewayException):
+ type = GatewayResponseType.AUTHORIZER_FAILURE
+ # TODO validate this header with aws validated tests
+ code = "AuthorizerFailureException"
+
+
+class BadRequestParametersError(BaseGatewayException):
+ type = GatewayResponseType.BAD_REQUEST_PARAMETERS
+ code = "BadRequestException"
+
+
+class BadRequestBodyError(BaseGatewayException):
+ type = GatewayResponseType.BAD_REQUEST_BODY
+ code = "BadRequestException"
+
+
+class ExpiredTokenError(BaseGatewayException):
+ type = GatewayResponseType.EXPIRED_TOKEN
+ # TODO validate this header with aws validated tests
+ code = "ExpiredTokenException"
+
+
+class IntegrationFailureError(BaseGatewayException):
+ type = GatewayResponseType.INTEGRATION_FAILURE
+ code = "InternalServerErrorException"
+ status_code = 500
+
+
+class IntegrationTimeoutError(BaseGatewayException):
+ type = GatewayResponseType.INTEGRATION_TIMEOUT
+ code = "InternalServerErrorException"
+
+
+class InvalidAPIKeyError(BaseGatewayException):
+ type = GatewayResponseType.INVALID_API_KEY
+ code = "ForbiddenException"
+
+
+class InvalidSignatureError(BaseGatewayException):
+ type = GatewayResponseType.INVALID_SIGNATURE
+ # TODO validate this header with aws validated tests
+ code = "InvalidSignatureException"
+
+
+class MissingAuthTokenError(BaseGatewayException):
+ type = GatewayResponseType.MISSING_AUTHENTICATION_TOKEN
+ code = "MissingAuthenticationTokenException"
+
+
+class QuotaExceededError(BaseGatewayException):
+ type = GatewayResponseType.QUOTA_EXCEEDED
+ code = "LimitExceededException"
+
+
+class RequestTooLargeError(BaseGatewayException):
+ type = GatewayResponseType.REQUEST_TOO_LARGE
+ # TODO validate this header with aws validated tests
+ code = "RequestTooLargeException"
+
+
+class ResourceNotFoundError(BaseGatewayException):
+ type = GatewayResponseType.RESOURCE_NOT_FOUND
+ # TODO validate this header with aws validated tests
+ code = "ResourceNotFoundException"
+
+
+class ThrottledError(BaseGatewayException):
+ type = GatewayResponseType.THROTTLED
+ code = "TooManyRequestsException"
+
+
+class UnauthorizedError(BaseGatewayException):
+ type = GatewayResponseType.UNAUTHORIZED
+ code = "UnauthorizedException"
+
+
+class UnsupportedMediaTypeError(BaseGatewayException):
+ type = GatewayResponseType.UNSUPPORTED_MEDIA_TYPE
+ code = "BadRequestException"
+
+
+class WafFilteredError(BaseGatewayException):
+ type = GatewayResponseType.WAF_FILTERED
+ # TODO validate this header with aws validated tests
+ code = "WafFilteredException"
+
+
+def build_gateway_response(
+ response_type: GatewayResponseType,
+ status_code: StatusCode = None,
+ response_parameters: MapOfStringToString = None,
+ response_templates: MapOfStringToString = None,
+ default_response: bool = True,
+) -> GatewayResponse:
+ """Building a Gateway Response. Non provided attributes will use default."""
+ response = GatewayResponse(
+ responseParameters=response_parameters or {},
+ responseTemplates=response_templates
+ or {APPLICATION_JSON: '{"message":$context.error.messageString}'},
+ responseType=response_type,
+ defaultResponse=default_response,
+ statusCode=status_code,
+ )
+
+ return response
+
+
+def get_gateway_response_or_default(
+ response_type: GatewayResponseType,
+ gateway_responses: dict[GatewayResponseType, GatewayResponse],
+) -> GatewayResponse:
+ """Utility function that will look for a matching Gateway Response in the following order.
+ - If provided in the gateway_response, return the dicts value
+ - If the DEFAULT_XXX was configured will create a new response
+ - Otherwise we return from DEFAULT_GATEWAY_RESPONSE"""
+
+ if response := gateway_responses.get(response_type):
+ # User configured response
+ return response
+ response_code = GatewayResponseCode[response_type]
+ if response_code == "":
+ # DEFAULT_XXX response do not have a default code
+ return DEFAULT_GATEWAY_RESPONSES.get(response_type)
+ if response_code >= "500":
+ # 5XX response will either get a user configured DEFAULT_5XX or the DEFAULT_GATEWAY_RESPONSES
+ default = gateway_responses.get(GatewayResponseType.DEFAULT_5XX)
+ else:
+ # 4XX response will either get a user configured DEFAULT_4XX or the DEFAULT_GATEWAY_RESPONSES
+ default = gateway_responses.get(GatewayResponseType.DEFAULT_4XX)
+
+ if not default:
+ # If DEFAULT_XXX was not provided return default
+ return DEFAULT_GATEWAY_RESPONSES.get(response_type)
+
+ return build_gateway_response(
+ # Build a new response from default
+ response_type,
+ status_code=default.get("statusCode"),
+ response_parameters=default.get("responseParameters"),
+ response_templates=default.get("responseTemplates"),
+ )
+
+
+DEFAULT_GATEWAY_RESPONSES = {
+ GatewayResponseType.REQUEST_TOO_LARGE: build_gateway_response(
+ GatewayResponseType.REQUEST_TOO_LARGE
+ ),
+ GatewayResponseType.RESOURCE_NOT_FOUND: build_gateway_response(
+ GatewayResponseType.RESOURCE_NOT_FOUND
+ ),
+ GatewayResponseType.AUTHORIZER_CONFIGURATION_ERROR: build_gateway_response(
+ GatewayResponseType.AUTHORIZER_CONFIGURATION_ERROR
+ ),
+ GatewayResponseType.MISSING_AUTHENTICATION_TOKEN: build_gateway_response(
+ GatewayResponseType.MISSING_AUTHENTICATION_TOKEN
+ ),
+ GatewayResponseType.BAD_REQUEST_BODY: build_gateway_response(
+ GatewayResponseType.BAD_REQUEST_BODY
+ ),
+ GatewayResponseType.INVALID_SIGNATURE: build_gateway_response(
+ GatewayResponseType.INVALID_SIGNATURE
+ ),
+ GatewayResponseType.INVALID_API_KEY: build_gateway_response(
+ GatewayResponseType.INVALID_API_KEY
+ ),
+ GatewayResponseType.BAD_REQUEST_PARAMETERS: build_gateway_response(
+ GatewayResponseType.BAD_REQUEST_PARAMETERS
+ ),
+ GatewayResponseType.AUTHORIZER_FAILURE: build_gateway_response(
+ GatewayResponseType.AUTHORIZER_FAILURE
+ ),
+ GatewayResponseType.UNAUTHORIZED: build_gateway_response(GatewayResponseType.UNAUTHORIZED),
+ GatewayResponseType.INTEGRATION_TIMEOUT: build_gateway_response(
+ GatewayResponseType.INTEGRATION_TIMEOUT
+ ),
+ GatewayResponseType.ACCESS_DENIED: build_gateway_response(GatewayResponseType.ACCESS_DENIED),
+ GatewayResponseType.DEFAULT_4XX: build_gateway_response(GatewayResponseType.DEFAULT_4XX),
+ GatewayResponseType.DEFAULT_5XX: build_gateway_response(GatewayResponseType.DEFAULT_5XX),
+ GatewayResponseType.WAF_FILTERED: build_gateway_response(GatewayResponseType.WAF_FILTERED),
+ GatewayResponseType.QUOTA_EXCEEDED: build_gateway_response(GatewayResponseType.QUOTA_EXCEEDED),
+ GatewayResponseType.THROTTLED: build_gateway_response(GatewayResponseType.THROTTLED),
+ GatewayResponseType.API_CONFIGURATION_ERROR: build_gateway_response(
+ GatewayResponseType.API_CONFIGURATION_ERROR
+ ),
+ GatewayResponseType.UNSUPPORTED_MEDIA_TYPE: build_gateway_response(
+ GatewayResponseType.UNSUPPORTED_MEDIA_TYPE
+ ),
+ GatewayResponseType.INTEGRATION_FAILURE: build_gateway_response(
+ GatewayResponseType.INTEGRATION_FAILURE
+ ),
+ GatewayResponseType.EXPIRED_TOKEN: build_gateway_response(GatewayResponseType.EXPIRED_TOKEN),
+}
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/__init__.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/__init__.py
new file mode 100644
index 0000000000000..089bf7d6a899b
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/__init__.py
@@ -0,0 +1,29 @@
+from rolo.gateway import CompositeHandler
+
+from .analytics import IntegrationUsageCounter
+from .api_key_validation import ApiKeyValidationHandler
+from .cors import CorsResponseEnricher
+from .gateway_exception import GatewayExceptionHandler
+from .integration import IntegrationHandler
+from .integration_request import IntegrationRequestHandler
+from .integration_response import IntegrationResponseHandler
+from .method_request import MethodRequestHandler
+from .method_response import MethodResponseHandler
+from .parse import InvocationRequestParser
+from .resource_router import InvocationRequestRouter
+from .response_enricher import InvocationResponseEnricher
+
+parse_request = InvocationRequestParser()
+modify_request = CompositeHandler()
+route_request = InvocationRequestRouter()
+preprocess_request = CompositeHandler()
+method_request_handler = MethodRequestHandler()
+integration_request_handler = IntegrationRequestHandler()
+integration_handler = IntegrationHandler()
+integration_response_handler = IntegrationResponseHandler()
+method_response_handler = MethodResponseHandler()
+gateway_exception_handler = GatewayExceptionHandler()
+api_key_validation_handler = ApiKeyValidationHandler()
+response_enricher = InvocationResponseEnricher()
+cors_response_enricher = CorsResponseEnricher()
+usage_counter = IntegrationUsageCounter()
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/analytics.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/analytics.py
new file mode 100644
index 0000000000000..b93a611fed2f6
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/analytics.py
@@ -0,0 +1,44 @@
+import logging
+
+from localstack.http import Response
+from localstack.utils.analytics.usage import UsageSetCounter
+
+from ..api import RestApiGatewayHandler, RestApiGatewayHandlerChain
+from ..context import RestApiInvocationContext
+
+LOG = logging.getLogger(__name__)
+
+
+class IntegrationUsageCounter(RestApiGatewayHandler):
+ counter: UsageSetCounter
+
+ def __init__(self, counter: UsageSetCounter = None):
+ self.counter = counter or UsageSetCounter(namespace="apigateway:invokedrest")
+
+ def __call__(
+ self,
+ chain: RestApiGatewayHandlerChain,
+ context: RestApiInvocationContext,
+ response: Response,
+ ):
+ if context.integration:
+ invocation_type = context.integration["type"]
+ if invocation_type == "AWS":
+ service_name = self._get_aws_integration_service(context.integration.get("uri"))
+ invocation_type = f"{invocation_type}:{service_name}"
+ else:
+ # if the invocation does not have an integration attached, it probably failed before routing the request,
+ # hence we should count it as a NOT_FOUND invocation
+ invocation_type = "NOT_FOUND"
+
+ self.counter.record(invocation_type)
+
+ @staticmethod
+ def _get_aws_integration_service(integration_uri: str) -> str:
+ if not integration_uri:
+ return "null"
+
+ if len(split_arn := integration_uri.split(":", maxsplit=5)) < 4:
+ return "null"
+
+ return split_arn[4]
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/api_key_validation.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/api_key_validation.py
new file mode 100644
index 0000000000000..ba8ada9769f17
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/api_key_validation.py
@@ -0,0 +1,113 @@
+import logging
+from typing import Optional
+
+from localstack.aws.api.apigateway import ApiKey, ApiKeySourceType, RestApi
+from localstack.http import Response
+
+from ..api import RestApiGatewayHandler, RestApiGatewayHandlerChain
+from ..context import InvocationRequest, RestApiInvocationContext
+from ..gateway_response import InvalidAPIKeyError
+from ..moto_helpers import get_api_key, get_usage_plan_keys, get_usage_plans
+from ..variables import ContextVarsIdentity
+
+LOG = logging.getLogger(__name__)
+
+
+class ApiKeyValidationHandler(RestApiGatewayHandler):
+ """
+ Handles Api key validation.
+ If an api key is required, we will validate that a usage plan associated with that stage
+ has a usage plan key with the corresponding value.
+ """
+
+ # TODO We currently do not support rate limiting or quota limit. As such we are not raising any related Exception
+
+ def __call__(
+ self,
+ chain: RestApiGatewayHandlerChain,
+ context: RestApiInvocationContext,
+ response: Response,
+ ):
+ method = context.resource_method
+ request = context.invocation_request
+ rest_api = context.deployment.rest_api.rest_api
+
+ # If api key is not required by the method, we can exit the handler
+ if not method.get("apiKeyRequired"):
+ return
+
+ identity = context.context_variables.get("identity")
+
+ # Look for the api key value in the request. If it is not found, raise an exception
+ if not (api_key_value := self.get_request_api_key(rest_api, request, identity)):
+ LOG.debug("API Key is empty")
+ raise InvalidAPIKeyError("Forbidden")
+
+ # Get the validated key, if no key is found, raise an exception
+ if not (validated_key := self.validate_api_key(api_key_value, context)):
+ LOG.debug("Provided API Key is not valid")
+ raise InvalidAPIKeyError("Forbidden")
+
+ # Update the context's identity with the key value and id
+ if not identity.get("apiKey"):
+ LOG.debug("Updating $context.identity.apiKey='%s'", validated_key["value"])
+ identity["apiKey"] = validated_key["value"]
+
+ LOG.debug("Updating $context.identity.apiKeyId='%s'", validated_key["id"])
+ identity["apiKeyId"] = validated_key["id"]
+
+ def validate_api_key(
+ self, api_key_value, context: RestApiInvocationContext
+ ) -> Optional[ApiKey]:
+ api_id = context.api_id
+ stage = context.stage
+ account_id = context.account_id
+ region = context.region
+
+ # Get usage plans from the store
+ usage_plans = get_usage_plans(account_id=account_id, region_name=region)
+
+ # Loop through usage plans and keep ids of the plans associated with the deployment stage
+ usage_plan_ids = []
+ for usage_plan in usage_plans:
+ api_stages = usage_plan.get("apiStages", [])
+ usage_plan_ids.extend(
+ usage_plan.get("id")
+ for api_stage in api_stages
+ if (api_stage.get("stage") == stage and api_stage.get("apiId") == api_id)
+ )
+ if not usage_plan_ids:
+ LOG.debug("No associated usage plans found stage '%s'", stage)
+ return
+
+ # Loop through plans with an association with the stage find a key with matching value
+ for usage_plan_id in usage_plan_ids:
+ usage_plan_keys = get_usage_plan_keys(
+ usage_plan_id=usage_plan_id, account_id=account_id, region_name=region
+ )
+ for key in usage_plan_keys:
+ if key["value"] == api_key_value:
+ api_key = get_api_key(
+ api_key_id=key["id"], account_id=account_id, region_name=region
+ )
+ LOG.debug("Found Api Key '%s'", api_key["id"])
+ return api_key if api_key["enabled"] else None
+
+ def get_request_api_key(
+ self, rest_api: RestApi, request: InvocationRequest, identity: ContextVarsIdentity
+ ) -> Optional[str]:
+ """https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-api-key-source.html
+ The source of the API key for metering requests according to a usage plan.
+ Valid values are:
+ - HEADER to read the API key from the X-API-Key header of a request.
+ - AUTHORIZER to read the API key from the Context Variables.
+ """
+ match api_key_source := rest_api.get("apiKeySource"):
+ case ApiKeySourceType.HEADER:
+ LOG.debug("Looking for api key in header 'X-API-Key'")
+ return request.get("headers", {}).get("X-API-Key")
+ case ApiKeySourceType.AUTHORIZER:
+ LOG.debug("Looking for api key in Identity Context")
+ return identity.get("apiKey")
+ case _:
+ LOG.debug("Api Key Source is not valid: '%s'", api_key_source)
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/cors.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/cors.py
new file mode 100644
index 0000000000000..497a9a273464c
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/cors.py
@@ -0,0 +1,49 @@
+import logging
+from http import HTTPMethod
+
+from localstack import config
+from localstack.aws.handlers.cors import CorsEnforcer
+from localstack.aws.handlers.cors import CorsResponseEnricher as GlobalCorsResponseEnricher
+from localstack.http import Response
+
+from ..api import RestApiGatewayHandler, RestApiGatewayHandlerChain
+from ..context import RestApiInvocationContext
+from ..gateway_response import MissingAuthTokenError
+
+LOG = logging.getLogger(__name__)
+
+
+class CorsResponseEnricher(RestApiGatewayHandler):
+ def __call__(
+ self,
+ chain: RestApiGatewayHandlerChain,
+ context: RestApiInvocationContext,
+ response: Response,
+ ):
+ """
+ This is a LocalStack only handler, to allow users to override API Gateway CORS configuration and just use the
+ default LocalStack configuration instead, to ease the usage and reduce production code changes.
+ """
+ if not config.DISABLE_CUSTOM_CORS_APIGATEWAY:
+ return
+
+ if not context.invocation_request:
+ return
+
+ headers = context.invocation_request["headers"]
+
+ if "Origin" not in headers:
+ return
+
+ if context.request.method == HTTPMethod.OPTIONS:
+ # If the user did not configure an OPTIONS route, we still want LocalStack to properly respond to CORS
+ # requests
+ if context.invocation_exception:
+ if isinstance(context.invocation_exception, MissingAuthTokenError):
+ response.data = b""
+ response.status_code = 204
+ else:
+ return
+
+ if CorsEnforcer.is_cors_origin_allowed(headers):
+ GlobalCorsResponseEnricher.add_cors_headers(headers, response.headers)
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/gateway_exception.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/gateway_exception.py
new file mode 100644
index 0000000000000..174b2cf8c1bc2
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/gateway_exception.py
@@ -0,0 +1,98 @@
+import json
+import logging
+
+from rolo import Response
+from werkzeug.datastructures import Headers
+
+from localstack.constants import APPLICATION_JSON
+from localstack.services.apigateway.next_gen.execute_api.api import (
+ RestApiGatewayExceptionHandler,
+ RestApiGatewayHandlerChain,
+)
+from localstack.services.apigateway.next_gen.execute_api.context import RestApiInvocationContext
+from localstack.services.apigateway.next_gen.execute_api.gateway_response import (
+ AccessDeniedError,
+ BaseGatewayException,
+ get_gateway_response_or_default,
+)
+from localstack.services.apigateway.next_gen.execute_api.variables import (
+ GatewayResponseContextVarsError,
+)
+
+LOG = logging.getLogger(__name__)
+
+
+class GatewayExceptionHandler(RestApiGatewayExceptionHandler):
+ """
+ Exception handler that serializes the Gateway Exceptions into Gateway Responses
+ """
+
+ def __call__(
+ self,
+ chain: RestApiGatewayHandlerChain,
+ exception: Exception,
+ context: RestApiInvocationContext,
+ response: Response,
+ ):
+ if not isinstance(exception, BaseGatewayException):
+ LOG.warning(
+ "Non Gateway Exception raised: %s",
+ exception,
+ exc_info=LOG.isEnabledFor(logging.DEBUG),
+ )
+ response.update_from(
+ Response(response=f"Error in apigateway invocation: {exception}", status="500")
+ )
+ return
+
+ LOG.info("Error raised during invocation: %s", exception.type)
+ self.set_error_context(exception, context)
+ error = self.create_exception_response(exception, context)
+ if error:
+ response.update_from(error)
+
+ @staticmethod
+ def set_error_context(exception: BaseGatewayException, context: RestApiInvocationContext):
+ context.context_variables["error"] = GatewayResponseContextVarsError(
+ message=exception.message,
+ messageString=exception.message,
+ responseType=exception.type,
+ validationErrorString="", # TODO
+ )
+
+ def create_exception_response(
+ self, exception: BaseGatewayException, context: RestApiInvocationContext
+ ):
+ gateway_response = get_gateway_response_or_default(
+ exception.type, context.deployment.rest_api.gateway_responses
+ )
+
+ content = self._build_response_content(exception)
+
+ headers = self._build_response_headers(exception)
+
+ status_code = gateway_response.get("statusCode")
+ if not status_code:
+ status_code = exception.status_code or 500
+
+ response = Response(response=content, headers=headers, status=status_code)
+ return response
+
+ @staticmethod
+ def _build_response_content(exception: BaseGatewayException) -> str:
+ # TODO apply responseTemplates to the content. We should also handle the default simply by managing the default
+ # template body `{"message":$context.error.messageString}`
+
+ # TODO: remove this workaround by properly managing the responseTemplate for UnauthorizedError
+ # on the CRUD level, it returns the same template as all other errors but in reality the message field is
+ # capitalized
+ if isinstance(exception, AccessDeniedError):
+ return json.dumps({"Message": exception.message}, separators=(",", ":"))
+
+ return json.dumps({"message": exception.message})
+
+ @staticmethod
+ def _build_response_headers(exception: BaseGatewayException) -> dict:
+ # TODO apply responseParameters to the headers and get content-type from the gateway_response
+ headers = Headers({"Content-Type": APPLICATION_JSON, "x-amzn-ErrorType": exception.code})
+ return headers
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/integration.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/integration.py
new file mode 100644
index 0000000000000..d8a9e984de637
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/integration.py
@@ -0,0 +1,35 @@
+import logging
+
+from localstack.http import Response
+
+from ..api import RestApiGatewayHandler, RestApiGatewayHandlerChain
+from ..context import EndpointResponse, RestApiInvocationContext
+from ..integrations import REST_API_INTEGRATIONS
+
+LOG = logging.getLogger(__name__)
+
+
+# TODO: this will need to use ApiGatewayIntegration class, using Plugin for discoverability and a PluginManager,
+# in order to automatically have access to defined Integrations that we can extend
+class IntegrationHandler(RestApiGatewayHandler):
+ def __call__(
+ self,
+ chain: RestApiGatewayHandlerChain,
+ context: RestApiInvocationContext,
+ response: Response,
+ ):
+ integration_type = context.integration["type"]
+ is_proxy = "PROXY" in integration_type
+
+ integration = REST_API_INTEGRATIONS.get(integration_type)
+
+ if not integration:
+ # TODO: raise proper exception?
+ raise NotImplementedError(
+ f"This integration type is not yet supported: {integration_type}"
+ )
+
+ endpoint_response: EndpointResponse = integration.invoke(context)
+ context.endpoint_response = endpoint_response
+ if is_proxy:
+ context.invocation_response = endpoint_response
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/integration_request.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/integration_request.py
new file mode 100644
index 0000000000000..6b74222a170a4
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/integration_request.py
@@ -0,0 +1,304 @@
+import logging
+from http import HTTPMethod
+
+from werkzeug.datastructures import Headers
+
+from localstack.aws.api.apigateway import Integration, IntegrationType
+from localstack.constants import APPLICATION_JSON
+from localstack.http import Request, Response
+from localstack.utils.collections import merge_recursive
+from localstack.utils.strings import to_bytes, to_str
+
+from ..api import RestApiGatewayHandler, RestApiGatewayHandlerChain
+from ..context import IntegrationRequest, InvocationRequest, RestApiInvocationContext
+from ..gateway_response import InternalServerError, UnsupportedMediaTypeError
+from ..header_utils import drop_headers, set_default_headers
+from ..helpers import render_integration_uri
+from ..parameters_mapping import ParametersMapper, RequestDataMapping
+from ..template_mapping import (
+ ApiGatewayVtlTemplate,
+ MappingTemplateInput,
+ MappingTemplateParams,
+ MappingTemplateVariables,
+)
+from ..variables import ContextVarsRequestOverride
+
+LOG = logging.getLogger(__name__)
+
+# Illegal headers to include in transformation
+ILLEGAL_INTEGRATION_REQUESTS_COMMON = [
+ "content-length",
+ "transfer-encoding",
+ "x-amzn-trace-id",
+ "X-Amzn-Apigateway-Api-Id",
+]
+ILLEGAL_INTEGRATION_REQUESTS_AWS = [
+ *ILLEGAL_INTEGRATION_REQUESTS_COMMON,
+ "authorization",
+ "connection",
+ "expect",
+ "proxy-authenticate",
+ "te",
+]
+
+# These are dropped after the templates override were applied. they will never make it to the requests.
+DROPPED_FROM_INTEGRATION_REQUESTS_COMMON = ["Expect", "Proxy-Authenticate", "TE"]
+DROPPED_FROM_INTEGRATION_REQUESTS_AWS = [*DROPPED_FROM_INTEGRATION_REQUESTS_COMMON, "Referer"]
+DROPPED_FROM_INTEGRATION_REQUESTS_HTTP = [*DROPPED_FROM_INTEGRATION_REQUESTS_COMMON, "Via"]
+
+# Default headers
+DEFAULT_REQUEST_HEADERS = {"Accept": APPLICATION_JSON, "Connection": "keep-alive"}
+
+
+class PassthroughBehavior(str):
+ # TODO maybe this class should be moved where it can also be used for validation in
+ # the provider when we switch out of moto
+ WHEN_NO_MATCH = "WHEN_NO_MATCH"
+ WHEN_NO_TEMPLATES = "WHEN_NO_TEMPLATES"
+ NEVER = "NEVER"
+
+
+class IntegrationRequestHandler(RestApiGatewayHandler):
+ """
+ This class will take care of the Integration Request part, which is mostly linked to template mapping
+ See https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-integration-settings-integration-request.html
+ """
+
+ def __init__(self):
+ self._param_mapper = ParametersMapper()
+ self._vtl_template = ApiGatewayVtlTemplate()
+
+ def __call__(
+ self,
+ chain: RestApiGatewayHandlerChain,
+ context: RestApiInvocationContext,
+ response: Response,
+ ):
+ integration: Integration = context.integration
+ integration_type = integration["type"]
+
+ integration_request_parameters = integration["requestParameters"] or {}
+ request_data_mapping = self.get_integration_request_data(
+ context, integration_request_parameters
+ )
+ path_parameters = request_data_mapping["path"]
+
+ if integration_type in (IntegrationType.AWS_PROXY, IntegrationType.HTTP_PROXY):
+ # `PROXY` types cannot use integration mapping templates, they pass most of the data straight
+ # We make a copy to avoid modifying the invocation headers and keep a cleaner history
+ headers = context.invocation_request["headers"].copy()
+ query_string_parameters: dict[str, list[str]] = context.invocation_request[
+ "multi_value_query_string_parameters"
+ ]
+ body = context.invocation_request["body"]
+
+ # HTTP_PROXY still make uses of the request data mappings, and merges it with the invocation request
+ # this is undocumented but validated behavior
+ if integration_type == IntegrationType.HTTP_PROXY:
+ # These headers won't be passed through by default from the invocation.
+ # They can however be added through request mappings.
+ drop_headers(headers, ["Host", "Content-Encoding"])
+ headers.update(request_data_mapping["header"])
+
+ query_string_parameters = self._merge_http_proxy_query_string(
+ query_string_parameters, request_data_mapping["querystring"]
+ )
+
+ else:
+ self._set_proxy_headers(headers, context.request)
+ # AWS_PROXY does not allow URI path rendering
+ # TODO: verify this
+ path_parameters = {}
+
+ else:
+ # find request template to raise UnsupportedMediaTypeError early
+ request_template = self.get_request_template(
+ integration=integration, request=context.invocation_request
+ )
+
+ body, request_override = self.render_request_template_mapping(
+ context=context, template=request_template
+ )
+ # mutate the ContextVariables with the requestOverride result, as we copy the context when rendering the
+ # template to avoid mutation on other fields
+ # the VTL responseTemplate can access the requestOverride
+ context.context_variables["requestOverride"] = request_override
+ # TODO: log every override that happens afterwards (in a loop on `request_override`)
+ merge_recursive(request_override, request_data_mapping, overwrite=True)
+
+ headers = Headers(request_data_mapping["header"])
+ query_string_parameters = request_data_mapping["querystring"]
+
+ # Some headers can't be modified by parameter mappings or mapping templates.
+ # Aws will raise in those were present. Even for AWS_PROXY, where it is not applying them.
+ if header_mappings := request_data_mapping["header"]:
+ self._validate_headers_mapping(header_mappings, integration_type)
+
+ self._apply_header_transforms(headers, integration_type, context)
+
+ # looks like the stageVariables rendering part is done in the Integration part in AWS
+ # but we can avoid duplication by doing it here for now
+ # TODO: if the integration if of AWS Lambda type and the Lambda is in another account, we cannot render
+ # stageVariables. Work on that special case later (we can add a quick check for the URI region and set the
+ # stage variables to an empty dict)
+ rendered_integration_uri = render_integration_uri(
+ uri=integration["uri"],
+ path_parameters=path_parameters,
+ stage_variables=context.stage_variables,
+ )
+
+ # if the integration method is defined and is not ANY, we can use it for the integration
+ if not (integration_method := integration["httpMethod"]) or integration_method == "ANY":
+ # otherwise, fallback to the request's method
+ integration_method = context.invocation_request["http_method"]
+
+ integration_request = IntegrationRequest(
+ http_method=integration_method,
+ uri=rendered_integration_uri,
+ query_string_parameters=query_string_parameters,
+ headers=headers,
+ body=body,
+ )
+
+ context.integration_request = integration_request
+
+ def get_integration_request_data(
+ self, context: RestApiInvocationContext, request_parameters: dict[str, str]
+ ) -> RequestDataMapping:
+ return self._param_mapper.map_integration_request(
+ request_parameters=request_parameters,
+ invocation_request=context.invocation_request,
+ context_variables=context.context_variables,
+ stage_variables=context.stage_variables,
+ )
+
+ def render_request_template_mapping(
+ self,
+ context: RestApiInvocationContext,
+ template: str,
+ ) -> tuple[bytes, ContextVarsRequestOverride]:
+ request: InvocationRequest = context.invocation_request
+ body = request["body"]
+
+ if not template:
+ return body, {}
+
+ body, request_override = self._vtl_template.render_request(
+ template=template,
+ variables=MappingTemplateVariables(
+ context=context.context_variables,
+ stageVariables=context.stage_variables or {},
+ input=MappingTemplateInput(
+ body=to_str(body),
+ params=MappingTemplateParams(
+ path=request.get("path_parameters"),
+ querystring=request.get("query_string_parameters", {}),
+ header=request.get("headers"),
+ ),
+ ),
+ ),
+ )
+ return to_bytes(body), request_override
+
+ @staticmethod
+ def get_request_template(integration: Integration, request: InvocationRequest) -> str:
+ """
+ Attempts to return the request template.
+ Will raise UnsupportedMediaTypeError if there are no match according to passthrough behavior.
+ """
+ request_templates = integration.get("requestTemplates") or {}
+ passthrough_behavior = integration.get("passthroughBehavior")
+ # If content-type is not provided aws assumes application/json
+ content_type = request["headers"].get("Content-Type", APPLICATION_JSON)
+ # first look to for a template associated to the content-type, otherwise look for the $default template
+ request_template = request_templates.get(content_type) or request_templates.get("$default")
+
+ if request_template or passthrough_behavior == PassthroughBehavior.WHEN_NO_MATCH:
+ return request_template
+
+ match passthrough_behavior:
+ case PassthroughBehavior.NEVER:
+ LOG.debug(
+ "No request template found for '%s' and passthrough behavior set to NEVER",
+ content_type,
+ )
+ raise UnsupportedMediaTypeError("Unsupported Media Type")
+ case PassthroughBehavior.WHEN_NO_TEMPLATES:
+ if request_templates:
+ LOG.debug(
+ "No request template found for '%s' and passthrough behavior set to WHEN_NO_TEMPLATES",
+ content_type,
+ )
+ raise UnsupportedMediaTypeError("Unsupported Media Type")
+ case _:
+ LOG.debug("Unknown passthrough behavior: '%s'", passthrough_behavior)
+
+ return request_template
+
+ @staticmethod
+ def _merge_http_proxy_query_string(
+ query_string_parameters: dict[str, list[str]],
+ mapped_query_string: dict[str, str | list[str]],
+ ):
+ new_query_string_parameters = {k: v.copy() for k, v in query_string_parameters.items()}
+ for param, value in mapped_query_string.items():
+ if existing := new_query_string_parameters.get(param):
+ if isinstance(value, list):
+ existing.extend(value)
+ else:
+ existing.append(value)
+ else:
+ new_query_string_parameters[param] = value
+
+ return new_query_string_parameters
+
+ @staticmethod
+ def _set_proxy_headers(headers: Headers, request: Request):
+ headers.set("X-Forwarded-For", request.remote_addr)
+ headers.set("X-Forwarded-Port", request.environ.get("SERVER_PORT"))
+ headers.set(
+ "X-Forwarded-Proto",
+ request.environ.get("SERVER_PROTOCOL", "").split("/")[0],
+ )
+
+ @staticmethod
+ def _apply_header_transforms(
+ headers: Headers, integration_type: IntegrationType, context: RestApiInvocationContext
+ ):
+ # Dropping matching headers for the provided integration type
+ match integration_type:
+ case IntegrationType.AWS:
+ drop_headers(headers, DROPPED_FROM_INTEGRATION_REQUESTS_AWS)
+ case IntegrationType.HTTP | IntegrationType.HTTP_PROXY:
+ drop_headers(headers, DROPPED_FROM_INTEGRATION_REQUESTS_HTTP)
+ case _:
+ drop_headers(headers, DROPPED_FROM_INTEGRATION_REQUESTS_COMMON)
+
+ # Adding default headers to the requests headers
+ default_headers = {
+ **DEFAULT_REQUEST_HEADERS,
+ "User-Agent": f"AmazonAPIGateway_{context.api_id}",
+ }
+ if (
+ content_type := context.request.headers.get("Content-Type")
+ ) and context.request.method not in {HTTPMethod.OPTIONS, HTTPMethod.GET, HTTPMethod.HEAD}:
+ default_headers["Content-Type"] = content_type
+
+ set_default_headers(headers, default_headers)
+ headers.set("X-Amzn-Trace-Id", context.trace_id)
+ if integration_type not in (IntegrationType.AWS_PROXY, IntegrationType.AWS):
+ headers.set("X-Amzn-Apigateway-Api-Id", context.api_id)
+
+ @staticmethod
+ def _validate_headers_mapping(headers: dict[str, str], integration_type: IntegrationType):
+ """Validates and raises an error when attempting to set an illegal header"""
+ to_validate = ILLEGAL_INTEGRATION_REQUESTS_COMMON
+ if integration_type in {IntegrationType.AWS, IntegrationType.AWS_PROXY}:
+ to_validate = ILLEGAL_INTEGRATION_REQUESTS_AWS
+
+ for header in headers:
+ if header.lower() in to_validate:
+ LOG.debug(
+ "Execution failed due to configuration error: %s header already present", header
+ )
+ raise InternalServerError("Internal server error")
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/integration_response.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/integration_response.py
new file mode 100644
index 0000000000000..25df425b5a193
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/integration_response.py
@@ -0,0 +1,245 @@
+import json
+import logging
+import re
+
+from werkzeug.datastructures import Headers
+
+from localstack.aws.api.apigateway import Integration, IntegrationResponse, IntegrationType
+from localstack.constants import APPLICATION_JSON
+from localstack.http import Response
+from localstack.utils.strings import to_bytes, to_str
+
+from ..api import RestApiGatewayHandler, RestApiGatewayHandlerChain
+from ..context import (
+ EndpointResponse,
+ InvocationRequest,
+ InvocationResponse,
+ RestApiInvocationContext,
+)
+from ..gateway_response import ApiConfigurationError, InternalServerError
+from ..parameters_mapping import ParametersMapper, ResponseDataMapping
+from ..template_mapping import (
+ ApiGatewayVtlTemplate,
+ MappingTemplateInput,
+ MappingTemplateParams,
+ MappingTemplateVariables,
+)
+from ..variables import ContextVarsResponseOverride
+
+LOG = logging.getLogger(__name__)
+
+
+class IntegrationResponseHandler(RestApiGatewayHandler):
+ """
+ This class will take care of the Integration Response part, which is mostly linked to template mapping
+ See https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-integration-settings-integration-response.html
+ """
+
+ def __init__(self):
+ self._param_mapper = ParametersMapper()
+ self._vtl_template = ApiGatewayVtlTemplate()
+
+ def __call__(
+ self,
+ chain: RestApiGatewayHandlerChain,
+ context: RestApiInvocationContext,
+ response: Response,
+ ):
+ # TODO: we should log the response coming in from the Integration, either in Integration or here.
+ # before modification / after?
+ integration: Integration = context.integration
+ integration_type = integration["type"]
+
+ if integration_type in (IntegrationType.AWS_PROXY, IntegrationType.HTTP_PROXY):
+ # `PROXY` types cannot use integration response mapping templates
+ # TODO: verify assumptions against AWS
+ return
+
+ endpoint_response: EndpointResponse = context.endpoint_response
+ status_code = endpoint_response["status_code"]
+ body = endpoint_response["body"]
+
+ # we first need to find the right IntegrationResponse based on their selection template, linked to the status
+ # code of the Response
+ if integration_type == IntegrationType.AWS and "lambda:path/" in integration["uri"]:
+ selection_value = self.parse_error_message_from_lambda(body) or str(status_code)
+ else:
+ selection_value = str(status_code)
+
+ integration_response: IntegrationResponse = self.select_integration_response(
+ selection_value,
+ integration["integrationResponses"],
+ )
+
+ # we then need to apply Integration Response parameters mapping, to only return select headers
+ response_parameters = integration_response.get("responseParameters") or {}
+ response_data_mapping = self.get_method_response_data(
+ context=context,
+ response=endpoint_response,
+ response_parameters=response_parameters,
+ )
+
+ # We then fetch a response templates and apply the template mapping
+ response_template = self.get_response_template(
+ integration_response=integration_response, request=context.invocation_request
+ )
+ body, response_override = self.render_response_template_mapping(
+ context=context, template=response_template, body=body
+ )
+
+ # We basically need to remove all headers and replace them with the mapping, then
+ # override them if there are overrides.
+ # The status code is pretty straight forward. By default, it would be set by the integration response,
+ # unless there was an override
+ response_status_code = int(integration_response["statusCode"])
+ if response_status_override := response_override["status"]:
+ # maybe make a better error message format, same for the overrides for request too
+ LOG.debug("Overriding response status code: '%s'", response_status_override)
+ response_status_code = response_status_override
+
+ # Create a new headers object that we can manipulate before overriding the original response headers
+ response_headers = Headers(response_data_mapping.get("header"))
+ if header_override := response_override["header"]:
+ LOG.debug("Response header overrides: %s", header_override)
+ response_headers.update(header_override)
+
+ LOG.debug("Method response body after transformations: %s", body)
+ context.invocation_response = InvocationResponse(
+ body=body,
+ headers=response_headers,
+ status_code=response_status_code,
+ )
+
+ def get_method_response_data(
+ self,
+ context: RestApiInvocationContext,
+ response: EndpointResponse,
+ response_parameters: dict[str, str],
+ ) -> ResponseDataMapping:
+ return self._param_mapper.map_integration_response(
+ response_parameters=response_parameters,
+ integration_response=response,
+ context_variables=context.context_variables,
+ stage_variables=context.stage_variables,
+ )
+
+ @staticmethod
+ def select_integration_response(
+ selection_value: str, integration_responses: dict[str, IntegrationResponse]
+ ) -> IntegrationResponse:
+ if not integration_responses:
+ LOG.warning(
+ "Configuration error: No match for output mapping and no default output mapping configured. "
+ "Endpoint Response Status Code: %s",
+ selection_value,
+ )
+ raise ApiConfigurationError("Internal server error")
+
+ if select_by_pattern := [
+ response
+ for response in integration_responses.values()
+ if (selectionPatten := response.get("selectionPattern"))
+ and re.match(selectionPatten, selection_value)
+ ]:
+ selected_response = select_by_pattern[0]
+ if len(select_by_pattern) > 1:
+ LOG.warning(
+ "Multiple integration responses matching '%s' statuscode. Choosing '%s' (first).",
+ selection_value,
+ selected_response["statusCode"],
+ )
+ else:
+ # choose default return code
+ # TODO: the provider should check this, as we should only have one default with no value in selectionPattern
+ default_responses = [
+ response
+ for response in integration_responses.values()
+ if not response.get("selectionPattern")
+ ]
+ if not default_responses:
+ # TODO: verify log message when the selection_value is a lambda errorMessage
+ LOG.warning(
+ "Configuration error: No match for output mapping and no default output mapping configured. "
+ "Endpoint Response Status Code: %s",
+ selection_value,
+ )
+ raise ApiConfigurationError("Internal server error")
+
+ selected_response = default_responses[0]
+ if len(default_responses) > 1:
+ LOG.warning(
+ "Multiple default integration responses. Choosing %s (first).",
+ selected_response["statusCode"],
+ )
+ return selected_response
+
+ @staticmethod
+ def get_response_template(
+ integration_response: IntegrationResponse, request: InvocationRequest
+ ) -> str:
+ """The Response Template is selected from the response templates.
+ If there are no templates defined, the body will pass through.
+ Apigateway looks at the integration request `Accept` header and defaults to `application/json`.
+ If no template is matched, Apigateway will use the "first" existing template and use it as default.
+ https://docs.aws.amazon.com/apigateway/latest/developerguide/request-response-data-mappings.html#transforming-request-response-body
+ """
+ if not (response_templates := integration_response["responseTemplates"]):
+ return ""
+
+ # The invocation request header is used to find the right response templated
+ accepts = request["headers"].getlist("accept")
+ if accepts and (template := response_templates.get(accepts[-1])):
+ return template
+ # TODO aws seemed to favor application/json as default when unmatched regardless of "first"
+ if template := response_templates.get(APPLICATION_JSON):
+ return template
+ # TODO What is first? do we need to keep an order as to when they were added/modified?
+ template = next(iter(response_templates.values()))
+ LOG.warning("No templates were matched, Using template: %s", template)
+ return template
+
+ def render_response_template_mapping(
+ self, context: RestApiInvocationContext, template: str, body: bytes | str
+ ) -> tuple[bytes, ContextVarsResponseOverride]:
+ if not template:
+ return body, ContextVarsResponseOverride(status=0, header={})
+
+ body, response_override = self._vtl_template.render_response(
+ template=template,
+ variables=MappingTemplateVariables(
+ context=context.context_variables,
+ stageVariables=context.stage_variables or {},
+ input=MappingTemplateInput(
+ body=to_str(body),
+ params=MappingTemplateParams(
+ path=context.invocation_request.get("path_parameters"),
+ querystring=context.invocation_request.get("query_string_parameters", {}),
+ header=context.invocation_request.get("headers", {}),
+ ),
+ ),
+ ),
+ )
+
+ # AWS ignores the status if the override isn't an integer between 100 and 599
+ if (status := response_override["status"]) and not (
+ isinstance(status, int) and 100 <= status < 600
+ ):
+ response_override["status"] = 0
+ return to_bytes(body), response_override
+
+ @staticmethod
+ def parse_error_message_from_lambda(payload: bytes) -> str:
+ try:
+ lambda_response = json.loads(payload)
+ if not isinstance(lambda_response, dict):
+ return ""
+
+ # very weird case, but AWS will not return the Error from Lambda in AWS integration, where it does for
+ # Kinesis and such. The AWS Lambda only behavior is concentrated in this method
+ if lambda_response.get("__type") == "AccessDeniedException":
+ raise InternalServerError("Internal server error")
+
+ return lambda_response.get("errorMessage", "")
+
+ except json.JSONDecodeError:
+ return ""
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/method_request.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/method_request.py
new file mode 100644
index 0000000000000..00a35129225b1
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/method_request.py
@@ -0,0 +1,147 @@
+import json
+import logging
+
+from jsonschema import ValidationError, validate
+
+from localstack.aws.api.apigateway import Method
+from localstack.constants import APPLICATION_JSON
+from localstack.http import Response
+from localstack.services.apigateway.helpers import EMPTY_MODEL, ModelResolver
+from localstack.services.apigateway.models import RestApiContainer
+
+from ..api import RestApiGatewayHandler, RestApiGatewayHandlerChain
+from ..context import InvocationRequest, RestApiInvocationContext
+from ..gateway_response import BadRequestBodyError, BadRequestParametersError
+
+LOG = logging.getLogger(__name__)
+
+
+class MethodRequestHandler(RestApiGatewayHandler):
+ """
+ This class will mostly take care of Request validation with Models
+ See https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-method-settings-method-request.html
+ """
+
+ def __call__(
+ self,
+ chain: RestApiGatewayHandlerChain,
+ context: RestApiInvocationContext,
+ response: Response,
+ ):
+ self.validate_request(
+ context.resource_method,
+ context.deployment.rest_api,
+ context.invocation_request,
+ )
+
+ def validate_request(
+ self, method: Method, rest_api: RestApiContainer, request: InvocationRequest
+ ) -> None:
+ """
+ :raises BadRequestParametersError if the request has required parameters which are not present
+ :raises BadRequestBodyError if the request has required body validation with a model and it does not respect it
+ :return: None
+ """
+
+ # check if there is validator for the method
+ if not (request_validator_id := method.get("requestValidatorId") or "").strip():
+ return
+
+ # check if there is a validator for this request
+ if not (validator := rest_api.validators.get(request_validator_id)):
+ # TODO Should we raise an exception instead?
+ LOG.exception("No validator were found with matching id: '%s'", request_validator_id)
+ return
+
+ if self.should_validate_request(validator) and (
+ missing_parameters := self._get_missing_required_parameters(method, request)
+ ):
+ message = f"Missing required request parameters: [{', '.join(missing_parameters)}]"
+ raise BadRequestParametersError(message=message)
+
+ if self.should_validate_body(validator) and not self._is_body_valid(
+ method, rest_api, request
+ ):
+ raise BadRequestBodyError(message="Invalid request body")
+
+ return
+
+ @staticmethod
+ def _is_body_valid(
+ method: Method, rest_api: RestApiContainer, request: InvocationRequest
+ ) -> bool:
+ # if there's no model to validate the body, use the Empty model
+ # https://docs.aws.amazon.com/cdk/api/v1/docs/@aws-cdk_aws-apigateway.EmptyModel.html
+ if not (request_models := method.get("requestModels")):
+ model_name = EMPTY_MODEL
+ else:
+ model_name = request_models.get(
+ APPLICATION_JSON, request_models.get("$default", EMPTY_MODEL)
+ )
+
+ model_resolver = ModelResolver(
+ rest_api_container=rest_api,
+ model_name=model_name,
+ )
+
+ # try to get the resolved model first
+ resolved_schema = model_resolver.get_resolved_model()
+ if not resolved_schema:
+ LOG.exception(
+ "An exception occurred while trying to validate the request: could not resolve the model '%s'",
+ model_name,
+ )
+ return False
+
+ try:
+ # if the body is empty, replace it with an empty JSON body
+ validate(
+ instance=json.loads(request.get("body") or "{}"),
+ schema=resolved_schema,
+ )
+ return True
+ except ValidationError as e:
+ LOG.debug("failed to validate request body %s", e)
+ return False
+ except json.JSONDecodeError as e:
+ LOG.debug("failed to validate request body, request data is not valid JSON %s", e)
+ return False
+
+ @staticmethod
+ def _get_missing_required_parameters(method: Method, request: InvocationRequest) -> list[str]:
+ missing_params = []
+ if not (request_parameters := method.get("requestParameters")):
+ return missing_params
+
+ case_sensitive_headers = list(request.get("headers").keys())
+
+ for request_parameter, required in sorted(request_parameters.items()):
+ if not required:
+ continue
+
+ param_type, param_value = request_parameter.removeprefix("method.request.").split(".")
+ match param_type:
+ case "header":
+ is_missing = param_value not in case_sensitive_headers
+ case "path":
+ path = request.get("path_parameters", "")
+ is_missing = param_value not in path
+ case "querystring":
+ is_missing = param_value not in request.get("query_string_parameters", [])
+ case _:
+ # This shouldn't happen
+ LOG.debug("Found an invalid request parameter: %s", request_parameter)
+ is_missing = False
+
+ if is_missing:
+ missing_params.append(param_value)
+
+ return missing_params
+
+ @staticmethod
+ def should_validate_body(validator):
+ return validator.get("validateRequestBody")
+
+ @staticmethod
+ def should_validate_request(validator):
+ return validator.get("validateRequestParameters")
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/method_response.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/method_response.py
new file mode 100644
index 0000000000000..004f99b98a4da
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/method_response.py
@@ -0,0 +1,96 @@
+import logging
+
+from werkzeug.datastructures import Headers
+
+from localstack.aws.api.apigateway import IntegrationType
+from localstack.http import Response
+
+from ..api import RestApiGatewayHandler, RestApiGatewayHandlerChain
+from ..context import InvocationResponse, RestApiInvocationContext
+from ..header_utils import drop_headers
+
+LOG = logging.getLogger(__name__)
+
+# These are dropped after the templates override were applied. they will never make it to the requests.
+DROPPED_FROM_INTEGRATION_RESPONSES_COMMON = ["Transfer-Encoding"]
+DROPPED_FROM_INTEGRATION_RESPONSES_HTTP_PROXY = [
+ *DROPPED_FROM_INTEGRATION_RESPONSES_COMMON,
+ "Content-Encoding",
+ "Via",
+]
+
+
+# Headers that will receive a remap
+REMAPPED_FROM_INTEGRATION_RESPONSE_COMMON = [
+ "Connection",
+ "Content-Length",
+ "Date",
+ "Server",
+]
+REMAPPED_FROM_INTEGRATION_RESPONSE_NON_PROXY = [
+ *REMAPPED_FROM_INTEGRATION_RESPONSE_COMMON,
+ "Authorization",
+ "Content-MD5",
+ "Expect",
+ "Host",
+ "Max-Forwards",
+ "Proxy-Authenticate",
+ "Trailer",
+ "Upgrade",
+ "User-Agent",
+ "WWW-Authenticate",
+]
+
+
+class MethodResponseHandler(RestApiGatewayHandler):
+ """
+ Last handler of the chain, responsible for serializing the Response object
+ """
+
+ def __call__(
+ self,
+ chain: RestApiGatewayHandlerChain,
+ context: RestApiInvocationContext,
+ response: Response,
+ ):
+ invocation_response = context.invocation_response
+ integration_type = context.integration["type"]
+ headers = invocation_response["headers"]
+
+ self._transform_headers(headers, integration_type)
+
+ method_response = self.serialize_invocation_response(invocation_response)
+ response.update_from(method_response)
+
+ @staticmethod
+ def serialize_invocation_response(invocation_response: InvocationResponse) -> Response:
+ is_content_type_set = invocation_response["headers"].get("content-type") is not None
+ response = Response(
+ response=invocation_response["body"],
+ headers=invocation_response["headers"],
+ status=invocation_response["status_code"],
+ )
+ if not is_content_type_set:
+ # Response sets a content-type by default. This will always be ignored.
+ response.headers.remove("content-type")
+ return response
+
+ @staticmethod
+ def _transform_headers(headers: Headers, integration_type: IntegrationType):
+ """Remaps the provided headers in-place. Adding new `x-amzn-Remapped-` headers and dropping the original headers"""
+ to_remap = REMAPPED_FROM_INTEGRATION_RESPONSE_COMMON
+ to_drop = DROPPED_FROM_INTEGRATION_RESPONSES_COMMON
+
+ match integration_type:
+ case IntegrationType.HTTP | IntegrationType.AWS:
+ to_remap = REMAPPED_FROM_INTEGRATION_RESPONSE_NON_PROXY
+ case IntegrationType.HTTP_PROXY:
+ to_drop = DROPPED_FROM_INTEGRATION_RESPONSES_HTTP_PROXY
+
+ for header in to_remap:
+ if headers.get(header):
+ LOG.debug("Remapping header: %s", header)
+ remapped = headers.pop(header)
+ headers[f"x-amzn-Remapped-{header}"] = remapped
+
+ drop_headers(headers, to_drop)
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/parse.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/parse.py
new file mode 100644
index 0000000000000..f4201ec2dc26f
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/parse.py
@@ -0,0 +1,193 @@
+import datetime
+import logging
+import re
+from collections import defaultdict
+from typing import Optional
+from urllib.parse import urlparse
+
+from rolo.request import restore_payload
+from werkzeug.datastructures import Headers, MultiDict
+
+from localstack.http import Response
+from localstack.services.apigateway.helpers import REQUEST_TIME_DATE_FORMAT
+from localstack.utils.strings import long_uid, short_uid
+from localstack.utils.time import timestamp
+
+from ..api import RestApiGatewayHandler, RestApiGatewayHandlerChain
+from ..context import InvocationRequest, RestApiInvocationContext
+from ..header_utils import should_drop_header_from_invocation
+from ..helpers import generate_trace_id, generate_trace_parent, parse_trace_id
+from ..moto_helpers import get_stage_variables
+from ..variables import ContextVariables, ContextVarsIdentity
+
+LOG = logging.getLogger(__name__)
+
+
+class InvocationRequestParser(RestApiGatewayHandler):
+ def __call__(
+ self,
+ chain: RestApiGatewayHandlerChain,
+ context: RestApiInvocationContext,
+ response: Response,
+ ):
+ context.account_id = context.deployment.account_id
+ context.region = context.deployment.region
+ self.parse_and_enrich(context)
+
+ def parse_and_enrich(self, context: RestApiInvocationContext):
+ # first, create the InvocationRequest with the incoming request
+ context.invocation_request = self.create_invocation_request(context)
+ # then we can create the ContextVariables, used throughout the invocation as payload and to render authorizer
+ # payload, mapping templates and such.
+ context.context_variables = self.create_context_variables(context)
+ # TODO: maybe adjust the logging
+ LOG.debug("Initializing $context='%s'", context.context_variables)
+ # then populate the stage variables
+ context.stage_variables = self.fetch_stage_variables(context)
+ LOG.debug("Initializing $stageVariables='%s'", context.stage_variables)
+
+ context.trace_id = self.populate_trace_id(context.request.headers)
+
+ def create_invocation_request(self, context: RestApiInvocationContext) -> InvocationRequest:
+ request = context.request
+ params, multi_value_params = self._get_single_and_multi_values_from_multidict(request.args)
+ headers = self._get_invocation_headers(request.headers)
+ invocation_request = InvocationRequest(
+ http_method=request.method,
+ query_string_parameters=params,
+ multi_value_query_string_parameters=multi_value_params,
+ headers=headers,
+ body=restore_payload(request),
+ )
+ self._enrich_with_raw_path(context, invocation_request)
+
+ return invocation_request
+
+ @staticmethod
+ def _enrich_with_raw_path(
+ context: RestApiInvocationContext, invocation_request: InvocationRequest
+ ):
+ # Base path is not URL-decoded, so we need to get the `RAW_URI` from the request
+ request = context.request
+ raw_uri = request.environ.get("RAW_URI") or request.path
+
+ # if the request comes from the LocalStack only `_user_request_` route, we need to remove this prefix from the
+ # path, in order to properly route the request
+ if "_user_request_" in raw_uri:
+ # in this format, the stage is before `_user_request_`, so we don't need to remove it
+ raw_uri = raw_uri.partition("_user_request_")[2]
+ else:
+ if raw_uri.startswith("/_aws/execute-api"):
+ # the API can be cased in the path, so we need to ignore it to remove it
+ raw_uri = re.sub(
+ f"^/_aws/execute-api/{context.api_id}",
+ "",
+ raw_uri,
+ flags=re.IGNORECASE,
+ )
+
+ # remove the stage from the path, only replace the first occurrence
+ raw_uri = raw_uri.replace(f"/{context.stage}", "", 1)
+
+ if raw_uri.startswith("//"):
+ # TODO: AWS validate this assumption
+ # if the RAW_URI starts with double slashes, `urlparse` will fail to decode it as path only
+ # it also means that we already only have the path, so we just need to remove the query string
+ raw_uri = raw_uri.split("?")[0]
+ raw_path = "/" + raw_uri.lstrip("/")
+
+ else:
+ # we need to make sure we have a path here, sometimes RAW_URI can be a full URI (when proxied)
+ raw_path = raw_uri = urlparse(raw_uri).path
+
+ invocation_request["path"] = raw_path
+ invocation_request["raw_path"] = raw_uri
+
+ @staticmethod
+ def _get_single_and_multi_values_from_multidict(
+ multi_dict: MultiDict,
+ ) -> tuple[dict[str, str], dict[str, list[str]]]:
+ single_values = {}
+ multi_values = defaultdict(list)
+
+ for key, value in multi_dict.items(multi=True):
+ multi_values[key].append(value)
+ # for the single value parameters, AWS only keeps the last value of the list
+ single_values[key] = value
+
+ return single_values, dict(multi_values)
+
+ @staticmethod
+ def _get_invocation_headers(headers: Headers) -> Headers:
+ invocation_headers = Headers()
+ for key, value in headers:
+ if should_drop_header_from_invocation(key):
+ LOG.debug("Dropping header from invocation request: '%s'", key)
+ continue
+ invocation_headers.add(key, value)
+ return invocation_headers
+
+ @staticmethod
+ def create_context_variables(context: RestApiInvocationContext) -> ContextVariables:
+ invocation_request: InvocationRequest = context.invocation_request
+ domain_name = invocation_request["headers"].get("Host", "")
+ domain_prefix = domain_name.split(".")[0]
+ now = datetime.datetime.now()
+
+ # TODO: verify which values needs to explicitly have None set
+ context_variables = ContextVariables(
+ accountId=context.account_id,
+ apiId=context.api_id,
+ deploymentId=context.deployment_id,
+ domainName=domain_name,
+ domainPrefix=domain_prefix,
+ extendedRequestId=short_uid(), # TODO: use snapshot tests to verify format
+ httpMethod=invocation_request["http_method"],
+ identity=ContextVarsIdentity(
+ accountId=None,
+ accessKey=None,
+ caller=None,
+ cognitoAuthenticationProvider=None,
+ cognitoAuthenticationType=None,
+ cognitoIdentityId=None,
+ cognitoIdentityPoolId=None,
+ principalOrgId=None,
+ sourceIp="127.0.0.1", # TODO: get the sourceIp from the Request
+ user=None,
+ userAgent=invocation_request["headers"].get("User-Agent"),
+ userArn=None,
+ ),
+ path=f"/{context.stage}{invocation_request['raw_path']}",
+ protocol="HTTP/1.1",
+ requestId=long_uid(),
+ requestTime=timestamp(time=now, format=REQUEST_TIME_DATE_FORMAT),
+ requestTimeEpoch=int(now.timestamp() * 1000),
+ stage=context.stage,
+ )
+ return context_variables
+
+ @staticmethod
+ def fetch_stage_variables(context: RestApiInvocationContext) -> Optional[dict[str, str]]:
+ stage_variables = get_stage_variables(
+ account_id=context.account_id,
+ region=context.region,
+ api_id=context.api_id,
+ stage_name=context.stage,
+ )
+ if not stage_variables:
+ # we need to set the stage variables to None in the context if we don't have at least one
+ return None
+
+ return stage_variables
+
+ @staticmethod
+ def populate_trace_id(headers: Headers) -> str:
+ incoming_trace = parse_trace_id(headers.get("x-amzn-trace-id", ""))
+ # parse_trace_id always return capitalized keys
+
+ trace = incoming_trace.get("Root", generate_trace_id())
+ incoming_parent = incoming_trace.get("Parent")
+ parent = incoming_parent or generate_trace_parent()
+ sampled = incoming_trace.get("Sampled", "1" if incoming_parent else "0")
+ # TODO: lineage? not sure what it related to
+ return f"Root={trace};Parent={parent};Sampled={sampled}"
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/resource_router.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/resource_router.py
new file mode 100644
index 0000000000000..c957e24fb00bd
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/resource_router.py
@@ -0,0 +1,170 @@
+import logging
+from functools import cache
+from http import HTTPMethod
+from typing import Iterable
+
+from werkzeug.exceptions import MethodNotAllowed, NotFound
+from werkzeug.routing import Map, MapAdapter, Rule
+
+from localstack.aws.api.apigateway import Resource
+from localstack.aws.protocol.routing import (
+ path_param_regex,
+ post_process_arg_name,
+ transform_path_params_to_rule_vars,
+)
+from localstack.http import Response
+from localstack.http.router import GreedyPathConverter
+from localstack.services.apigateway.models import RestApiDeployment
+
+from ..api import RestApiGatewayHandler, RestApiGatewayHandlerChain
+from ..context import RestApiInvocationContext
+from ..gateway_response import MissingAuthTokenError
+from ..variables import ContextVariables
+
+LOG = logging.getLogger(__name__)
+
+
+class ApiGatewayMethodRule(Rule):
+ """
+ Small extension to Werkzeug's Rule class which reverts unwanted assumptions made by Werkzeug.
+ Reverted assumptions:
+ - Werkzeug automatically matches HEAD requests to the corresponding GET request (i.e. Werkzeug's rule
+ automatically adds the HEAD HTTP method to a rule which should only match GET requests).
+ Added behavior:
+ - ANY is equivalent to 7 HTTP methods listed. We manually set them to the rule's methods
+ """
+
+ def __init__(self, string: str, method: str, **kwargs) -> None:
+ super().__init__(string=string, methods=[method], **kwargs)
+
+ if method == "ANY":
+ self.methods = {
+ HTTPMethod.DELETE,
+ HTTPMethod.GET,
+ HTTPMethod.HEAD,
+ HTTPMethod.OPTIONS,
+ HTTPMethod.PATCH,
+ HTTPMethod.POST,
+ HTTPMethod.PUT,
+ }
+ else:
+ # Make sure Werkzeug's Rule does not add any other methods
+ # (f.e. the HEAD method even though the rule should only match GET)
+ self.methods = {method.upper()}
+
+
+class RestAPIResourceRouter:
+ """
+ A router implementation which abstracts the routing of incoming REST API Context to a specific
+ resource of the Deployment.
+ """
+
+ _map: Map
+
+ def __init__(self, deployment: RestApiDeployment):
+ self._resources = deployment.rest_api.resources
+ self._map = get_rule_map_for_resources(self._resources.values())
+
+ def match(self, context: RestApiInvocationContext) -> tuple[Resource, dict[str, str]]:
+ """
+ Matches the given request to the resource it targets (or raises an exception if no resource matches).
+
+ :param context:
+ :return: A tuple with the matched resource and the (already parsed) path params
+ :raises: TODO: Gateway exception in case the given request does not match any operation
+ """
+
+ request = context.request
+ # bind the map to get the actual matcher
+ matcher: MapAdapter = self._map.bind(context.request.host)
+
+ # perform the matching
+ # trailing slashes are ignored in APIGW
+ path = context.invocation_request["path"].rstrip("/")
+ try:
+ rule, args = matcher.match(path, method=request.method, return_rule=True)
+ except (MethodNotAllowed, NotFound) as e:
+ # MethodNotAllowed (405) exception is raised if a path is matching, but the method does not.
+ # Our router might handle this as a 404, validate with AWS.
+ LOG.warning(
+ "API Gateway: No resource or method was found for: %s %s",
+ request.method,
+ path,
+ exc_info=LOG.isEnabledFor(logging.DEBUG),
+ )
+ raise MissingAuthTokenError("Missing Authentication Token") from e
+
+ # post process the arg keys and values
+ # - the path param keys need to be "un-sanitized", i.e. sanitized rule variable names need to be reverted
+ # - the path param values might still be url-encoded
+ args = {post_process_arg_name(k): v for k, v in args.items()}
+
+ # extract the operation model from the rule
+ resource_id: str = rule.endpoint
+ resource = self._resources[resource_id]
+
+ return resource, args
+
+
+class InvocationRequestRouter(RestApiGatewayHandler):
+ def __call__(
+ self,
+ chain: RestApiGatewayHandlerChain,
+ context: RestApiInvocationContext,
+ response: Response,
+ ):
+ self.route_and_enrich(context)
+
+ def route_and_enrich(self, context: RestApiInvocationContext):
+ router = self.get_router_for_deployment(context.deployment)
+
+ resource, path_parameters = router.match(context)
+ resource: Resource
+
+ context.invocation_request["path_parameters"] = path_parameters
+ context.resource = resource
+
+ method = (
+ resource["resourceMethods"].get(context.request.method)
+ or resource["resourceMethods"]["ANY"]
+ )
+ context.resource_method = method
+ context.integration = method["methodIntegration"]
+
+ self.update_context_variables_with_resource(context.context_variables, resource)
+
+ @staticmethod
+ def update_context_variables_with_resource(
+ context_variables: ContextVariables, resource: Resource
+ ):
+ LOG.debug("Updating $context.resourcePath='%s'", resource["path"])
+ context_variables["resourcePath"] = resource["path"]
+ LOG.debug("Updating $context.resourceId='%s'", resource["id"])
+ context_variables["resourceId"] = resource["id"]
+
+ @staticmethod
+ @cache
+ def get_router_for_deployment(deployment: RestApiDeployment) -> RestAPIResourceRouter:
+ return RestAPIResourceRouter(deployment)
+
+
+def get_rule_map_for_resources(resources: Iterable[Resource]) -> Map:
+ rules = []
+ for resource in resources:
+ for method, resource_method in resource.get("resourceMethods", {}).items():
+ path = resource["path"]
+ # translate the requestUri to a Werkzeug rule string
+ rule_string = path_param_regex.sub(transform_path_params_to_rule_vars, path)
+ rules.append(
+ ApiGatewayMethodRule(string=rule_string, method=method, endpoint=resource["id"])
+ ) # type: ignore
+
+ return Map(
+ rules=rules,
+ # don't be strict about trailing slashes when matching
+ strict_slashes=False,
+ # we can't really use werkzeug's merge-slashes since it uses HTTP redirects to solve it
+ merge_slashes=False,
+ # get service-specific converters
+ converters={"path": GreedyPathConverter},
+ )
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/response_enricher.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/response_enricher.py
new file mode 100644
index 0000000000000..8b6308e7e3d2c
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/handlers/response_enricher.py
@@ -0,0 +1,30 @@
+from localstack.aws.api.apigateway import IntegrationType
+from localstack.http import Response
+from localstack.services.apigateway.next_gen.execute_api.api import (
+ RestApiGatewayHandler,
+ RestApiGatewayHandlerChain,
+)
+from localstack.services.apigateway.next_gen.execute_api.context import RestApiInvocationContext
+from localstack.utils.strings import short_uid
+
+
+class InvocationResponseEnricher(RestApiGatewayHandler):
+ def __call__(
+ self,
+ chain: RestApiGatewayHandlerChain,
+ context: RestApiInvocationContext,
+ response: Response,
+ ):
+ headers = response.headers
+
+ headers.set("x-amzn-RequestId", context.context_variables["requestId"])
+
+ # Todo, as we go into monitoring, we will want to have these values come from the context?
+ headers.set("x-amz-apigw-id", short_uid() + "=")
+ if (
+ context.integration
+ and context.integration["type"]
+ not in (IntegrationType.HTTP_PROXY, IntegrationType.MOCK)
+ and not context.context_variables.get("error")
+ ):
+ headers.set("X-Amzn-Trace-Id", context.trace_id)
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/header_utils.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/header_utils.py
new file mode 100644
index 0000000000000..1b1fcbfa3f35a
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/header_utils.py
@@ -0,0 +1,56 @@
+import logging
+from collections import defaultdict
+from typing import Iterable
+
+from werkzeug.datastructures.headers import Headers
+
+LOG = logging.getLogger(__name__)
+
+# Headers dropped at the request parsing. They will never make it to the invocation requests.
+# And won't be available for request mapping.
+DROPPED_FROM_REQUEST_COMMON = [
+ "Connection",
+ "Content-Length",
+ "Content-MD5",
+ "Expect",
+ "Max-Forwards",
+ "Proxy-Authenticate",
+ "Server",
+ "TE",
+ "Transfer-Encoding",
+ "Trailer",
+ "Upgrade",
+ "WWW-Authenticate",
+]
+DROPPED_FROM_REQUEST_COMMON_LOWER = [header.lower() for header in DROPPED_FROM_REQUEST_COMMON]
+
+
+def should_drop_header_from_invocation(header: str) -> bool:
+ """These headers are not making it to the invocation requests. Even Proxy integrations are not sending them."""
+ return header.lower() in DROPPED_FROM_REQUEST_COMMON_LOWER
+
+
+def build_multi_value_headers(headers: Headers) -> dict[str, list[str]]:
+ multi_value_headers = defaultdict(list)
+ for key, value in headers:
+ multi_value_headers[key].append(value)
+
+ return multi_value_headers
+
+
+def drop_headers(headers: Headers, to_drop: Iterable[str]):
+ """Will modify the provided headers in-place. Dropping matching headers from the provided list"""
+ dropped_headers = []
+
+ for header in to_drop:
+ if headers.get(header):
+ headers.remove(header)
+ dropped_headers.append(header)
+
+ LOG.debug("Dropping headers: %s", dropped_headers)
+
+
+def set_default_headers(headers: Headers, default_headers: dict[str, str]):
+ for header, value in default_headers.items():
+ if not headers.get(header):
+ headers.set(header, value)
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/helpers.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/helpers.py
new file mode 100644
index 0000000000000..117fbd9f9078c
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/helpers.py
@@ -0,0 +1,150 @@
+import copy
+import logging
+import re
+import time
+from secrets import token_hex
+from typing import Type, TypedDict
+
+from moto.apigateway.models import RestAPI as MotoRestAPI
+
+from localstack.services.apigateway.models import MergedRestApi, RestApiContainer, RestApiDeployment
+from localstack.utils.aws.arns import get_partition
+
+from .context import RestApiInvocationContext
+from .moto_helpers import get_resources_from_moto_rest_api
+
+LOG = logging.getLogger(__name__)
+
+_stage_variable_pattern = re.compile(r"\${stageVariables\.(?P.*?)}")
+
+
+def freeze_rest_api(
+ account_id: str, region: str, moto_rest_api: MotoRestAPI, localstack_rest_api: RestApiContainer
+) -> RestApiDeployment:
+ """
+ Snapshot a REST API in time to create a deployment
+ This will merge the Moto and LocalStack data into one `MergedRestApi`
+ """
+ moto_resources = get_resources_from_moto_rest_api(moto_rest_api)
+
+ rest_api = MergedRestApi.from_rest_api_container(
+ rest_api_container=localstack_rest_api,
+ resources=moto_resources,
+ )
+
+ return RestApiDeployment(
+ account_id=account_id,
+ region=region,
+ rest_api=copy.deepcopy(rest_api),
+ )
+
+
+def render_uri_with_stage_variables(
+ uri: str | None, stage_variables: dict[str, str] | None
+) -> str | None:
+ """
+ https://docs.aws.amazon.com/apigateway/latest/developerguide/aws-api-gateway-stage-variables-reference.html#stage-variables-in-integration-HTTP-uris
+ URI=https://${stageVariables.}
+ This format is the same as VTL, but we're using a simplified version to only replace `${stageVariables. }`
+ values, as AWS will ignore `${path}` for example
+ """
+ if not uri:
+ return uri
+ stage_vars = stage_variables or {}
+
+ def replace_match(match_obj: re.Match) -> str:
+ return stage_vars.get(match_obj.group("varName"), "")
+
+ return _stage_variable_pattern.sub(replace_match, uri)
+
+
+def render_uri_with_path_parameters(uri: str | None, path_parameters: dict[str, str]) -> str | None:
+ if not uri:
+ return uri
+
+ for key, value in path_parameters.items():
+ uri = uri.replace(f"{{{key}}}", value)
+
+ return uri
+
+
+def render_integration_uri(
+ uri: str | None, path_parameters: dict[str, str], stage_variables: dict[str, str]
+) -> str:
+ """
+ A URI can contain different value to interpolate / render
+ It will have path parameters substitutions with this shape (can also add a querystring).
+ URI=http://myhost.test/rootpath/{path}
+
+ It can also have another format, for stage variables, documented here:
+ https://docs.aws.amazon.com/apigateway/latest/developerguide/aws-api-gateway-stage-variables-reference.html#stage-variables-in-integration-HTTP-uris
+ URI=https://${stageVariables.}
+ This format is the same as VTL.
+
+ :param uri: the integration URI
+ :param path_parameters: the list of path parameters, coming from the parameters mapping and override
+ :param stage_variables: -
+ :return: the rendered URI
+ """
+ if not uri:
+ return ""
+
+ uri_with_path = render_uri_with_path_parameters(uri, path_parameters)
+ return render_uri_with_stage_variables(uri_with_path, stage_variables)
+
+
+def get_source_arn(context: RestApiInvocationContext):
+ method = context.resource_method["httpMethod"]
+ path = context.resource["path"]
+ return (
+ f"arn:{get_partition(context.region)}:execute-api"
+ f":{context.region}"
+ f":{context.account_id}"
+ f":{context.api_id}"
+ f"/{context.stage}/{method}{path}"
+ )
+
+
+def get_lambda_function_arn_from_invocation_uri(uri: str) -> str:
+ """
+ "arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/arn:aws:lambda:us-east-1:123456789012:function:SimpleLambda4ProxyResource/invocations",
+ :param uri: the integration URI value for a lambda function
+ :return: the lambda function ARN
+ """
+ return uri.split("functions/")[1].removesuffix("/invocations")
+
+
+def validate_sub_dict_of_typed_dict(typed_dict: Type[TypedDict], obj: dict) -> bool:
+ """
+ Validate that the object is a subset off the keys of a given `TypedDict`.
+ :param typed_dict: the `TypedDict` blueprint
+ :param obj: the object to validate
+ :return: True if it is a subset, False otherwise
+ """
+ typed_dict_keys = {*typed_dict.__required_keys__, *typed_dict.__optional_keys__}
+
+ return not bool(set(obj) - typed_dict_keys)
+
+
+def generate_trace_id():
+ """https://docs.aws.amazon.com/xray/latest/devguide/xray-api-sendingdata.html#xray-api-traceids"""
+ original_request_epoch = int(time.time())
+ timestamp_hex = hex(original_request_epoch)[2:]
+ version_number = "1"
+ unique_id = token_hex(12)
+ return f"{version_number}-{timestamp_hex}-{unique_id}"
+
+
+def generate_trace_parent():
+ return token_hex(8)
+
+
+def parse_trace_id(trace_id: str) -> dict[str, str]:
+ split_trace = trace_id.split(";")
+ trace_values = {}
+ for trace_part in split_trace:
+ key_value = trace_part.split("=")
+ if len(key_value) == 2:
+ trace_values[key_value[0].capitalize()] = key_value[1]
+
+ return trace_values
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/__init__.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/__init__.py
new file mode 100644
index 0000000000000..7900965784631
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/__init__.py
@@ -0,0 +1,15 @@
+from .aws import RestApiAwsIntegration, RestApiAwsProxyIntegration
+from .http import RestApiHttpIntegration, RestApiHttpProxyIntegration
+from .mock import RestApiMockIntegration
+
+REST_API_INTEGRATIONS = {
+ RestApiAwsIntegration.name: RestApiAwsIntegration(),
+ RestApiAwsProxyIntegration.name: RestApiAwsProxyIntegration(),
+ RestApiHttpIntegration.name: RestApiHttpIntegration(),
+ RestApiHttpProxyIntegration.name: RestApiHttpProxyIntegration(),
+ RestApiMockIntegration.name: RestApiMockIntegration(),
+}
+
+__all__ = [
+ "REST_API_INTEGRATIONS",
+]
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/aws.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/aws.py
new file mode 100644
index 0000000000000..5bc2474d386ca
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/aws.py
@@ -0,0 +1,569 @@
+import base64
+import json
+import logging
+from functools import lru_cache
+from http import HTTPMethod
+from typing import Literal, Optional, TypedDict
+from urllib.parse import urlparse
+
+import requests
+from botocore.exceptions import ClientError
+from werkzeug.datastructures import Headers
+
+from localstack import config
+from localstack.aws.connect import (
+ INTERNAL_REQUEST_PARAMS_HEADER,
+ InternalRequestParameters,
+ connect_to,
+ dump_dto,
+)
+from localstack.aws.protocol.service_router import get_service_catalog
+from localstack.constants import APPLICATION_JSON, INTERNAL_AWS_ACCESS_KEY_ID
+from localstack.utils.aws.arns import extract_region_from_arn
+from localstack.utils.aws.client_types import ServicePrincipal
+from localstack.utils.strings import to_bytes, to_str
+
+from ..context import (
+ EndpointResponse,
+ IntegrationRequest,
+ InvocationRequest,
+ RestApiInvocationContext,
+)
+from ..gateway_response import IntegrationFailureError, InternalServerError
+from ..header_utils import build_multi_value_headers
+from ..helpers import (
+ get_lambda_function_arn_from_invocation_uri,
+ get_source_arn,
+ render_uri_with_stage_variables,
+ validate_sub_dict_of_typed_dict,
+)
+from ..variables import ContextVariables
+from .core import RestApiIntegration
+
+LOG = logging.getLogger(__name__)
+
+NO_BODY_METHODS = {
+ HTTPMethod.OPTIONS,
+ HTTPMethod.GET,
+ HTTPMethod.HEAD,
+}
+
+
+class LambdaProxyResponse(TypedDict, total=False):
+ body: Optional[str]
+ statusCode: Optional[int | str]
+ headers: Optional[dict[str, str]]
+ isBase64Encoded: Optional[bool]
+ multiValueHeaders: Optional[dict[str, list[str]]]
+
+
+class LambdaInputEvent(TypedDict, total=False):
+ body: str
+ isBase64Encoded: bool
+ httpMethod: str | HTTPMethod
+ resource: str
+ path: str
+ headers: dict[str, str]
+ multiValueHeaders: dict[str, list[str]]
+ queryStringParameters: dict[str, str]
+ multiValueQueryStringParameters: dict[str, list[str]]
+ requestContext: ContextVariables
+ pathParameters: dict[str, str]
+ stageVariables: dict[str, str]
+
+
+class ParsedAwsIntegrationUri(TypedDict):
+ service_name: str
+ region_name: str
+ action_type: Literal["path", "action"]
+ path: str
+
+
+@lru_cache(maxsize=64)
+def get_service_factory(region_name: str, role_arn: str):
+ if role_arn:
+ return connect_to.with_assumed_role(
+ role_arn=role_arn,
+ region_name=region_name,
+ service_principal=ServicePrincipal.apigateway,
+ session_name="BackplaneAssumeRoleSession",
+ )
+ else:
+ return connect_to(region_name=region_name)
+
+
+@lru_cache(maxsize=64)
+def get_internal_mocked_headers(
+ service_name: str,
+ region_name: str,
+ source_arn: str,
+ role_arn: str | None,
+) -> dict[str, str]:
+ if role_arn:
+ access_key_id = (
+ connect_to()
+ .sts.request_metadata(service_principal=ServicePrincipal.apigateway)
+ .assume_role(RoleArn=role_arn, RoleSessionName="BackplaneAssumeRoleSession")[
+ "Credentials"
+ ]["AccessKeyId"]
+ )
+ else:
+ access_key_id = INTERNAL_AWS_ACCESS_KEY_ID
+
+ dto = InternalRequestParameters(
+ service_principal=ServicePrincipal.apigateway, source_arn=source_arn
+ )
+ # TODO: maybe use the localstack.utils.aws.client.SigningHttpClient instead of directly mocking the Authorization
+ # header (but will need to select the right signer depending on the service?)
+ headers = {
+ "Authorization": (
+ "AWS4-HMAC-SHA256 "
+ + f"Credential={access_key_id}/20160623/{region_name}/{service_name}/aws4_request, "
+ + "SignedHeaders=content-type;host;x-amz-date;x-amz-target, Signature=1234"
+ ),
+ INTERNAL_REQUEST_PARAMS_HEADER: dump_dto(dto),
+ }
+
+ return headers
+
+
+@lru_cache(maxsize=64)
+def get_target_prefix_for_service(service_name: str) -> str | None:
+ return get_service_catalog().get(service_name).metadata.get("targetPrefix")
+
+
+class RestApiAwsIntegration(RestApiIntegration):
+ """
+ This is a REST API integration responsible to directly interact with AWS services. It uses the `uri` to
+ map the incoming request to the concerned AWS service, and can have 2 types.
+ - `path`: the request is targeting the direct URI of the AWS service, like you would with an HTTP client
+ example: For S3 GetObject call: arn:aws:apigateway:us-west-2:s3:path/{bucket}/{key}
+ - `action`: this is a simpler way, where you can pass the request parameters like you would do with an SDK, and you
+ can specify the service action (for ex. here S3 `GetObject`). It seems the request parameters can be pass as query
+ string parameters, JSON body and maybe more. TODO: verify, 2 documentation pages indicates divergent information.
+ (one indicates parameters through QS, one through request body)
+ example: arn:aws:apigateway:us-west-2:s3:action/GetObject&Bucket={bucket}&Key={key}
+
+ https://docs.aws.amazon.com/apigateway/latest/developerguide/integration-request-basic-setup.html
+
+
+ TODO: it seems we can global AWS integration type, we should not need to subclass for each service
+ we just need to separate usage between the `path` URI type and the `action` URI type.
+ - `path`, we can simply pass along the full rendered request along with specific `mocked` AWS headers
+ that are dependant of the service (retrieving for the ARN in the uri)
+ - `action`, we might need either a full Boto call or use the Boto request serializer, as it seems the request
+ parameters are expected as parameters
+ """
+
+ name = "AWS"
+
+ # TODO: it seems in AWS, you don't need to manually set the `X-Amz-Target` header when using the `action` type.
+ # for now, we know `events` needs the user to manually add the header, but Kinesis and DynamoDB don't.
+ # Maybe reverse the list to exclude instead of include.
+ SERVICES_AUTO_TARGET = ["dynamodb", "kinesis", "ssm", "stepfunctions"]
+
+ # TODO: some services still target the Query protocol (validated with AWS), even though SSM for example is JSON for
+ # as long as the Boto SDK exists. We will need to emulate the Query protocol and translate it to JSON
+ SERVICES_LEGACY_QUERY_PROTOCOL = ["ssm"]
+
+ SERVICE_MAP = {
+ "states": "stepfunctions",
+ }
+
+ def __init__(self):
+ self._base_domain = config.internal_service_url()
+ self._base_host = ""
+ self._service_names = get_service_catalog().service_names
+
+ def invoke(self, context: RestApiInvocationContext) -> EndpointResponse:
+ integration_req: IntegrationRequest = context.integration_request
+ method = integration_req["http_method"]
+ parsed_uri = self.parse_aws_integration_uri(integration_req["uri"])
+ service_name = parsed_uri["service_name"]
+ integration_region = parsed_uri["region_name"]
+
+ if credentials := context.integration.get("credentials"):
+ credentials = render_uri_with_stage_variables(credentials, context.stage_variables)
+
+ headers = integration_req["headers"]
+ # Some integrations will use a special format for the service in the URI, like AppSync, and so those requests
+ # are not directed to a service directly, so need to add the Authorization header. It would fail parsing
+ # by our service name parser anyway
+ if service_name in self._service_names:
+ headers.update(
+ get_internal_mocked_headers(
+ service_name=service_name,
+ region_name=integration_region,
+ source_arn=get_source_arn(context),
+ role_arn=credentials,
+ )
+ )
+ query_params = integration_req["query_string_parameters"].copy()
+ data = integration_req["body"]
+
+ if parsed_uri["action_type"] == "path":
+ # the Path action type allows you to override the path the request is sent to, like you would send to AWS
+ path = f"/{parsed_uri['path']}"
+ else:
+ # Action passes the `Action` query string parameter
+ path = ""
+ action = parsed_uri["path"]
+
+ if target := self.get_action_service_target(service_name, action):
+ headers["X-Amz-Target"] = target
+
+ query_params["Action"] = action
+
+ if service_name in self.SERVICES_LEGACY_QUERY_PROTOCOL:
+ # this has been tested in AWS: for `ssm`, it fully overrides the body because SSM uses the Query
+ # protocol, so we simulate it that way
+ data = self.get_payload_from_query_string(query_params)
+
+ url = f"{self._base_domain}{path}"
+ headers["Host"] = self.get_internal_host_for_service(
+ service_name=service_name, region_name=integration_region
+ )
+
+ request_parameters = {
+ "method": method,
+ "url": url,
+ "params": query_params,
+ "headers": headers,
+ }
+
+ if method not in NO_BODY_METHODS:
+ request_parameters["data"] = data
+
+ request_response = requests.request(**request_parameters)
+ response_content = request_response.content
+
+ if (
+ parsed_uri["action_type"] == "action"
+ and service_name in self.SERVICES_LEGACY_QUERY_PROTOCOL
+ ):
+ response_content = self.format_response_content_legacy(
+ payload=response_content,
+ service_name=service_name,
+ action=parsed_uri["path"],
+ request_id=context.context_variables["requestId"],
+ )
+
+ return EndpointResponse(
+ body=response_content,
+ status_code=request_response.status_code,
+ headers=Headers(dict(request_response.headers)),
+ )
+
+ def parse_aws_integration_uri(self, uri: str) -> ParsedAwsIntegrationUri:
+ """
+ The URI can be of 2 shapes: Path or Action.
+ Path : arn:aws:apigateway:us-west-2:s3:path/{bucket}/{key}
+ Action: arn:aws:apigateway:us-east-1:kinesis:action/PutRecord
+ :param uri: the URI of the AWS integration
+ :return: a ParsedAwsIntegrationUri containing the service name, the region and the type of action
+ """
+ arn, _, path = uri.partition("/")
+ split_arn = arn.split(":", maxsplit=5)
+ *_, region_name, service_name, action_type = split_arn
+ boto_service_name = self.SERVICE_MAP.get(service_name, service_name)
+ return ParsedAwsIntegrationUri(
+ region_name=region_name,
+ service_name=boto_service_name,
+ action_type=action_type,
+ path=path,
+ )
+
+ def get_action_service_target(self, service_name: str, action: str) -> str | None:
+ if service_name not in self.SERVICES_AUTO_TARGET:
+ return None
+
+ target_prefix = get_target_prefix_for_service(service_name)
+ if not target_prefix:
+ return None
+
+ return f"{target_prefix}.{action}"
+
+ def get_internal_host_for_service(self, service_name: str, region_name: str):
+ url = self._base_domain
+ if service_name == "sqs":
+ # This follow the new SQS_ENDPOINT_STRATEGY=standard
+ url = config.external_service_url(subdomains=f"sqs.{region_name}")
+ elif "-api" in service_name:
+ # this could be an `.-api`, used by some services
+ url = config.external_service_url(subdomains=service_name)
+
+ return urlparse(url).netloc
+
+ @staticmethod
+ def get_payload_from_query_string(query_string_parameters: dict) -> str:
+ return json.dumps(query_string_parameters)
+
+ @staticmethod
+ def format_response_content_legacy(
+ service_name: str, action: str, payload: bytes, request_id: str
+ ) -> bytes:
+ # TODO: not sure how much we need to support this, this supports SSM for now, once we write more tests for
+ # `action` type, see if we can generalize more
+ data = json.loads(payload)
+ try:
+ # we try to populate the missing fields from the OperationModel of the operation
+ operation_model = get_service_catalog().get(service_name).operation_model(action)
+ for key in operation_model.output_shape.members:
+ if key not in data:
+ data[key] = None
+
+ except Exception:
+ # the operation above is only for parity reason, skips if it fails
+ pass
+
+ wrapped = {
+ f"{action}Response": {
+ f"{action}Result": data,
+ "ResponseMetadata": {
+ "RequestId": request_id,
+ },
+ }
+ }
+ return to_bytes(json.dumps(wrapped))
+
+
+class RestApiAwsProxyIntegration(RestApiIntegration):
+ """
+ This is a custom, simplified REST API integration focused only on the Lambda service, with minimal modification from
+ API Gateway. It passes the incoming request almost as is, in a custom created event payload, to the configured
+ Lambda function.
+
+ https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html
+ """
+
+ name = "AWS_PROXY"
+
+ def invoke(self, context: RestApiInvocationContext) -> EndpointResponse:
+ integration_req: IntegrationRequest = context.integration_request
+ method = integration_req["http_method"]
+
+ if method != HTTPMethod.POST:
+ LOG.warning(
+ "The 'AWS_PROXY' integration can only be used with the POST integration method.",
+ )
+ raise IntegrationFailureError("Internal server error")
+
+ input_event = self.create_lambda_input_event(context)
+
+ # TODO: verify stage variables rendering in AWS_PROXY
+ integration_uri = integration_req["uri"]
+
+ function_arn = get_lambda_function_arn_from_invocation_uri(integration_uri)
+ source_arn = get_source_arn(context)
+
+ # TODO: write test for credentials rendering
+ if credentials := context.integration.get("credentials"):
+ credentials = render_uri_with_stage_variables(credentials, context.stage_variables)
+
+ try:
+ lambda_payload = self.call_lambda(
+ function_arn=function_arn,
+ event=to_bytes(json.dumps(input_event)),
+ source_arn=source_arn,
+ credentials=credentials,
+ )
+
+ except ClientError as e:
+ LOG.warning(
+ "Exception during integration invocation: '%s'",
+ e,
+ )
+ status_code = 502
+ if e.response["Error"]["Code"] == "AccessDeniedException":
+ status_code = 500
+ raise IntegrationFailureError("Internal server error", status_code=status_code) from e
+
+ except Exception as e:
+ LOG.warning(
+ "Unexpected exception during integration invocation: '%s'",
+ e,
+ )
+ raise IntegrationFailureError("Internal server error", status_code=502) from e
+
+ lambda_response = self.parse_lambda_response(lambda_payload)
+
+ headers = Headers({"Content-Type": APPLICATION_JSON})
+
+ response_headers = self._merge_lambda_response_headers(lambda_response)
+ headers.update(response_headers)
+
+ return EndpointResponse(
+ headers=headers,
+ body=to_bytes(lambda_response.get("body") or ""),
+ status_code=int(lambda_response.get("statusCode") or 200),
+ )
+
+ @staticmethod
+ def call_lambda(
+ function_arn: str,
+ event: bytes,
+ source_arn: str,
+ credentials: str = None,
+ ) -> bytes:
+ lambda_client = get_service_factory(
+ region_name=extract_region_from_arn(function_arn),
+ role_arn=credentials,
+ ).lambda_
+ inv_result = lambda_client.request_metadata(
+ service_principal=ServicePrincipal.apigateway,
+ source_arn=source_arn,
+ ).invoke(
+ FunctionName=function_arn,
+ Payload=event,
+ InvocationType="RequestResponse",
+ )
+ if payload := inv_result.get("Payload"):
+ return payload.read()
+ return b""
+
+ def parse_lambda_response(self, payload: bytes) -> LambdaProxyResponse:
+ try:
+ lambda_response = json.loads(payload)
+ except json.JSONDecodeError:
+ LOG.warning(
+ 'Lambda output should follow the next JSON format: { "isBase64Encoded": true|false, "statusCode": httpStatusCode, "headers": { "headerName": "headerValue", ... },"body": "..."} but was: %s',
+ payload,
+ )
+ LOG.debug(
+ "Execution failed due to configuration error: Malformed Lambda proxy response"
+ )
+ raise InternalServerError("Internal server error", status_code=502)
+
+ # none of the lambda response fields are mandatory, but you cannot return any other fields
+ if not self._is_lambda_response_valid(lambda_response):
+ if "errorMessage" in lambda_response:
+ LOG.debug(
+ "Lambda execution failed with status 200 due to customer function error: %s. Lambda request id: %s",
+ lambda_response["errorMessage"],
+ lambda_response.get("requestId", ""),
+ )
+ else:
+ LOG.warning(
+ 'Lambda output should follow the next JSON format: { "isBase64Encoded": true|false, "statusCode": httpStatusCode, "headers": { "headerName": "headerValue", ... },"body": "..."} but was: %s',
+ payload,
+ )
+ LOG.debug(
+ "Execution failed due to configuration error: Malformed Lambda proxy response"
+ )
+ raise InternalServerError("Internal server error", status_code=502)
+
+ def serialize_header(value: bool | str) -> str:
+ if isinstance(value, bool):
+ return "true" if value else "false"
+ return value
+
+ if headers := lambda_response.get("headers"):
+ lambda_response["headers"] = {k: serialize_header(v) for k, v in headers.items()}
+
+ if multi_value_headers := lambda_response.get("multiValueHeaders"):
+ lambda_response["multiValueHeaders"] = {
+ k: [serialize_header(v) for v in values]
+ for k, values in multi_value_headers.items()
+ }
+
+ return lambda_response
+
+ @staticmethod
+ def _is_lambda_response_valid(lambda_response: dict) -> bool:
+ if not isinstance(lambda_response, dict):
+ return False
+
+ if not validate_sub_dict_of_typed_dict(LambdaProxyResponse, lambda_response):
+ return False
+
+ if (headers := lambda_response.get("headers")) is not None:
+ if not isinstance(headers, dict):
+ return False
+ if any(not isinstance(header_value, (str, bool)) for header_value in headers.values()):
+ return False
+
+ if (multi_value_headers := lambda_response.get("multiValueHeaders")) is not None:
+ if not isinstance(multi_value_headers, dict):
+ return False
+ if any(
+ not isinstance(header_value, list) for header_value in multi_value_headers.values()
+ ):
+ return False
+
+ if "statusCode" in lambda_response:
+ try:
+ int(lambda_response["statusCode"])
+ except ValueError:
+ return False
+
+ # TODO: add more validations of the values' type
+ return True
+
+ def create_lambda_input_event(self, context: RestApiInvocationContext) -> LambdaInputEvent:
+ # https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#api-gateway-simple-proxy-for-lambda-input-format
+ # for building the Lambda Payload, we need access to the Invocation Request, as some data is not available in
+ # the integration request and does not make sense for it
+ invocation_req: InvocationRequest = context.invocation_request
+ integration_req: IntegrationRequest = context.integration_request
+
+ # TODO: binary support of APIGW
+ body, is_b64_encoded = self._format_body(integration_req["body"])
+
+ input_event = LambdaInputEvent(
+ headers=self._format_headers(dict(integration_req["headers"])),
+ multiValueHeaders=self._format_headers(
+ build_multi_value_headers(integration_req["headers"])
+ ),
+ body=body or None,
+ isBase64Encoded=is_b64_encoded,
+ requestContext=context.context_variables,
+ stageVariables=context.stage_variables,
+ # still using the InvocationRequest query string parameters as the logic is the same, maybe refactor?
+ queryStringParameters=invocation_req["query_string_parameters"] or None,
+ multiValueQueryStringParameters=invocation_req["multi_value_query_string_parameters"]
+ or None,
+ pathParameters=invocation_req["path_parameters"] or None,
+ httpMethod=invocation_req["http_method"],
+ path=invocation_req["path"],
+ resource=context.resource["path"],
+ )
+
+ return input_event
+
+ @staticmethod
+ def _format_headers(headers: dict[str, str | list[str]]) -> dict[str, str | list[str]]:
+ # Some headers get capitalized like in CloudFront, see
+ # https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/add-origin-custom-headers.html#add-origin-custom-headers-forward-authorization
+ # It seems AWS_PROXY lambda integrations are behind CloudFront, as seen by the returned headers in AWS
+ to_capitalize: list[str] = ["authorization", "user-agent"] # some headers get capitalized
+ to_filter: list[str] = ["content-length", "connection"]
+ headers = {
+ k.title() if k.lower() in to_capitalize else k: v
+ for k, v in headers.items()
+ if k.lower() not in to_filter
+ }
+
+ return headers
+
+ @staticmethod
+ def _format_body(body: bytes) -> tuple[str, bool]:
+ try:
+ return body.decode("utf-8"), False
+ except UnicodeDecodeError:
+ return to_str(base64.b64encode(body)), True
+
+ @staticmethod
+ def _merge_lambda_response_headers(lambda_response: LambdaProxyResponse) -> dict:
+ headers = lambda_response.get("headers") or {}
+
+ if multi_value_headers := lambda_response.get("multiValueHeaders"):
+ # multiValueHeaders has the priority and will decide the casing of the final headers, as they are merged
+ headers_low_keys = {k.lower(): v for k, v in headers.items()}
+
+ for k, values in multi_value_headers.items():
+ if (k_lower := k.lower()) in headers_low_keys:
+ headers[k] = [*values, headers_low_keys[k_lower]]
+ else:
+ headers[k] = values
+
+ return headers
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/core.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/core.py
new file mode 100644
index 0000000000000..c65b1a9539d7f
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/core.py
@@ -0,0 +1,19 @@
+from abc import abstractmethod
+
+from ..api import RestApiInvocationContext
+from ..context import EndpointResponse
+
+
+class RestApiIntegration:
+ """
+ This REST API Integration exposes an API to invoke the specific Integration with a common interface.
+
+ https://docs.aws.amazon.com/apigateway/latest/developerguide/how-to-integration-settings.html
+ TODO: Add more abstractmethods when starting to work on the Integration handler
+ """
+
+ name: str
+
+ @abstractmethod
+ def invoke(self, context: RestApiInvocationContext) -> EndpointResponse:
+ pass
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/http.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/http.py
new file mode 100644
index 0000000000000..fa0511072c9d1
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/http.py
@@ -0,0 +1,147 @@
+import logging
+from http import HTTPMethod
+from typing import Optional, TypedDict
+
+import requests
+from werkzeug.datastructures import Headers
+
+from localstack.aws.api.apigateway import Integration
+
+from ..context import EndpointResponse, IntegrationRequest, RestApiInvocationContext
+from ..gateway_response import ApiConfigurationError, IntegrationFailureError
+from ..header_utils import build_multi_value_headers
+from .core import RestApiIntegration
+
+LOG = logging.getLogger(__name__)
+
+NO_BODY_METHODS = {HTTPMethod.OPTIONS, HTTPMethod.GET, HTTPMethod.HEAD}
+
+
+class SimpleHttpRequest(TypedDict, total=False):
+ method: HTTPMethod | str
+ url: str
+ params: Optional[dict[str, str | list[str]]]
+ data: bytes
+ headers: Optional[dict[str, str]]
+ cookies: Optional[dict[str, str]]
+ timeout: Optional[int]
+ allow_redirects: Optional[bool]
+ stream: Optional[bool]
+ verify: Optional[bool]
+ # TODO: check if there was a situation where we'd pass certs?
+ cert: Optional[str | tuple[str, str]]
+
+
+class BaseRestApiHttpIntegration(RestApiIntegration):
+ @staticmethod
+ def _get_integration_timeout(integration: Integration) -> float:
+ return int(integration.get("timeoutInMillis", 29000)) / 1000
+
+
+class RestApiHttpIntegration(BaseRestApiHttpIntegration):
+ """
+ This is a REST API integration responsible to send a request to another HTTP API.
+ https://docs.aws.amazon.com/apigateway/latest/developerguide/setup-http-integrations.html#api-gateway-set-up-http-proxy-integration-on-proxy-resource
+ """
+
+ name = "HTTP"
+
+ def invoke(self, context: RestApiInvocationContext) -> EndpointResponse:
+ integration_req: IntegrationRequest = context.integration_request
+ method = integration_req["http_method"]
+ uri = integration_req["uri"]
+
+ request_parameters: SimpleHttpRequest = {
+ "method": method,
+ "url": uri,
+ "params": integration_req["query_string_parameters"],
+ "headers": integration_req["headers"],
+ }
+
+ if method not in NO_BODY_METHODS:
+ request_parameters["data"] = integration_req["body"]
+
+ # TODO: configurable timeout (29 by default) (check type and default value in provider)
+ # integration: Integration = context.resource_method["methodIntegration"]
+ # request_parameters["timeout"] = self._get_integration_timeout(integration)
+ # TODO: check for redirects
+ # request_parameters["allow_redirects"] = False
+ try:
+ request_response = requests.request(**request_parameters)
+
+ except (requests.exceptions.InvalidURL, requests.exceptions.InvalidSchema) as e:
+ LOG.warning("Execution failed due to configuration error: Invalid endpoint address")
+ LOG.debug("The URI specified for the HTTP/HTTP_PROXY integration is invalid: %s", uri)
+ raise ApiConfigurationError("Internal server error") from e
+
+ except (requests.exceptions.Timeout, requests.exceptions.SSLError) as e:
+ # TODO make the exception catching more fine grained
+ # this can be reproduced in AWS if you try to hit an HTTP endpoint which is HTTPS only like lambda URL
+ LOG.warning("Execution failed due to a network error communicating with endpoint")
+ raise IntegrationFailureError("Network error communicating with endpoint") from e
+
+ except requests.exceptions.ConnectionError as e:
+ raise ApiConfigurationError("Internal server error") from e
+
+ return EndpointResponse(
+ body=request_response.content,
+ status_code=request_response.status_code,
+ headers=Headers(dict(request_response.headers)),
+ )
+
+
+class RestApiHttpProxyIntegration(BaseRestApiHttpIntegration):
+ """
+ This is a simplified REST API integration responsible to send a request to another HTTP API by proxying it almost
+ directly.
+ https://docs.aws.amazon.com/apigateway/latest/developerguide/setup-http-integrations.html#api-gateway-set-up-http-proxy-integration-on-proxy-resource
+ """
+
+ name = "HTTP_PROXY"
+
+ def invoke(self, context: RestApiInvocationContext) -> EndpointResponse:
+ integration_req: IntegrationRequest = context.integration_request
+ method = integration_req["http_method"]
+ uri = integration_req["uri"]
+
+ multi_value_headers = build_multi_value_headers(integration_req["headers"])
+ request_headers = {key: ",".join(value) for key, value in multi_value_headers.items()}
+
+ request_parameters: SimpleHttpRequest = {
+ "method": method,
+ "url": uri,
+ "params": integration_req["query_string_parameters"],
+ "headers": request_headers,
+ }
+
+ # TODO: validate this for HTTP_PROXY
+ if method not in NO_BODY_METHODS:
+ request_parameters["data"] = integration_req["body"]
+
+ # TODO: configurable timeout (29 by default) (check type and default value in provider)
+ # integration: Integration = context.resource_method["methodIntegration"]
+ # request_parameters["timeout"] = self._get_integration_timeout(integration)
+ try:
+ request_response = requests.request(**request_parameters)
+
+ except (requests.exceptions.InvalidURL, requests.exceptions.InvalidSchema) as e:
+ LOG.warning("Execution failed due to configuration error: Invalid endpoint address")
+ LOG.debug("The URI specified for the HTTP/HTTP_PROXY integration is invalid: %s", uri)
+ raise ApiConfigurationError("Internal server error") from e
+
+ except (requests.exceptions.Timeout, requests.exceptions.SSLError):
+ # TODO make the exception catching more fine grained
+ # this can be reproduced in AWS if you try to hit an HTTP endpoint which is HTTPS only like lambda URL
+ LOG.warning("Execution failed due to a network error communicating with endpoint")
+ raise IntegrationFailureError("Network error communicating with endpoint")
+
+ except requests.exceptions.ConnectionError:
+ raise ApiConfigurationError("Internal server error")
+
+ response_headers = Headers(dict(request_response.headers))
+
+ return EndpointResponse(
+ body=request_response.content,
+ status_code=request_response.status_code,
+ headers=response_headers,
+ )
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/mock.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/mock.py
new file mode 100644
index 0000000000000..84ddecc05862e
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/integrations/mock.py
@@ -0,0 +1,108 @@
+import json
+import logging
+import re
+from json import JSONDecodeError
+
+from werkzeug.datastructures import Headers
+
+from localstack.utils.strings import to_str
+
+from ..context import EndpointResponse, IntegrationRequest, RestApiInvocationContext
+from ..gateway_response import InternalServerError
+from .core import RestApiIntegration
+
+LOG = logging.getLogger(__name__)
+
+
+class RestApiMockIntegration(RestApiIntegration):
+ """
+ This is a simple REST API integration but quite limited, allowing you to quickly test your APIs or return
+ hardcoded responses to the client.
+ This integration can never return a proper response, and all the work is done with integration request and response
+ mappings.
+ This can be used to set up CORS response for `OPTIONS` requests.
+ https://docs.aws.amazon.com/apigateway/latest/developerguide/how-to-mock-integration.html
+ """
+
+ name = "MOCK"
+
+ def invoke(self, context: RestApiInvocationContext) -> EndpointResponse:
+ integration_req: IntegrationRequest = context.integration_request
+
+ status_code = self.get_status_code(integration_req)
+
+ if status_code is None:
+ LOG.debug(
+ "Execution failed due to configuration error: Unable to parse statusCode. "
+ "It should be an integer that is defined in the request template."
+ )
+ raise InternalServerError("Internal server error")
+
+ return EndpointResponse(status_code=status_code, body=b"", headers=Headers())
+
+ def get_status_code(self, integration_req: IntegrationRequest) -> int | None:
+ try:
+ body = json.loads(integration_req["body"])
+ except JSONDecodeError as e:
+ LOG.debug(
+ "Exception while JSON parsing integration request body: %s"
+ "Falling back to custom parser",
+ e,
+ exc_info=LOG.isEnabledFor(logging.DEBUG),
+ )
+ body = self.parse_invalid_json(to_str(integration_req["body"]))
+
+ status_code = body.get("statusCode")
+ if not isinstance(status_code, int):
+ return
+
+ return status_code
+
+ def parse_invalid_json(self, body: str) -> dict:
+ """This is a quick fix to unblock cdk users setting cors policy for rest apis.
+ CDK creates a MOCK OPTIONS route with in valid json. `{statusCode: 200}`
+ Aws probably has a custom token parser. We can implement one
+ at some point if we have user requests for it"""
+
+ def convert_null_value(value) -> str:
+ if (value := value.strip()) in ("null", ""):
+ return '""'
+ return value
+
+ try:
+ statuscode = ""
+ matched = re.match(r"^\s*{(.+)}\s*$", body).group(1)
+ pairs = [m.strip() for m in matched.split(",")]
+ # TODO this is not right, but nested object would otherwise break the parsing
+ key_values = [s.split(":", maxsplit=1) for s in pairs if s]
+ for key_value in key_values:
+ assert len(key_value) == 2
+ key, value = [convert_null_value(el) for el in key_value]
+
+ if key in ("statusCode", "'statusCode'", '"statusCode"'):
+ statuscode = int(value)
+ continue
+
+ assert (leading_key_char := key[0]) not in "[{"
+ if leading_key_char in "'\"":
+ assert len(key) >= 2
+ assert key[-1] == leading_key_char
+
+ if (leading_value_char := value[0]) in "[{'\"":
+ assert len(value) >= 2
+ if leading_value_char == "{":
+ # TODO reparse objects
+ assert value[-1] == "}"
+ elif leading_value_char == "[":
+ # TODO validate arrays
+ assert value[-1] == "]"
+ else:
+ assert value[-1] == leading_value_char
+
+ return {"statusCode": statuscode}
+
+ except Exception as e:
+ LOG.debug(
+ "Error Parsing an invalid json, %s", e, exc_info=LOG.isEnabledFor(logging.DEBUG)
+ )
+ return {"statusCode": ""}
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/moto_helpers.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/moto_helpers.py
new file mode 100644
index 0000000000000..ae9e9ddc6a7a2
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/moto_helpers.py
@@ -0,0 +1,69 @@
+from moto.apigateway.models import APIGatewayBackend, apigateway_backends
+from moto.apigateway.models import RestAPI as MotoRestAPI
+
+from localstack.aws.api.apigateway import ApiKey, ListOfUsagePlan, ListOfUsagePlanKey, Resource
+
+
+def get_resources_from_moto_rest_api(moto_rest_api: MotoRestAPI) -> dict[str, Resource]:
+ """
+ This returns the `Resources` from a Moto REST API
+ This allows to decouple the underlying split of resources between Moto and LocalStack, and always return the right
+ format.
+ """
+ moto_resources = moto_rest_api.resources
+
+ resources: dict[str, Resource] = {}
+ for moto_resource in moto_resources.values():
+ resource = Resource(
+ id=moto_resource.id,
+ parentId=moto_resource.parent_id,
+ pathPart=moto_resource.path_part,
+ path=moto_resource.get_path(),
+ resourceMethods={
+ # TODO: check if resource_methods.to_json() returns everything we need/want
+ k: v.to_json()
+ for k, v in moto_resource.resource_methods.items()
+ },
+ )
+
+ resources[moto_resource.id] = resource
+
+ return resources
+
+
+def get_stage_variables(
+ account_id: str, region: str, api_id: str, stage_name: str
+) -> dict[str, str]:
+ apigateway_backend: APIGatewayBackend = apigateway_backends[account_id][region]
+ moto_rest_api = apigateway_backend.get_rest_api(api_id)
+ stage = moto_rest_api.stages[stage_name]
+ return stage.variables
+
+
+def get_usage_plans(account_id: str, region_name: str) -> ListOfUsagePlan:
+ """
+ Will return a list of usage plans from the moto store.
+ """
+ apigateway_backend: APIGatewayBackend = apigateway_backends[account_id][region_name]
+ return [usage_plan.to_json() for usage_plan in apigateway_backend.usage_plans.values()]
+
+
+def get_api_key(api_key_id: str, account_id: str, region_name: str) -> ApiKey:
+ """
+ Will return an api key from the moto store.
+ """
+ apigateway_backend: APIGatewayBackend = apigateway_backends[account_id][region_name]
+ return apigateway_backend.keys[api_key_id].to_json()
+
+
+def get_usage_plan_keys(
+ usage_plan_id: str, account_id: str, region_name: str
+) -> ListOfUsagePlanKey:
+ """
+ Will return a list of usage plan keys from the moto store.
+ """
+ apigateway_backend: APIGatewayBackend = apigateway_backends[account_id][region_name]
+ return [
+ usage_plan_key.to_json()
+ for usage_plan_key in apigateway_backend.usage_plan_keys.get(usage_plan_id, {}).values()
+ ]
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/parameters_mapping.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/parameters_mapping.py
new file mode 100644
index 0000000000000..bb723e58ea4ef
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/parameters_mapping.py
@@ -0,0 +1,298 @@
+# > This section explains how to set up data mappings from an API's method request data, including other data
+# stored in context, stage, or util variables, to the corresponding integration request parameters and from an
+# integration response data, including the other data, to the method response parameters. The method request
+# data includes request parameters (path, query string, headers) and the body. The integration response data
+# includes response parameters (headers) and the body. For more information about using the stage variables,
+# see API Gateway stage variables reference.
+#
+# https://docs.aws.amazon.com/apigateway/latest/developerguide/request-response-data-mappings.html
+import json
+import logging
+from typing import Any, TypedDict
+
+from localstack.utils.json import extract_jsonpath
+from localstack.utils.strings import to_str
+
+from .context import EndpointResponse, InvocationRequest
+from .gateway_response import BadRequestException, InternalFailureException
+from .header_utils import build_multi_value_headers
+from .variables import ContextVariables
+
+LOG = logging.getLogger(__name__)
+
+
+class RequestDataMapping(TypedDict):
+ # Integration request parameters, in the form of path variables, query strings or headers, can be mapped from any
+ # defined method request parameters and the payload.
+ header: dict[str, str]
+ path: dict[str, str]
+ querystring: dict[str, str | list[str]]
+
+
+class ResponseDataMapping(TypedDict):
+ # Method response header parameters can be mapped from any integration response header or integration response body,
+ # $context variables, or static values.
+ header: dict[str, str]
+
+
+class ParametersMapper:
+ def map_integration_request(
+ self,
+ request_parameters: dict[str, str],
+ invocation_request: InvocationRequest,
+ context_variables: ContextVariables,
+ stage_variables: dict[str, str],
+ ) -> RequestDataMapping:
+ request_data_mapping = RequestDataMapping(
+ header={},
+ path={},
+ querystring={},
+ )
+ # storing the case-sensitive headers once, the mapping is strict
+ case_sensitive_headers = build_multi_value_headers(invocation_request["headers"])
+
+ for integration_mapping, request_mapping in request_parameters.items():
+ # TODO: remove this once the validation has been added to the provider, to avoid breaking
+ if not isinstance(integration_mapping, str) or not isinstance(request_mapping, str):
+ LOG.warning(
+ "Wrong parameter mapping value type: %s: %s. They should both be string. Skipping this mapping.",
+ integration_mapping,
+ request_mapping,
+ )
+ continue
+
+ integration_param_location, param_name = integration_mapping.removeprefix(
+ "integration.request."
+ ).split(".")
+
+ if request_mapping.startswith("method.request."):
+ method_req_expr = request_mapping.removeprefix("method.request.")
+ value = self._retrieve_parameter_from_invocation_request(
+ method_req_expr, invocation_request, case_sensitive_headers
+ )
+
+ else:
+ value = self._retrieve_parameter_from_variables_and_static(
+ mapping_value=request_mapping,
+ context_variables=context_variables,
+ stage_variables=stage_variables,
+ )
+
+ if value:
+ request_data_mapping[integration_param_location][param_name] = value
+
+ return request_data_mapping
+
+ def map_integration_response(
+ self,
+ response_parameters: dict[str, str],
+ integration_response: EndpointResponse,
+ context_variables: ContextVariables,
+ stage_variables: dict[str, str],
+ ) -> ResponseDataMapping:
+ response_data_mapping = ResponseDataMapping(header={})
+
+ # storing the case-sensitive headers once, the mapping is strict
+ case_sensitive_headers = build_multi_value_headers(integration_response["headers"])
+
+ for response_mapping, integration_mapping in response_parameters.items():
+ header_name = response_mapping.removeprefix("method.response.header.")
+
+ if integration_mapping.startswith("integration.response."):
+ method_req_expr = integration_mapping.removeprefix("integration.response.")
+ value = self._retrieve_parameter_from_integration_response(
+ method_req_expr, integration_response, case_sensitive_headers
+ )
+ else:
+ value = self._retrieve_parameter_from_variables_and_static(
+ mapping_value=integration_mapping,
+ context_variables=context_variables,
+ stage_variables=stage_variables,
+ )
+
+ if value:
+ response_data_mapping["header"][header_name] = value
+
+ return response_data_mapping
+
+ def _retrieve_parameter_from_variables_and_static(
+ self,
+ mapping_value: str,
+ context_variables: dict[str, Any],
+ stage_variables: dict[str, str],
+ ) -> str | None:
+ if mapping_value.startswith("context."):
+ context_var_expr = mapping_value.removeprefix("context.")
+ return self._retrieve_parameter_from_context_variables(
+ context_var_expr, context_variables
+ )
+
+ elif mapping_value.startswith("stageVariables."):
+ stage_var_name = mapping_value.removeprefix("stageVariables.")
+ return self._retrieve_parameter_from_stage_variables(stage_var_name, stage_variables)
+
+ elif mapping_value.startswith("'") and mapping_value.endswith("'"):
+ return mapping_value.strip("'")
+
+ else:
+ LOG.warning(
+ "Unrecognized parameter mapping value: '%s'. Skipping this mapping.",
+ mapping_value,
+ )
+ return None
+
+ def _retrieve_parameter_from_integration_response(
+ self,
+ expr: str,
+ integration_response: EndpointResponse,
+ case_sensitive_headers: dict[str, list[str]],
+ ) -> str | None:
+ """
+ See https://docs.aws.amazon.com/apigateway/latest/developerguide/request-response-data-mappings.html#mapping-response-parameters
+ :param expr: mapping expression stripped from `integration.response.`:
+ Can be of the following: `header.`, multivalueheader., `body` and
+ `body..`
+ :param integration_response: the Response to map parameters from
+ :return: the value to map in the ResponseDataMapping
+ """
+ if expr.startswith("body"):
+ body = integration_response.get("body") or b"{}"
+ body = body.strip()
+ try:
+ decoded_body = self._json_load(body)
+ except ValueError:
+ raise InternalFailureException(message="Internal server error")
+
+ if expr == "body":
+ return to_str(body)
+
+ elif expr.startswith("body."):
+ json_path = expr.removeprefix("body.")
+ return self._get_json_path_from_dict(decoded_body, json_path)
+ else:
+ LOG.warning(
+ "Unrecognized integration.response parameter: '%s'. Skipping the parameter mapping.",
+ expr,
+ )
+ return None
+
+ param_type, param_name = expr.split(".")
+
+ if param_type == "header":
+ if header := case_sensitive_headers.get(param_name):
+ return header[-1]
+
+ elif param_type == "multivalueheader":
+ if header := case_sensitive_headers.get(param_name):
+ return ",".join(header)
+
+ else:
+ LOG.warning(
+ "Unrecognized integration.response parameter: '%s'. Skipping the parameter mapping.",
+ expr,
+ )
+
+ def _retrieve_parameter_from_invocation_request(
+ self,
+ expr: str,
+ invocation_request: InvocationRequest,
+ case_sensitive_headers: dict[str, list[str]],
+ ) -> str | list[str] | None:
+ """
+ See https://docs.aws.amazon.com/apigateway/latest/developerguide/request-response-data-mappings.html#mapping-response-parameters
+ :param expr: mapping expression stripped from `method.request.`:
+ Can be of the following: `path.`, `querystring.`,
+ `multivaluequerystring.`, `header.`, `multivalueheader.`,
+ `body` and `body..`
+ :param invocation_request: the InvocationRequest to map parameters from
+ :return: the value to map in the RequestDataMapping
+ """
+ if expr.startswith("body"):
+ body = invocation_request["body"] or b"{}"
+ body = body.strip()
+ try:
+ decoded_body = self._json_load(body)
+ except ValueError:
+ raise BadRequestException(message="Invalid JSON in request body")
+
+ if expr == "body":
+ return to_str(body)
+
+ elif expr.startswith("body."):
+ json_path = expr.removeprefix("body.")
+ return self._get_json_path_from_dict(decoded_body, json_path)
+ else:
+ LOG.warning(
+ "Unrecognized method.request parameter: '%s'. Skipping the parameter mapping.",
+ expr,
+ )
+ return None
+
+ param_type, param_name = expr.split(".")
+ if param_type == "path":
+ return invocation_request["path_parameters"].get(param_name)
+
+ elif param_type == "querystring":
+ multi_qs_params = invocation_request["multi_value_query_string_parameters"].get(
+ param_name
+ )
+ if multi_qs_params:
+ return multi_qs_params[-1]
+
+ elif param_type == "multivaluequerystring":
+ multi_qs_params = invocation_request["multi_value_query_string_parameters"].get(
+ param_name
+ )
+ if len(multi_qs_params) == 1:
+ return multi_qs_params[0]
+ return multi_qs_params
+
+ elif param_type == "header":
+ if header := case_sensitive_headers.get(param_name):
+ return header[-1]
+
+ elif param_type == "multivalueheader":
+ if header := case_sensitive_headers.get(param_name):
+ return ",".join(header)
+
+ else:
+ LOG.warning(
+ "Unrecognized method.request parameter: '%s'. Skipping the parameter mapping.",
+ expr,
+ )
+
+ def _retrieve_parameter_from_context_variables(
+ self, expr: str, context_variables: dict[str, Any]
+ ) -> str | None:
+ # we're using JSON path here because we could access nested properties like `context.identity.sourceIp`
+ if (value := self._get_json_path_from_dict(context_variables, expr)) and isinstance(
+ value, str
+ ):
+ return value
+
+ @staticmethod
+ def _retrieve_parameter_from_stage_variables(
+ stage_var_name: str, stage_variables: dict[str, str]
+ ) -> str | None:
+ return stage_variables.get(stage_var_name)
+
+ @staticmethod
+ def _get_json_path_from_dict(body: dict, path: str) -> str | None:
+ # TODO: verify we don't have special cases
+ try:
+ return extract_jsonpath(body, f"$.{path}")
+ except KeyError:
+ return None
+
+ @staticmethod
+ def _json_load(body: bytes) -> dict | list:
+ """
+ AWS only tries to JSON decode the body if it starts with some leading characters ({, [, ", ')
+ otherwise, it ignores it
+ :param body:
+ :return:
+ """
+ if any(body.startswith(c) for c in (b"{", b"[", b"'", b'"')):
+ return json.loads(body)
+
+ return {}
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/router.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/router.py
new file mode 100644
index 0000000000000..7e84967df5004
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/router.py
@@ -0,0 +1,189 @@
+import logging
+from typing import TypedDict, Unpack
+
+from rolo import Request, Router
+from rolo.routing.handler import Handler
+from werkzeug.routing import Rule
+
+from localstack.constants import APPLICATION_JSON, AWS_REGION_US_EAST_1, DEFAULT_AWS_ACCOUNT_ID
+from localstack.deprecations import deprecated_endpoint
+from localstack.http import Response
+from localstack.services.apigateway.models import ApiGatewayStore, apigateway_stores
+from localstack.services.edge import ROUTER
+from localstack.services.stores import AccountRegionBundle
+
+from .context import RestApiInvocationContext
+from .gateway import RestApiGateway
+
+LOG = logging.getLogger(__name__)
+
+
+class RouteHostPathParameters(TypedDict, total=False):
+ """
+ Represents the kwargs typing for calling APIGatewayEndpoint.
+ Each field might be populated from the route host and path parameters, defined when registering a route.
+ """
+
+ api_id: str
+ path: str
+ port: int | None
+ server: str | None
+ stage: str | None
+ vpce_suffix: str | None
+
+
+class ApiGatewayEndpoint:
+ """
+ This class is the endpoint for API Gateway invocations of the `execute-api` route. It will take the incoming
+ invocation request, create a context from the API matching the route parameters, and dispatch the request to the
+ Gateway to be processed by the handler chain.
+ """
+
+ def __init__(self, rest_gateway: RestApiGateway = None, store: AccountRegionBundle = None):
+ self.rest_gateway = rest_gateway or RestApiGateway()
+ # we only access CrossAccount attributes in the handler, so we use a global store in default account and region
+ self._store = store or apigateway_stores
+
+ @property
+ def _global_store(self) -> ApiGatewayStore:
+ return self._store[DEFAULT_AWS_ACCOUNT_ID][AWS_REGION_US_EAST_1]
+
+ def __call__(self, request: Request, **kwargs: Unpack[RouteHostPathParameters]) -> Response:
+ """
+ :param request: the incoming Request object
+ :param kwargs: can contain all the field of RouteHostPathParameters. Those values are defined on the registered
+ routes in ApiGatewayRouter, through host and path parameters in the shape or only.
+ :return: the Response object to return to the client
+ """
+ # api_id can be cased because of custom-tag id
+ api_id, stage = kwargs.get("api_id", "").lower(), kwargs.get("stage")
+ if self.is_rest_api(api_id, stage):
+ context, response = self.prepare_rest_api_invocation(request, api_id, stage)
+ self.rest_gateway.process_with_context(context, response)
+ return response
+ else:
+ return self.create_not_found_response(api_id)
+
+ def prepare_rest_api_invocation(
+ self, request: Request, api_id: str, stage: str
+ ) -> tuple[RestApiInvocationContext, Response]:
+ LOG.debug("APIGW v1 Endpoint called")
+ response = self.create_response(request)
+ context = RestApiInvocationContext(request)
+ self.populate_rest_api_invocation_context(context, api_id, stage)
+
+ return context, response
+
+ def is_rest_api(self, api_id: str, stage: str):
+ return stage in self._global_store.active_deployments.get(api_id, {})
+
+ def populate_rest_api_invocation_context(
+ self, context: RestApiInvocationContext, api_id: str, stage: str
+ ):
+ try:
+ deployment_id = self._global_store.active_deployments[api_id][stage]
+ frozen_deployment = self._global_store.internal_deployments[api_id][deployment_id]
+
+ except KeyError:
+ # TODO: find proper error when trying to hit an API with no deployment/stage linked
+ return
+
+ context.deployment = frozen_deployment
+ context.api_id = api_id
+ context.stage = stage
+ context.deployment_id = deployment_id
+
+ @staticmethod
+ def create_response(request: Request) -> Response:
+ # Creates a default apigw response.
+ response = Response(headers={"Content-Type": APPLICATION_JSON})
+ if not (connection := request.headers.get("Connection")) or connection != "close":
+ # We only set the connection if it isn't close.
+ # There appears to be in issue in Localstack, where setting "close" will result in "close, close"
+ response.headers.set("Connection", "keep-alive")
+ return response
+
+ @staticmethod
+ def create_not_found_response(api_id: str) -> Response:
+ not_found = Response(status=404)
+ not_found.set_json(
+ {"message": f"The API id '{api_id}' does not correspond to a deployed API Gateway API"}
+ )
+ return not_found
+
+
+class ApiGatewayRouter:
+ router: Router[Handler]
+ handler: ApiGatewayEndpoint
+ EXECUTE_API_INTERNAL_PATH = "/_aws/execute-api"
+
+ def __init__(self, router: Router[Handler] = None, handler: ApiGatewayEndpoint = None):
+ self.router = router or ROUTER
+ self.handler = handler or ApiGatewayEndpoint()
+ self.registered_rules: list[Rule] = []
+
+ def register_routes(self) -> None:
+ LOG.debug("Registering API Gateway routes.")
+ host_pattern = ".execute-api."
+ deprecated_route_endpoint = deprecated_endpoint(
+ endpoint=self.handler,
+ previous_path="/restapis///_user_request_",
+ deprecation_version="3.8.0",
+ new_path=f"{self.EXECUTE_API_INTERNAL_PATH}//",
+ )
+ rules = [
+ self.router.add(
+ path="/",
+ host=host_pattern,
+ endpoint=self.handler,
+ defaults={"path": "", "stage": None},
+ strict_slashes=True,
+ ),
+ self.router.add(
+ path="//",
+ host=host_pattern,
+ endpoint=self.handler,
+ defaults={"path": ""},
+ strict_slashes=False,
+ ),
+ self.router.add(
+ path="//",
+ host=host_pattern,
+ endpoint=self.handler,
+ strict_slashes=True,
+ ),
+ # add the deprecated localstack-specific _user_request_ routes
+ self.router.add(
+ path="/restapis///_user_request_",
+ endpoint=deprecated_route_endpoint,
+ defaults={"path": "", "random": "?"},
+ ),
+ self.router.add(
+ path="/restapis///_user_request_/",
+ endpoint=deprecated_route_endpoint,
+ strict_slashes=True,
+ ),
+ # add the localstack-specific so-called "path-style" routes when DNS resolving is not possible
+ self.router.add(
+ path=f"{self.EXECUTE_API_INTERNAL_PATH}//",
+ endpoint=self.handler,
+ defaults={"path": "", "stage": None},
+ strict_slashes=True,
+ ),
+ self.router.add(
+ path=f"{self.EXECUTE_API_INTERNAL_PATH}///",
+ endpoint=self.handler,
+ defaults={"path": ""},
+ strict_slashes=False,
+ ),
+ self.router.add(
+ path=f"{self.EXECUTE_API_INTERNAL_PATH}///",
+ endpoint=self.handler,
+ strict_slashes=True,
+ ),
+ ]
+ for rule in rules:
+ self.registered_rules.append(rule)
+
+ def unregister_routes(self):
+ self.router.remove(self.registered_rules)
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/template_mapping.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/template_mapping.py
new file mode 100644
index 0000000000000..6f55f17adb834
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/template_mapping.py
@@ -0,0 +1,204 @@
+# > In API Gateway, an API's method request or response can take a payload in a different format from the integration
+# request or response.
+#
+# You can transform your data to:
+# - Match the payload to an API-specified format.
+# - Override an API's request and response parameters and status codes.
+# - Return client selected response headers.
+# - Associate path parameters, query string parameters, or header parameters in the method request of HTTP proxy
+# or AWS service proxy. TODO: this is from the documentation. Can we use requestOverides for proxy integrations?
+# - Select which data to send using integration with AWS services, such as Amazon DynamoDB or Lambda functions,
+# or HTTP endpoints.
+#
+# You can use mapping templates to transform your data. A mapping template is a script expressed in Velocity Template
+# Language (VTL) and applied to the payload using JSONPath .
+#
+# https://docs.aws.amazon.com/apigateway/latest/developerguide/models-mappings.html
+import base64
+import copy
+import json
+import logging
+from typing import Any, TypedDict
+from urllib.parse import quote_plus, unquote_plus
+
+from localstack import config
+from localstack.services.apigateway.next_gen.execute_api.variables import (
+ ContextVariables,
+ ContextVarsRequestOverride,
+ ContextVarsResponseOverride,
+)
+from localstack.utils.aws.templating import APIGW_SOURCE, VelocityUtil, VtlTemplate
+from localstack.utils.json import extract_jsonpath, json_safe
+
+LOG = logging.getLogger(__name__)
+
+
+class MappingTemplateParams(TypedDict, total=False):
+ path: dict[str, str]
+ querystring: dict[str, str]
+ header: dict[str, str]
+
+
+class MappingTemplateInput(TypedDict, total=False):
+ body: str
+ params: MappingTemplateParams
+
+
+class MappingTemplateVariables(TypedDict, total=False):
+ context: ContextVariables
+ input: MappingTemplateInput
+ stageVariables: dict[str, str]
+
+
+class AttributeDict(dict):
+ """
+ Wrapper returned by VelocityUtilApiGateway.parseJson to allow access to dict values as attributes (dot notation),
+ e.g.: $util.parseJson('$.foo').bar
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(AttributeDict, self).__init__(*args, **kwargs)
+ for key, value in self.items():
+ if isinstance(value, dict):
+ self[key] = AttributeDict(value)
+
+ def __getattr__(self, name):
+ if name in self:
+ return self[name]
+ raise AttributeError(f"'AttributeDict' object has no attribute '{name}'")
+
+ def __setattr__(self, name, value):
+ self[name] = value
+
+ def __delattr__(self, name):
+ if name in self:
+ del self[name]
+ else:
+ raise AttributeError(f"'AttributeDict' object has no attribute '{name}'")
+
+
+class VelocityUtilApiGateway(VelocityUtil):
+ """
+ Simple class to mimic the behavior of variable '$util' in AWS API Gateway integration
+ velocity templates.
+ See: https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html
+ """
+
+ def base64Encode(self, s):
+ if not isinstance(s, str):
+ s = json.dumps(s)
+ encoded_str = s.encode(config.DEFAULT_ENCODING)
+ encoded_b64_str = base64.b64encode(encoded_str)
+ return encoded_b64_str.decode(config.DEFAULT_ENCODING)
+
+ def base64Decode(self, s):
+ if not isinstance(s, str):
+ s = json.dumps(s)
+ return base64.b64decode(s)
+
+ def toJson(self, obj):
+ return obj and json.dumps(obj)
+
+ def urlEncode(self, s):
+ return quote_plus(s)
+
+ def urlDecode(self, s):
+ return unquote_plus(s)
+
+ def escapeJavaScript(self, obj: Any) -> str:
+ """
+ Converts the given object to a string and escapes any regular single quotes (') into escaped ones (\').
+ JSON dumps will escape the single quotes.
+ https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html
+ """
+ if obj is None:
+ return "null"
+ if isinstance(obj, str):
+ # empty string escapes to empty object
+ if len(obj.strip()) == 0:
+ return "{}"
+ return json.dumps(obj)[1:-1]
+ if obj in (True, False):
+ return str(obj).lower()
+ return str(obj)
+
+ def parseJson(self, s: str):
+ obj = json.loads(s)
+ return AttributeDict(obj) if isinstance(obj, dict) else obj
+
+
+class VelocityInput:
+ """
+ Simple class to mimic the behavior of variable '$input' in AWS API Gateway integration
+ velocity templates.
+ See: http://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html
+ """
+
+ def __init__(self, body, params):
+ self.parameters = params or {}
+ self.value = body
+
+ def path(self, path):
+ if not self.value:
+ return {}
+ value = self.value if isinstance(self.value, dict) else json.loads(self.value)
+ return extract_jsonpath(value, path)
+
+ def json(self, path):
+ path = path or "$"
+ matching = self.path(path)
+ if isinstance(matching, (list, dict)):
+ matching = json_safe(matching)
+ return json.dumps(matching)
+
+ @property
+ def body(self):
+ return self.value
+
+ def params(self, name=None):
+ if not name:
+ return self.parameters
+ for k in ["path", "querystring", "header"]:
+ if val := self.parameters.get(k).get(name):
+ return val
+ return ""
+
+ def __getattr__(self, name):
+ return self.value.get(name)
+
+ def __repr__(self):
+ return "$input"
+
+
+class ApiGatewayVtlTemplate(VtlTemplate):
+ """Util class for rendering VTL templates with API Gateway specific extensions"""
+
+ def prepare_namespace(self, variables, source: str = APIGW_SOURCE) -> dict[str, Any]:
+ namespace = super().prepare_namespace(variables, source)
+ input_var = variables.get("input") or {}
+ variables = {
+ "input": VelocityInput(input_var.get("body"), input_var.get("params")),
+ "util": VelocityUtilApiGateway(),
+ }
+ namespace.update(variables)
+ return namespace
+
+ def render_request(
+ self, template: str, variables: MappingTemplateVariables
+ ) -> tuple[str, ContextVarsRequestOverride]:
+ variables_copy: MappingTemplateVariables = copy.deepcopy(variables)
+ variables_copy["context"]["requestOverride"] = ContextVarsRequestOverride(
+ querystring={}, header={}, path={}
+ )
+ result = self.render_vtl(template=template.strip(), variables=variables_copy)
+ return result, variables_copy["context"]["requestOverride"]
+
+ def render_response(
+ self, template: str, variables: MappingTemplateVariables
+ ) -> tuple[str, ContextVarsResponseOverride]:
+ variables_copy: MappingTemplateVariables = copy.deepcopy(variables)
+ variables_copy["context"]["responseOverride"] = ContextVarsResponseOverride(
+ header={}, status=0
+ )
+ result = self.render_vtl(template=template.strip(), variables=variables_copy)
+ return result, variables_copy["context"]["responseOverride"]
diff --git a/localstack-core/localstack/services/apigateway/next_gen/execute_api/variables.py b/localstack-core/localstack/services/apigateway/next_gen/execute_api/variables.py
new file mode 100644
index 0000000000000..6403f01852752
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/execute_api/variables.py
@@ -0,0 +1,190 @@
+from typing import Optional, TypedDict
+
+
+class ContextVarsAuthorizer(TypedDict, total=False):
+ # this is merged with the Context returned by the Authorizer, which can attach any property to this dict in string
+ # format
+
+ # https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html
+ claims: Optional[dict[str, str]]
+ """Claims returned from the Amazon Cognito user pool after the method caller is successfully authenticated"""
+ principalId: Optional[str]
+ """The principal user identification associated with the token sent by the client and returned from an API Gateway Lambda authorizer"""
+
+
+class ContextVarsIdentityClientCertValidity(TypedDict, total=False):
+ notBefore: str
+ notAfter: str
+
+
+class ContextVarsIdentityClientCert(TypedDict, total=False):
+ """Certificate that a client presents. Present only in access logs if mutual TLS authentication fails."""
+
+ clientCertPem: str
+ subjectDN: str
+ issuerDN: str
+ serialNumber: str
+ validity: ContextVarsIdentityClientCertValidity
+
+
+class ContextVarsIdentity(TypedDict, total=False):
+ # https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html
+ accountId: Optional[str]
+ """The AWS account ID associated with the request."""
+ accessKey: Optional[str]
+ """The AWS access key associated with the request."""
+ apiKey: Optional[str]
+ """For API methods that require an API key, this variable is the API key associated with the method request."""
+ apiKeyId: Optional[str]
+ """The API key ID associated with an API request that requires an API key."""
+ caller: Optional[str]
+ """The principal identifier of the caller that signed the request. Supported for resources that use IAM authorization."""
+ cognitoAuthenticationProvider: Optional[str]
+ """A comma-separated list of the Amazon Cognito authentication providers used by the caller making the request"""
+ cognitoAuthenticationType: Optional[str]
+ """The Amazon Cognito authentication type of the caller making the request"""
+ cognitoIdentityId: Optional[str]
+ """The Amazon Cognito identity ID of the caller making the request"""
+ cognitoIdentityPoolId: Optional[str]
+ """The Amazon Cognito identity pool ID of the caller making the request"""
+ principalOrgId: Optional[str]
+ """The AWS organization ID."""
+ sourceIp: Optional[str]
+ """The source IP address of the immediate TCP connection making the request to the API Gateway endpoint"""
+ clientCert: ContextVarsIdentityClientCert
+ vpcId: Optional[str]
+ """The VPC ID of the VPC making the request to the API Gateway endpoint."""
+ vpceId: Optional[str]
+ """The VPC endpoint ID of the VPC endpoint making the request to the API Gateway endpoint."""
+ user: Optional[str]
+ """The principal identifier of the user that will be authorized against resource access for resources that use IAM authorization."""
+ userAgent: Optional[str]
+ """The User-Agent header of the API caller."""
+ userArn: Optional[str]
+ """The Amazon Resource Name (ARN) of the effective user identified after authentication."""
+
+
+class ContextVarsRequestOverride(TypedDict, total=False):
+ header: dict[str, str]
+ path: dict[str, str]
+ querystring: dict[str, str]
+
+
+class ContextVarsResponseOverride(TypedDict):
+ header: dict[str, str]
+ status: int
+
+
+class GatewayResponseContextVarsError(TypedDict, total=False):
+ # This variable can only be used for simple variable substitution in a GatewayResponse body-mapping template,
+ # which is not processed by the Velocity Template Language engine, and in access logging.
+ message: str
+ messageString: str
+ responseType: str
+ validationErrorString: str
+
+
+class ContextVariables(TypedDict, total=False):
+ # https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html#context-variable-reference
+ accountId: str
+ """The API owner's AWS account ID."""
+ apiId: str
+ """The identifier API Gateway assigns to your API."""
+ authorizer: Optional[ContextVarsAuthorizer]
+ """The principal user identification associated with the token."""
+ awsEndpointRequestId: Optional[str]
+ """The AWS endpoint's request ID."""
+ deploymentId: str
+ """The ID of the API deployment."""
+ domainName: str
+ """The full domain name used to invoke the API. This should be the same as the incoming Host header."""
+ domainPrefix: str
+ """The first label of the $context.domainName."""
+ error: GatewayResponseContextVarsError
+ """The error context variables."""
+ extendedRequestId: str
+ """The extended ID that API Gateway generates and assigns to the API request. """
+ httpMethod: str
+ """The HTTP method used"""
+ identity: Optional[ContextVarsIdentity]
+ isCanaryRequest: Optional[bool | str] # TODO: verify type
+ """Indicates if the request was directed to the canary"""
+ path: str
+ """The request path."""
+ protocol: str
+ """The request protocol"""
+ requestId: str
+ """An ID for the request. Clients can override this request ID. """
+ requestOverride: Optional[ContextVarsRequestOverride]
+ """Request override. Only exists for request mapping template"""
+ requestTime: str
+ """The CLF-formatted request time (dd/MMM/yyyy:HH:mm:ss +-hhmm)."""
+ requestTimeEpoch: int
+ """The Epoch-formatted request time, in milliseconds."""
+ resourceId: Optional[str]
+ """The identifier that API Gateway assigns to your resource."""
+ resourcePath: Optional[str]
+ """The path to your resource"""
+ responseOverride: Optional[ContextVarsResponseOverride]
+ """Response override. Only exists for response mapping template"""
+ stage: str
+ """The deployment stage of the API request """
+ wafResponseCode: Optional[str]
+ """The response received from AWS WAF: WAF_ALLOW or WAF_BLOCK. Will not be set if the stage is not associated with a web ACL"""
+ webaclArn: Optional[str]
+ """The complete ARN of the web ACL that is used to decide whether to allow or block the request. Will not be set if the stage is not associated with a web ACL."""
+
+
+class LoggingContextVarsAuthorize(TypedDict, total=False):
+ error: Optional[str]
+ latency: Optional[str]
+ status: Optional[str]
+
+
+class LoggingContextVarsAuthorizer(TypedDict, total=False):
+ error: Optional[str]
+ integrationLatency: Optional[str]
+ integrationStatus: Optional[str]
+ latency: Optional[str]
+ requestId: Optional[str]
+ status: Optional[str]
+
+
+class LoggingContextVarsAuthenticate(TypedDict, total=False):
+ error: Optional[str]
+ latency: Optional[str]
+ status: Optional[str]
+
+
+class LoggingContextVarsCustomDomain(TypedDict, total=False):
+ basePathMatched: Optional[str]
+
+
+class LoggingContextVarsIntegration(TypedDict, total=False):
+ error: Optional[str]
+ integrationStatus: Optional[str]
+ latency: Optional[str]
+ requestId: Optional[str]
+ status: Optional[str]
+
+
+class LoggingContextVarsWaf(TypedDict, total=False):
+ error: Optional[str]
+ latency: Optional[str]
+ status: Optional[str]
+
+
+class LoggingContextVariables(TypedDict, total=False):
+ authorize: Optional[LoggingContextVarsAuthorize]
+ authorizer: Optional[LoggingContextVarsAuthorizer]
+ authenticate: Optional[LoggingContextVarsAuthenticate]
+ customDomain: Optional[LoggingContextVarsCustomDomain]
+ endpointType: Optional[str]
+ integration: Optional[LoggingContextVarsIntegration]
+ integrationLatency: Optional[str]
+ integrationStatus: Optional[str]
+ responseLatency: Optional[str]
+ responseLength: Optional[str]
+ status: Optional[str]
+ waf: Optional[LoggingContextVarsWaf]
+ xrayTraceId: Optional[str]
diff --git a/localstack-core/localstack/services/apigateway/next_gen/provider.py b/localstack-core/localstack/services/apigateway/next_gen/provider.py
new file mode 100644
index 0000000000000..9361e08ae94fd
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/next_gen/provider.py
@@ -0,0 +1,262 @@
+from localstack.aws.api import CommonServiceException, RequestContext, handler
+from localstack.aws.api.apigateway import (
+ CacheClusterSize,
+ CreateStageRequest,
+ Deployment,
+ DeploymentCanarySettings,
+ GatewayResponse,
+ GatewayResponses,
+ GatewayResponseType,
+ ListOfPatchOperation,
+ MapOfStringToString,
+ NotFoundException,
+ NullableBoolean,
+ NullableInteger,
+ Stage,
+ StatusCode,
+ String,
+ TestInvokeMethodRequest,
+ TestInvokeMethodResponse,
+)
+from localstack.services.apigateway.helpers import (
+ get_apigateway_store,
+ get_moto_rest_api,
+ get_rest_api_container,
+)
+from localstack.services.apigateway.legacy.provider import ApigatewayProvider
+from localstack.services.apigateway.patches import apply_patches
+from localstack.services.edge import ROUTER
+from localstack.services.moto import call_moto
+
+from ..models import apigateway_stores
+from .execute_api.gateway_response import (
+ DEFAULT_GATEWAY_RESPONSES,
+ GatewayResponseCode,
+ build_gateway_response,
+ get_gateway_response_or_default,
+)
+from .execute_api.helpers import freeze_rest_api
+from .execute_api.router import ApiGatewayEndpoint, ApiGatewayRouter
+
+
+class ApigatewayNextGenProvider(ApigatewayProvider):
+ router: ApiGatewayRouter
+
+ def __init__(self, router: ApiGatewayRouter = None):
+ # we initialize the route handler with a global store with default account and region, because it only ever
+ # access values with CrossAccount attributes
+ if not router:
+ route_handler = ApiGatewayEndpoint(store=apigateway_stores)
+ router = ApiGatewayRouter(ROUTER, handler=route_handler)
+
+ super().__init__(router=router)
+
+ def on_after_init(self):
+ apply_patches()
+ self.router.register_routes()
+
+ @handler("DeleteRestApi")
+ def delete_rest_api(self, context: RequestContext, rest_api_id: String, **kwargs) -> None:
+ super().delete_rest_api(context, rest_api_id, **kwargs)
+ store = get_apigateway_store(context=context)
+ api_id_lower = rest_api_id.lower()
+ store.active_deployments.pop(api_id_lower, None)
+ store.internal_deployments.pop(api_id_lower, None)
+
+ @handler("CreateStage", expand=False)
+ def create_stage(self, context: RequestContext, request: CreateStageRequest) -> Stage:
+ response = super().create_stage(context, request)
+ store = get_apigateway_store(context=context)
+
+ rest_api_id = request["restApiId"].lower()
+ store.active_deployments.setdefault(rest_api_id, {})
+ store.active_deployments[rest_api_id][request["stageName"]] = request["deploymentId"]
+
+ return response
+
+ @handler("UpdateStage")
+ def update_stage(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ stage_name: String,
+ patch_operations: ListOfPatchOperation = None,
+ **kwargs,
+ ) -> Stage:
+ response = super().update_stage(
+ context, rest_api_id, stage_name, patch_operations, **kwargs
+ )
+
+ for patch_operation in patch_operations:
+ patch_path = patch_operation["path"]
+
+ if patch_path == "/deploymentId" and patch_operation["op"] == "replace":
+ if deployment_id := patch_operation.get("value"):
+ store = get_apigateway_store(context=context)
+ store.active_deployments.setdefault(rest_api_id.lower(), {})[stage_name] = (
+ deployment_id
+ )
+
+ return response
+
+ def delete_stage(
+ self, context: RequestContext, rest_api_id: String, stage_name: String, **kwargs
+ ) -> None:
+ call_moto(context)
+ store = get_apigateway_store(context=context)
+ store.active_deployments[rest_api_id.lower()].pop(stage_name, None)
+
+ def create_deployment(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ stage_name: String = None,
+ stage_description: String = None,
+ description: String = None,
+ cache_cluster_enabled: NullableBoolean = None,
+ cache_cluster_size: CacheClusterSize = None,
+ variables: MapOfStringToString = None,
+ canary_settings: DeploymentCanarySettings = None,
+ tracing_enabled: NullableBoolean = None,
+ **kwargs,
+ ) -> Deployment:
+ # TODO: if the REST API does not contain any method, we should raise an exception
+ deployment: Deployment = call_moto(context)
+ # https://docs.aws.amazon.com/apigateway/latest/developerguide/updating-api.html
+ # TODO: the deployment is not accessible until it is linked to a stage
+ # you can combine a stage or later update the deployment with a stage id
+ store = get_apigateway_store(context=context)
+ moto_rest_api = get_moto_rest_api(context, rest_api_id)
+ rest_api_container = get_rest_api_container(context, rest_api_id=rest_api_id)
+ frozen_deployment = freeze_rest_api(
+ account_id=context.account_id,
+ region=context.region,
+ moto_rest_api=moto_rest_api,
+ localstack_rest_api=rest_api_container,
+ )
+ router_api_id = rest_api_id.lower()
+ store.internal_deployments.setdefault(router_api_id, {})[deployment["id"]] = (
+ frozen_deployment
+ )
+
+ if stage_name:
+ store.active_deployments.setdefault(router_api_id, {})[stage_name] = deployment["id"]
+
+ return deployment
+
+ def delete_deployment(
+ self, context: RequestContext, rest_api_id: String, deployment_id: String, **kwargs
+ ) -> None:
+ call_moto(context)
+ store = get_apigateway_store(context=context)
+ store.internal_deployments.get(rest_api_id.lower(), {}).pop(deployment_id, None)
+
+ def put_gateway_response(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ response_type: GatewayResponseType,
+ status_code: StatusCode = None,
+ response_parameters: MapOfStringToString = None,
+ response_templates: MapOfStringToString = None,
+ **kwargs,
+ ) -> GatewayResponse:
+ store = get_apigateway_store(context=context)
+ if not (rest_api_container := store.rest_apis.get(rest_api_id)):
+ raise NotFoundException(
+ f"Invalid API identifier specified {context.account_id}:{rest_api_id}"
+ )
+
+ if response_type not in DEFAULT_GATEWAY_RESPONSES:
+ raise CommonServiceException(
+ code="ValidationException",
+ message=f"1 validation error detected: Value '{response_type}' at 'responseType' failed to satisfy constraint: Member must satisfy enum value set: [{', '.join(DEFAULT_GATEWAY_RESPONSES)}]",
+ )
+
+ gateway_response = build_gateway_response(
+ status_code=status_code,
+ response_parameters=response_parameters,
+ response_templates=response_templates,
+ response_type=response_type,
+ default_response=False,
+ )
+
+ rest_api_container.gateway_responses[response_type] = gateway_response
+
+ # The CRUD provider has a weird behavior: for some responses (for now, INTEGRATION_FAILURE), it sets the default
+ # status code to `504`. However, in the actual invocation logic, it returns 500. To deal with the inconsistency,
+ # we need to set the value to None if not provided by the user, so that the invocation logic can properly return
+ # 500, and the CRUD layer can still return 504 even though it is technically wrong.
+ response = gateway_response.copy()
+ if response.get("statusCode") is None:
+ response["statusCode"] = GatewayResponseCode[response_type]
+
+ return response
+
+ def get_gateway_response(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ response_type: GatewayResponseType,
+ **kwargs,
+ ) -> GatewayResponse:
+ store = get_apigateway_store(context=context)
+ if not (rest_api_container := store.rest_apis.get(rest_api_id)):
+ raise NotFoundException(
+ f"Invalid API identifier specified {context.account_id}:{rest_api_id}"
+ )
+
+ if response_type not in DEFAULT_GATEWAY_RESPONSES:
+ raise CommonServiceException(
+ code="ValidationException",
+ message=f"1 validation error detected: Value '{response_type}' at 'responseType' failed to satisfy constraint: Member must satisfy enum value set: [{', '.join(DEFAULT_GATEWAY_RESPONSES)}]",
+ )
+
+ gateway_response = _get_gateway_response_or_default(
+ response_type, rest_api_container.gateway_responses
+ )
+ # TODO: add validation with the parameters? seems like it validated client side? how to try?
+ return gateway_response
+
+ def get_gateway_responses(
+ self,
+ context: RequestContext,
+ rest_api_id: String,
+ position: String = None,
+ limit: NullableInteger = None,
+ **kwargs,
+ ) -> GatewayResponses:
+ store = get_apigateway_store(context=context)
+ if not (rest_api_container := store.rest_apis.get(rest_api_id)):
+ raise NotFoundException(
+ f"Invalid API identifier specified {context.account_id}:{rest_api_id}"
+ )
+
+ user_gateway_resp = rest_api_container.gateway_responses
+ gateway_responses = [
+ _get_gateway_response_or_default(response_type, user_gateway_resp)
+ for response_type in DEFAULT_GATEWAY_RESPONSES
+ ]
+ return GatewayResponses(items=gateway_responses)
+
+ def test_invoke_method(
+ self, context: RequestContext, request: TestInvokeMethodRequest
+ ) -> TestInvokeMethodResponse:
+ # TODO: rewrite and migrate to NextGen
+ return super().test_invoke_method(context, request)
+
+
+def _get_gateway_response_or_default(
+ response_type: GatewayResponseType,
+ gateway_responses: dict[GatewayResponseType, GatewayResponse],
+) -> GatewayResponse:
+ """
+ Utility function that overrides the behavior of `get_gateway_response_or_default` by setting a default status code
+ from the `GatewayResponseCode` values. In reality, some default values in the invocation layer are different from
+ what the CRUD layer of API Gateway is returning.
+ """
+ response = get_gateway_response_or_default(response_type, gateway_responses)
+ if response.get("statusCode") is None and (status_code := GatewayResponseCode[response_type]):
+ response["statusCode"] = status_code
+
+ return response
diff --git a/localstack-core/localstack/services/apigateway/patches.py b/localstack-core/localstack/services/apigateway/patches.py
new file mode 100644
index 0000000000000..253a5f54e8fd4
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/patches.py
@@ -0,0 +1,183 @@
+import json
+import logging
+
+from moto.apigateway import models as apigateway_models
+from moto.apigateway.exceptions import (
+ DeploymentNotFoundException,
+ NoIntegrationDefined,
+ RestAPINotFound,
+ StageStillActive,
+)
+from moto.apigateway.responses import APIGatewayResponse
+from moto.core.utils import camelcase_to_underscores
+
+from localstack.constants import TAG_KEY_CUSTOM_ID
+from localstack.services.apigateway.helpers import apply_json_patch_safe
+from localstack.utils.common import str_to_bool
+from localstack.utils.patch import patch
+
+LOG = logging.getLogger(__name__)
+
+
+def apply_patches():
+ # TODO refactor patches in this module (e.g., use @patch decorator, simplify, ...)
+
+ def apigateway_models_Stage_init(
+ self, cacheClusterEnabled=False, cacheClusterSize=None, **kwargs
+ ):
+ apigateway_models_Stage_init_orig(
+ self,
+ cacheClusterEnabled=cacheClusterEnabled,
+ cacheClusterSize=cacheClusterSize,
+ **kwargs,
+ )
+
+ if (cacheClusterSize or cacheClusterEnabled) and not self.cache_cluster_status:
+ self.cache_cluster_status = "AVAILABLE"
+
+ apigateway_models_Stage_init_orig = apigateway_models.Stage.__init__
+ apigateway_models.Stage.__init__ = apigateway_models_Stage_init
+
+ @patch(APIGatewayResponse.put_integration)
+ def apigateway_put_integration(fn, self, *args, **kwargs):
+ # TODO: verify if this patch is still necessary, this might have been fixed upstream
+ fn(self, *args, **kwargs)
+
+ url_path_parts = self.path.split("/")
+ function_id = url_path_parts[2]
+ resource_id = url_path_parts[4]
+ method_type = url_path_parts[6]
+ integration = self.backend.get_integration(function_id, resource_id, method_type)
+
+ timeout_milliseconds = self._get_param("timeoutInMillis")
+ cache_key_parameters = self._get_param("cacheKeyParameters") or []
+ content_handling = self._get_param("contentHandling")
+ integration.cache_namespace = resource_id
+ integration.timeout_in_millis = timeout_milliseconds
+ integration.cache_key_parameters = cache_key_parameters
+ integration.content_handling = content_handling
+ return 201, {}, json.dumps(integration.to_json())
+
+ # define json-patch operations for backend models
+
+ def backend_model_apply_operations(self, patch_operations):
+ # run pre-actions
+ if isinstance(self, apigateway_models.Stage) and [
+ op for op in patch_operations if "/accessLogSettings" in op.get("path", "")
+ ]:
+ self.access_log_settings = self.access_log_settings or {}
+ # apply patches
+ apply_json_patch_safe(self, patch_operations, in_place=True)
+ # run post-actions
+ if isinstance(self, apigateway_models.Stage):
+ bool_params = ["cacheClusterEnabled", "tracingEnabled"]
+ for bool_param in bool_params:
+ if getattr(self, camelcase_to_underscores(bool_param), None):
+ value = getattr(self, camelcase_to_underscores(bool_param), None)
+ setattr(self, camelcase_to_underscores(bool_param), str_to_bool(value))
+ return self
+
+ model_classes = [
+ apigateway_models.Authorizer,
+ apigateway_models.DomainName,
+ apigateway_models.MethodResponse,
+ ]
+ for model_class in model_classes:
+ model_class.apply_operations = model_class.apply_patch_operations = (
+ backend_model_apply_operations
+ )
+
+ # fix data types for some json-patch operation values
+
+ @patch(apigateway_models.Stage._get_default_method_settings)
+ def _get_default_method_settings(fn, self):
+ result = fn(self)
+ default_settings = self.method_settings.get("*/*", {})
+ result["cacheDataEncrypted"] = default_settings.get("cacheDataEncrypted", False)
+ result["throttlingRateLimit"] = default_settings.get("throttlingRateLimit", 10000.0)
+ result["throttlingBurstLimit"] = default_settings.get("throttlingBurstLimit", 5000)
+ result["metricsEnabled"] = default_settings.get("metricsEnabled", False)
+ result["dataTraceEnabled"] = default_settings.get("dataTraceEnabled", False)
+ result["unauthorizedCacheControlHeaderStrategy"] = default_settings.get(
+ "unauthorizedCacheControlHeaderStrategy", "SUCCEED_WITH_RESPONSE_HEADER"
+ )
+ result["cacheTtlInSeconds"] = default_settings.get("cacheTtlInSeconds", 300)
+ result["cachingEnabled"] = default_settings.get("cachingEnabled", False)
+ result["requireAuthorizationForCacheControl"] = default_settings.get(
+ "requireAuthorizationForCacheControl", True
+ )
+ return result
+
+ # patch integration error responses
+ @patch(apigateway_models.Resource.get_integration)
+ def apigateway_models_resource_get_integration(fn, self, method_type):
+ resource_method = self.resource_methods.get(method_type, {})
+ if not resource_method.method_integration:
+ raise NoIntegrationDefined()
+ return resource_method.method_integration
+
+ @patch(apigateway_models.RestAPI.to_dict)
+ def apigateway_models_rest_api_to_dict(fn, self):
+ resp = fn(self)
+ resp["policy"] = None
+ if self.policy:
+ # Strip whitespaces for TF compatibility (not entirely sure why we need double-dumps,
+ # but otherwise: "error normalizing policy JSON: invalid character 'V' after top-level value")
+ resp["policy"] = json.dumps(json.dumps(json.loads(self.policy), separators=(",", ":")))[
+ 1:-1
+ ]
+
+ if not self.tags:
+ resp["tags"] = None
+
+ resp["disableExecuteApiEndpoint"] = (
+ str(resp.get("disableExecuteApiEndpoint")).lower() == "true"
+ )
+
+ return resp
+
+ @patch(apigateway_models.Stage.to_json)
+ def apigateway_models_stage_to_json(fn, self):
+ result = fn(self)
+
+ if "documentationVersion" not in result:
+ result["documentationVersion"] = getattr(self, "documentation_version", None)
+
+ return result
+
+ # TODO remove this patch when the behavior is implemented in moto
+ @patch(apigateway_models.APIGatewayBackend.create_rest_api)
+ def create_rest_api(fn, self, *args, tags=None, **kwargs):
+ """
+ https://github.com/localstack/localstack/pull/4413/files
+ Add ability to specify custom IDs for API GW REST APIs via tags
+ """
+ tags = tags or {}
+ result = fn(self, *args, tags=tags, **kwargs)
+ # TODO: lower the custom_id when getting it from the tags, as AWS is case insensitive
+ if custom_id := tags.get(TAG_KEY_CUSTOM_ID):
+ self.apis.pop(result.id)
+ result.id = custom_id
+ self.apis[custom_id] = result
+ return result
+
+ @patch(apigateway_models.APIGatewayBackend.get_rest_api, pass_target=False)
+ def get_rest_api(self, function_id):
+ for key in self.apis.keys():
+ if key.lower() == function_id.lower():
+ return self.apis[key]
+ raise RestAPINotFound()
+
+ @patch(apigateway_models.RestAPI.delete_deployment, pass_target=False)
+ def patch_delete_deployment(self, deployment_id: str) -> apigateway_models.Deployment:
+ if deployment_id not in self.deployments:
+ raise DeploymentNotFoundException()
+ deployment = self.deployments[deployment_id]
+ if deployment.stage_name and (
+ (stage := self.stages.get(deployment.stage_name))
+ and stage.deployment_id == deployment.id
+ ):
+ # Stage is still active
+ raise StageStillActive()
+
+ return self.deployments.pop(deployment_id)
diff --git a/localstack/utils/cloudwatch/__init__.py b/localstack-core/localstack/services/apigateway/resource_providers/__init__.py
similarity index 100%
rename from localstack/utils/cloudwatch/__init__.py
rename to localstack-core/localstack/services/apigateway/resource_providers/__init__.py
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_account.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_account.py
new file mode 100644
index 0000000000000..8c78925a5a8b8
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_account.py
@@ -0,0 +1,110 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class ApiGatewayAccountProperties(TypedDict):
+ CloudWatchRoleArn: Optional[str]
+ Id: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class ApiGatewayAccountProvider(ResourceProvider[ApiGatewayAccountProperties]):
+ TYPE = "AWS::ApiGateway::Account" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[ApiGatewayAccountProperties],
+ ) -> ProgressEvent[ApiGatewayAccountProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+
+
+
+
+ Read-only properties:
+ - /properties/Id
+
+ IAM permissions required:
+ - apigateway:PATCH
+ - iam:GetRole
+ - iam:PassRole
+
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ role_arn = model["CloudWatchRoleArn"]
+ apigw.update_account(
+ patchOperations=[{"op": "replace", "path": "/cloudwatchRoleArn", "value": role_arn}]
+ )
+
+ model["Id"] = util.generate_default_name(
+ stack_name=request.stack_name, logical_resource_id=request.logical_resource_id
+ )
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[ApiGatewayAccountProperties],
+ ) -> ProgressEvent[ApiGatewayAccountProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - apigateway:GET
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[ApiGatewayAccountProperties],
+ ) -> ProgressEvent[ApiGatewayAccountProperties]:
+ """
+ Delete a resource
+
+
+ """
+ model = request.desired_state
+
+ # note: deletion of accounts is currently a no-op
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[ApiGatewayAccountProperties],
+ ) -> ProgressEvent[ApiGatewayAccountProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - apigateway:PATCH
+ - iam:GetRole
+ - iam:PassRole
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_account.schema.json b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_account.schema.json
new file mode 100644
index 0000000000000..3192ca8c3b443
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_account.schema.json
@@ -0,0 +1,46 @@
+{
+ "typeName": "AWS::ApiGateway::Account",
+ "description": "Resource Type definition for AWS::ApiGateway::Account",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-apigateway",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "description": "Primary identifier which is manually generated.",
+ "type": "string"
+ },
+ "CloudWatchRoleArn": {
+ "description": "The Amazon Resource Name (ARN) of an IAM role that has write access to CloudWatch Logs in your account.",
+ "type": "string"
+ }
+ },
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "apigateway:PATCH",
+ "iam:GetRole",
+ "iam:PassRole"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "apigateway:GET"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "apigateway:PATCH",
+ "iam:GetRole",
+ "iam:PassRole"
+ ]
+ },
+ "delete": {
+ "permissions": []
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_account_plugin.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_account_plugin.py
new file mode 100644
index 0000000000000..d7dc5c91ce0d1
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_account_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class ApiGatewayAccountProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::ApiGateway::Account"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.apigateway.resource_providers.aws_apigateway_account import (
+ ApiGatewayAccountProvider,
+ )
+
+ self.factory = ApiGatewayAccountProvider
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_apikey.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_apikey.py
new file mode 100644
index 0000000000000..1385cd6c5d01c
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_apikey.py
@@ -0,0 +1,136 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+from localstack.utils.objects import keys_to_lower
+
+
+class ApiGatewayApiKeyProperties(TypedDict):
+ APIKeyId: Optional[str]
+ CustomerId: Optional[str]
+ Description: Optional[str]
+ Enabled: Optional[bool]
+ GenerateDistinctId: Optional[bool]
+ Name: Optional[str]
+ StageKeys: Optional[list[StageKey]]
+ Tags: Optional[list[Tag]]
+ Value: Optional[str]
+
+
+class StageKey(TypedDict):
+ RestApiId: Optional[str]
+ StageName: Optional[str]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class ApiGatewayApiKeyProvider(ResourceProvider[ApiGatewayApiKeyProperties]):
+ TYPE = "AWS::ApiGateway::ApiKey" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[ApiGatewayApiKeyProperties],
+ ) -> ProgressEvent[ApiGatewayApiKeyProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/APIKeyId
+
+
+ Create-only properties:
+ - /properties/GenerateDistinctId
+ - /properties/Name
+ - /properties/Value
+
+ Read-only properties:
+ - /properties/APIKeyId
+
+ IAM permissions required:
+ - apigateway:POST
+ - apigateway:GET
+
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ params = util.select_attributes(
+ model, ["Description", "CustomerId", "Name", "Value", "Enabled", "StageKeys"]
+ )
+ params = keys_to_lower(params.copy())
+ if "enabled" in params:
+ params["enabled"] = bool(params["enabled"])
+
+ if model.get("Tags"):
+ params["tags"] = {tag["Key"]: tag["Value"] for tag in model["Tags"]}
+
+ response = apigw.create_api_key(**params)
+ model["APIKeyId"] = response["id"]
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[ApiGatewayApiKeyProperties],
+ ) -> ProgressEvent[ApiGatewayApiKeyProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - apigateway:GET
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[ApiGatewayApiKeyProperties],
+ ) -> ProgressEvent[ApiGatewayApiKeyProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - apigateway:DELETE
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ apigw.delete_api_key(apiKey=model["APIKeyId"])
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[ApiGatewayApiKeyProperties],
+ ) -> ProgressEvent[ApiGatewayApiKeyProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - apigateway:GET
+ - apigateway:PATCH
+ - apigateway:PUT
+ - apigateway:DELETE
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_apikey.schema.json b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_apikey.schema.json
new file mode 100644
index 0000000000000..4d58557451ff8
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_apikey.schema.json
@@ -0,0 +1,135 @@
+{
+ "typeName": "AWS::ApiGateway::ApiKey",
+ "description": "Resource Type definition for AWS::ApiGateway::ApiKey",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-apigateway",
+ "additionalProperties": false,
+ "properties": {
+ "APIKeyId": {
+ "description": "A Unique Key ID which identifies the API Key. Generated by the Create API and returned by the Read and List APIs ",
+ "type": "string"
+ },
+ "CustomerId": {
+ "description": "An AWS Marketplace customer identifier to use when integrating with the AWS SaaS Marketplace.",
+ "type": "string"
+ },
+ "Description": {
+ "description": "A description of the purpose of the API key.",
+ "type": "string"
+ },
+ "Enabled": {
+ "description": "Indicates whether the API key can be used by clients.",
+ "default": false,
+ "type": "boolean"
+ },
+ "GenerateDistinctId": {
+ "description": "Specifies whether the key identifier is distinct from the created API key value. This parameter is deprecated and should not be used.",
+ "type": "boolean"
+ },
+ "Name": {
+ "description": "A name for the API key. If you don't specify a name, AWS CloudFormation generates a unique physical ID and uses that ID for the API key name.",
+ "type": "string"
+ },
+ "StageKeys": {
+ "description": "A list of stages to associate with this API key.",
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/StageKey"
+ }
+ },
+ "Tags": {
+ "description": "An array of arbitrary tags (key-value pairs) to associate with the API key.",
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "Value": {
+ "description": "The value of the API key. Must be at least 20 characters long.",
+ "type": "string"
+ }
+ },
+ "definitions": {
+ "StageKey": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "RestApiId": {
+ "description": "The ID of a RestApi resource that includes the stage with which you want to associate the API key.",
+ "type": "string"
+ },
+ "StageName": {
+ "description": "The name of the stage with which to associate the API key. The stage must be included in the RestApi resource that you specified in the RestApiId property. ",
+ "type": "string"
+ }
+ }
+ },
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Key": {
+ "description": "The key name of the tag. You can specify a value that is 1 to 128 Unicode characters in length and cannot be prefixed with aws:. You can use any of the following characters: the set of Unicode letters, digits, whitespace, _, ., /, =, +, and -.",
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 128
+ },
+ "Value": {
+ "description": "The value for the tag. You can specify a value that is 0 to 256 Unicode characters in length and cannot be prefixed with aws:. You can use any of the following characters: the set of Unicode letters, digits, whitespace, _, ., /, =, +, and -. ",
+ "type": "string",
+ "maxLength": 256
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ }
+ },
+ "createOnlyProperties": [
+ "/properties/GenerateDistinctId",
+ "/properties/Name",
+ "/properties/Value"
+ ],
+ "writeOnlyProperties": [
+ "/properties/GenerateDistinctId"
+ ],
+ "primaryIdentifier": [
+ "/properties/APIKeyId"
+ ],
+ "readOnlyProperties": [
+ "/properties/APIKeyId"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "apigateway:POST",
+ "apigateway:GET"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "apigateway:GET"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "apigateway:GET",
+ "apigateway:PATCH",
+ "apigateway:PUT",
+ "apigateway:DELETE"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "apigateway:DELETE"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "apigateway:GET"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_apikey_plugin.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_apikey_plugin.py
new file mode 100644
index 0000000000000..352ec19eec4d3
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_apikey_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class ApiGatewayApiKeyProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::ApiGateway::ApiKey"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.apigateway.resource_providers.aws_apigateway_apikey import (
+ ApiGatewayApiKeyProvider,
+ )
+
+ self.factory = ApiGatewayApiKeyProvider
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_basepathmapping.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_basepathmapping.py
new file mode 100644
index 0000000000000..51debd7811631
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_basepathmapping.py
@@ -0,0 +1,122 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class ApiGatewayBasePathMappingProperties(TypedDict):
+ DomainName: Optional[str]
+ BasePath: Optional[str]
+ RestApiId: Optional[str]
+ Stage: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class ApiGatewayBasePathMappingProvider(ResourceProvider[ApiGatewayBasePathMappingProperties]):
+ TYPE = "AWS::ApiGateway::BasePathMapping" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[ApiGatewayBasePathMappingProperties],
+ ) -> ProgressEvent[ApiGatewayBasePathMappingProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/DomainName
+ - /properties/BasePath
+
+ Required properties:
+ - DomainName
+
+ Create-only properties:
+ - /properties/DomainName
+ - /properties/BasePath
+
+
+
+ IAM permissions required:
+ - apigateway:POST
+ - apigateway:GET
+
+ """
+
+ # TODO we are using restApiId for PhysicalResourceId
+ # check if we need to change it
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ params = {
+ "domainName": model.get("DomainName"),
+ "restApiId": model.get("RestApiId"),
+ **({"basePath": model.get("BasePath")} if model.get("BasePath") else {}),
+ **({"stage": model.get("Stage")} if model.get("Stage") else {}),
+ }
+ response = apigw.create_base_path_mapping(**params)
+ model["RestApiId"] = response["restApiId"]
+ # TODO: validations
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[ApiGatewayBasePathMappingProperties],
+ ) -> ProgressEvent[ApiGatewayBasePathMappingProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - apigateway:GET
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[ApiGatewayBasePathMappingProperties],
+ ) -> ProgressEvent[ApiGatewayBasePathMappingProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - apigateway:DELETE
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ apigw.delete_base_path_mapping(domainName=model["DomainName"], basePath=model["BasePath"])
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[ApiGatewayBasePathMappingProperties],
+ ) -> ProgressEvent[ApiGatewayBasePathMappingProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - apigateway:GET
+ - apigateway:DELETE
+ - apigateway:PATCH
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_basepathmapping.schema.json b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_basepathmapping.schema.json
new file mode 100644
index 0000000000000..ded5541adedac
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_basepathmapping.schema.json
@@ -0,0 +1,81 @@
+{
+ "typeName": "AWS::ApiGateway::BasePathMapping",
+ "description": "Resource Type definition for AWS::ApiGateway::BasePathMapping",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-apigateway",
+ "additionalProperties": false,
+ "properties": {
+ "BasePath": {
+ "type": "string",
+ "description": "The base path name that callers of the API must provide in the URL after the domain name."
+ },
+ "DomainName": {
+ "type": "string",
+ "description": "The DomainName of an AWS::ApiGateway::DomainName resource."
+ },
+ "RestApiId": {
+ "type": "string",
+ "description": "The ID of the API."
+ },
+ "Stage": {
+ "type": "string",
+ "description": "The name of the API's stage."
+ }
+ },
+ "required": [
+ "DomainName"
+ ],
+ "createOnlyProperties": [
+ "/properties/DomainName",
+ "/properties/BasePath"
+ ],
+ "primaryIdentifier": [
+ "/properties/DomainName",
+ "/properties/BasePath"
+ ],
+ "tagging": {
+ "taggable": false,
+ "tagOnCreate": false,
+ "tagUpdatable": false,
+ "cloudFormationSystemTags": false
+ },
+ "handlers": {
+ "create": {
+ "permissions": [
+ "apigateway:POST",
+ "apigateway:GET"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "apigateway:GET"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "apigateway:GET",
+ "apigateway:DELETE",
+ "apigateway:PATCH"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "apigateway:DELETE"
+ ]
+ },
+ "list": {
+ "handlerSchema": {
+ "properties": {
+ "DomainName": {
+ "$ref": "resource-schema.json#/properties/DomainName"
+ }
+ },
+ "required": [
+ "DomainName"
+ ]
+ },
+ "permissions": [
+ "apigateway:GET"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_basepathmapping_plugin.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_basepathmapping_plugin.py
new file mode 100644
index 0000000000000..2dcb4b036e9ef
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_basepathmapping_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class ApiGatewayBasePathMappingProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::ApiGateway::BasePathMapping"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.apigateway.resource_providers.aws_apigateway_basepathmapping import (
+ ApiGatewayBasePathMappingProvider,
+ )
+
+ self.factory = ApiGatewayBasePathMappingProvider
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_deployment.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_deployment.py
new file mode 100644
index 0000000000000..68bae12d2af24
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_deployment.py
@@ -0,0 +1,196 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class ApiGatewayDeploymentProperties(TypedDict):
+ RestApiId: Optional[str]
+ DeploymentCanarySettings: Optional[DeploymentCanarySettings]
+ DeploymentId: Optional[str]
+ Description: Optional[str]
+ StageDescription: Optional[StageDescription]
+ StageName: Optional[str]
+
+
+class DeploymentCanarySettings(TypedDict):
+ PercentTraffic: Optional[float]
+ StageVariableOverrides: Optional[dict]
+ UseStageCache: Optional[bool]
+
+
+class AccessLogSetting(TypedDict):
+ DestinationArn: Optional[str]
+ Format: Optional[str]
+
+
+class CanarySetting(TypedDict):
+ PercentTraffic: Optional[float]
+ StageVariableOverrides: Optional[dict]
+ UseStageCache: Optional[bool]
+
+
+class MethodSetting(TypedDict):
+ CacheDataEncrypted: Optional[bool]
+ CacheTtlInSeconds: Optional[int]
+ CachingEnabled: Optional[bool]
+ DataTraceEnabled: Optional[bool]
+ HttpMethod: Optional[str]
+ LoggingLevel: Optional[str]
+ MetricsEnabled: Optional[bool]
+ ResourcePath: Optional[str]
+ ThrottlingBurstLimit: Optional[int]
+ ThrottlingRateLimit: Optional[float]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+class StageDescription(TypedDict):
+ AccessLogSetting: Optional[AccessLogSetting]
+ CacheClusterEnabled: Optional[bool]
+ CacheClusterSize: Optional[str]
+ CacheDataEncrypted: Optional[bool]
+ CacheTtlInSeconds: Optional[int]
+ CachingEnabled: Optional[bool]
+ CanarySetting: Optional[CanarySetting]
+ ClientCertificateId: Optional[str]
+ DataTraceEnabled: Optional[bool]
+ Description: Optional[str]
+ DocumentationVersion: Optional[str]
+ LoggingLevel: Optional[str]
+ MethodSettings: Optional[list[MethodSetting]]
+ MetricsEnabled: Optional[bool]
+ Tags: Optional[list[Tag]]
+ ThrottlingBurstLimit: Optional[int]
+ ThrottlingRateLimit: Optional[float]
+ TracingEnabled: Optional[bool]
+ Variables: Optional[dict]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class ApiGatewayDeploymentProvider(ResourceProvider[ApiGatewayDeploymentProperties]):
+ TYPE = "AWS::ApiGateway::Deployment" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[ApiGatewayDeploymentProperties],
+ ) -> ProgressEvent[ApiGatewayDeploymentProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/DeploymentId
+ - /properties/RestApiId
+
+ Required properties:
+ - RestApiId
+
+ Create-only properties:
+ - /properties/DeploymentCanarySettings
+ - /properties/RestApiId
+
+ Read-only properties:
+ - /properties/DeploymentId
+
+ IAM permissions required:
+ - apigateway:POST
+
+ """
+ model = request.desired_state
+ api = request.aws_client_factory.apigateway
+
+ params = {"restApiId": model["RestApiId"]}
+
+ if model.get("StageName"):
+ params["stageName"] = model["StageName"]
+
+ if model.get("StageDescription"):
+ params["stageDescription"] = json.dumps(model["StageDescription"])
+
+ if model.get("Description"):
+ params["description"] = model["Description"]
+
+ response = api.create_deployment(**params)
+
+ model["DeploymentId"] = response["id"]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[ApiGatewayDeploymentProperties],
+ ) -> ProgressEvent[ApiGatewayDeploymentProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - apigateway:GET
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[ApiGatewayDeploymentProperties],
+ ) -> ProgressEvent[ApiGatewayDeploymentProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - apigateway:GET
+ - apigateway:DELETE
+ """
+ model = request.desired_state
+ api = request.aws_client_factory.apigateway
+
+ try:
+ # TODO: verify if AWS behaves the same?
+ get_stages = api.get_stages(
+ restApiId=model["RestApiId"], deploymentId=model["DeploymentId"]
+ )
+ if stages := get_stages["item"]:
+ for stage in stages:
+ api.delete_stage(restApiId=model["RestApiId"], stageName=stage["stageName"])
+
+ api.delete_deployment(restApiId=model["RestApiId"], deploymentId=model["DeploymentId"])
+ except api.exceptions.NotFoundException:
+ pass
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[ApiGatewayDeploymentProperties],
+ ) -> ProgressEvent[ApiGatewayDeploymentProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - apigateway:PATCH
+ - apigateway:GET
+ - apigateway:PUT
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_deployment.schema.json b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_deployment.schema.json
new file mode 100644
index 0000000000000..ab10bbf5e2a7a
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_deployment.schema.json
@@ -0,0 +1,318 @@
+{
+ "typeName": "AWS::ApiGateway::Deployment",
+ "description": "Resource Type definition for AWS::ApiGateway::Deployment",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-apigateway",
+ "additionalProperties": false,
+ "properties": {
+ "DeploymentId": {
+ "type": "string",
+ "description": "Primary Id for this resource"
+ },
+ "DeploymentCanarySettings": {
+ "$ref": "#/definitions/DeploymentCanarySettings",
+ "description": "Specifies settings for the canary deployment."
+ },
+ "Description": {
+ "type": "string",
+ "description": "A description of the purpose of the API Gateway deployment."
+ },
+ "RestApiId": {
+ "type": "string",
+ "description": "The ID of the RestApi resource to deploy. "
+ },
+ "StageDescription": {
+ "$ref": "#/definitions/StageDescription",
+ "description": "Configures the stage that API Gateway creates with this deployment."
+ },
+ "StageName": {
+ "type": "string",
+ "description": "A name for the stage that API Gateway creates with this deployment. Use only alphanumeric characters."
+ }
+ },
+ "definitions": {
+ "StageDescription": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "AccessLogSetting": {
+ "description": "Specifies settings for logging access in this stage.",
+ "$ref": "#/definitions/AccessLogSetting"
+ },
+ "CacheClusterEnabled": {
+ "description": "Indicates whether cache clustering is enabled for the stage.",
+ "type": "boolean"
+ },
+ "CacheClusterSize": {
+ "description": "The size of the stage's cache cluster.",
+ "type": "string"
+ },
+ "CacheDataEncrypted": {
+ "description": "The time-to-live (TTL) period, in seconds, that specifies how long API Gateway caches responses. ",
+ "type": "boolean"
+ },
+ "CacheTtlInSeconds": {
+ "description": "The time-to-live (TTL) period, in seconds, that specifies how long API Gateway caches responses. ",
+ "type": "integer"
+ },
+ "CachingEnabled": {
+ "description": "Indicates whether responses are cached and returned for requests. You must enable a cache cluster on the stage to cache responses.",
+ "type": "boolean"
+ },
+ "CanarySetting": {
+ "description": "Specifies settings for the canary deployment in this stage.",
+ "$ref": "#/definitions/CanarySetting"
+ },
+ "ClientCertificateId": {
+ "description": "The identifier of the client certificate that API Gateway uses to call your integration endpoints in the stage. ",
+ "type": "string"
+ },
+ "DataTraceEnabled": {
+ "description": "Indicates whether data trace logging is enabled for methods in the stage. API Gateway pushes these logs to Amazon CloudWatch Logs. ",
+ "type": "boolean"
+ },
+ "Description": {
+ "description": "A description of the purpose of the stage.",
+ "type": "string"
+ },
+ "DocumentationVersion": {
+ "description": "The version identifier of the API documentation snapshot.",
+ "type": "string"
+ },
+ "LoggingLevel": {
+ "description": "The logging level for this method. For valid values, see the loggingLevel property of the Stage resource in the Amazon API Gateway API Reference. ",
+ "type": "string"
+ },
+ "MethodSettings": {
+ "description": "Configures settings for all of the stage's methods.",
+ "type": "array",
+ "uniqueItems": true,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/MethodSetting"
+ }
+ },
+ "MetricsEnabled": {
+ "description": "Indicates whether Amazon CloudWatch metrics are enabled for methods in the stage.",
+ "type": "boolean"
+ },
+ "Tags": {
+ "description": "An array of arbitrary tags (key-value pairs) to associate with the stage.",
+ "type": "array",
+ "uniqueItems": false,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "ThrottlingBurstLimit": {
+ "description": "The number of burst requests per second that API Gateway permits across all APIs, stages, and methods in your AWS account.",
+ "type": "integer"
+ },
+ "ThrottlingRateLimit": {
+ "description": "The number of steady-state requests per second that API Gateway permits across all APIs, stages, and methods in your AWS account.",
+ "type": "number"
+ },
+ "TracingEnabled": {
+ "description": "Specifies whether active tracing with X-ray is enabled for this stage.",
+ "type": "boolean"
+ },
+ "Variables": {
+ "description": "A map that defines the stage variables. Variable names must consist of alphanumeric characters, and the values must match the following regular expression: [A-Za-z0-9-._~:/?#&=,]+. ",
+ "type": "object",
+ "additionalProperties": false,
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ }
+ }
+ },
+ "DeploymentCanarySettings": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "PercentTraffic": {
+ "description": "The percentage (0-100) of traffic diverted to a canary deployment.",
+ "type": "number"
+ },
+ "StageVariableOverrides": {
+ "description": "Stage variables overridden for a canary release deployment, including new stage variables introduced in the canary. These stage variables are represented as a string-to-string map between stage variable names and their values. Duplicates are not allowed.",
+ "type": "object",
+ "additionalProperties": false,
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ },
+ "UseStageCache": {
+ "description": "Whether the canary deployment uses the stage cache.",
+ "type": "boolean"
+ }
+ }
+ },
+ "AccessLogSetting": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "DestinationArn": {
+ "description": "The Amazon Resource Name (ARN) of the CloudWatch Logs log group or Kinesis Data Firehose delivery stream to receive access logs. If you specify a Kinesis Data Firehose delivery stream, the stream name must begin with amazon-apigateway-. ",
+ "type": "string"
+ },
+ "Format": {
+ "description": "A single line format of the access logs of data, as specified by selected $context variables. The format must include at least $context.requestId. ",
+ "type": "string"
+ }
+ }
+ },
+ "CanarySetting": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "PercentTraffic": {
+ "description": "The percent (0-100) of traffic diverted to a canary deployment.",
+ "type": "number"
+ },
+ "StageVariableOverrides": {
+ "description": "Stage variables overridden for a canary release deployment, including new stage variables introduced in the canary. These stage variables are represented as a string-to-string map between stage variable names and their values. ",
+ "type": "object",
+ "additionalProperties": false,
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ },
+ "UseStageCache": {
+ "description": "Whether the canary deployment uses the stage cache or not.",
+ "type": "boolean"
+ }
+ }
+ },
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Key": {
+ "description": "The key name of the tag",
+ "type": "string"
+ },
+ "Value": {
+ "description": "The value for the tag",
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ },
+ "MethodSetting": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "CacheDataEncrypted": {
+ "description": "Indicates whether the cached responses are encrypted",
+ "type": "boolean"
+ },
+ "CacheTtlInSeconds": {
+ "description": "The time-to-live (TTL) period, in seconds, that specifies how long API Gateway caches responses. ",
+ "type": "integer"
+ },
+ "CachingEnabled": {
+ "description": "Indicates whether responses are cached and returned for requests. You must enable a cache cluster on the stage to cache responses.",
+ "type": "boolean"
+ },
+ "DataTraceEnabled": {
+ "description": "Indicates whether data trace logging is enabled for methods in the stage. API Gateway pushes these logs to Amazon CloudWatch Logs. ",
+ "type": "boolean"
+ },
+ "HttpMethod": {
+ "description": "The HTTP method.",
+ "type": "string"
+ },
+ "LoggingLevel": {
+ "description": "The logging level for this method. For valid values, see the loggingLevel property of the Stage resource in the Amazon API Gateway API Reference. ",
+ "type": "string"
+ },
+ "MetricsEnabled": {
+ "description": "Indicates whether Amazon CloudWatch metrics are enabled for methods in the stage.",
+ "type": "boolean"
+ },
+ "ResourcePath": {
+ "description": "The resource path for this method. Forward slashes (/) are encoded as ~1 and the initial slash must include a forward slash. ",
+ "type": "string"
+ },
+ "ThrottlingBurstLimit": {
+ "description": "The number of burst requests per second that API Gateway permits across all APIs, stages, and methods in your AWS account.",
+ "type": "integer"
+ },
+ "ThrottlingRateLimit": {
+ "description": "The number of steady-state requests per second that API Gateway permits across all APIs, stages, and methods in your AWS account.",
+ "type": "number"
+ }
+ }
+ }
+ },
+ "taggable": true,
+ "required": [
+ "RestApiId"
+ ],
+ "createOnlyProperties": [
+ "/properties/DeploymentCanarySettings",
+ "/properties/RestApiId"
+ ],
+ "primaryIdentifier": [
+ "/properties/DeploymentId",
+ "/properties/RestApiId"
+ ],
+ "readOnlyProperties": [
+ "/properties/DeploymentId"
+ ],
+ "writeOnlyProperties": [
+ "/properties/StageName",
+ "/properties/StageDescription",
+ "/properties/DeploymentCanarySettings"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "apigateway:POST"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "apigateway:GET"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "apigateway:PATCH",
+ "apigateway:GET",
+ "apigateway:PUT"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "apigateway:GET",
+ "apigateway:DELETE"
+ ]
+ },
+ "list": {
+ "handlerSchema": {
+ "properties": {
+ "RestApiId": {
+ "$ref": "resource-schema.json#/properties/RestApiId"
+ }
+ },
+ "required": [
+ "RestApiId"
+ ]
+ },
+ "permissions": [
+ "apigateway:GET"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_deployment_plugin.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_deployment_plugin.py
new file mode 100644
index 0000000000000..80ff9801a1ed5
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_deployment_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class ApiGatewayDeploymentProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::ApiGateway::Deployment"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.apigateway.resource_providers.aws_apigateway_deployment import (
+ ApiGatewayDeploymentProvider,
+ )
+
+ self.factory = ApiGatewayDeploymentProvider
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_domainname.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_domainname.py
new file mode 100644
index 0000000000000..37a37946f91ce
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_domainname.py
@@ -0,0 +1,158 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+from localstack.utils.objects import keys_to_lower
+
+
+class ApiGatewayDomainNameProperties(TypedDict):
+ CertificateArn: Optional[str]
+ DistributionDomainName: Optional[str]
+ DistributionHostedZoneId: Optional[str]
+ DomainName: Optional[str]
+ EndpointConfiguration: Optional[EndpointConfiguration]
+ MutualTlsAuthentication: Optional[MutualTlsAuthentication]
+ OwnershipVerificationCertificateArn: Optional[str]
+ RegionalCertificateArn: Optional[str]
+ RegionalDomainName: Optional[str]
+ RegionalHostedZoneId: Optional[str]
+ SecurityPolicy: Optional[str]
+ Tags: Optional[list[Tag]]
+
+
+class EndpointConfiguration(TypedDict):
+ Types: Optional[list[str]]
+
+
+class MutualTlsAuthentication(TypedDict):
+ TruststoreUri: Optional[str]
+ TruststoreVersion: Optional[str]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class ApiGatewayDomainNameProvider(ResourceProvider[ApiGatewayDomainNameProperties]):
+ TYPE = "AWS::ApiGateway::DomainName" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[ApiGatewayDomainNameProperties],
+ ) -> ProgressEvent[ApiGatewayDomainNameProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/DomainName
+
+ Create-only properties:
+ - /properties/DomainName
+
+ Read-only properties:
+ - /properties/RegionalHostedZoneId
+ - /properties/DistributionDomainName
+ - /properties/RegionalDomainName
+ - /properties/DistributionHostedZoneId
+
+ IAM permissions required:
+ - apigateway:*
+
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ params = keys_to_lower(model.copy())
+ param_names = [
+ "certificateArn",
+ "domainName",
+ "endpointConfiguration",
+ "mutualTlsAuthentication",
+ "ownershipVerificationCertificateArn",
+ "regionalCertificateArn",
+ "securityPolicy",
+ ]
+ params = util.select_attributes(params, param_names)
+ if model.get("Tags"):
+ params["tags"] = {tag["key"]: tag["value"] for tag in model["Tags"]}
+
+ result = apigw.create_domain_name(**params)
+
+ hosted_zones = request.aws_client_factory.route53.list_hosted_zones()
+ """
+ The hardcoded value is the only one that should be returned but due limitations it is not possible to
+ use it.
+ """
+ if hosted_zones["HostedZones"]:
+ model["DistributionHostedZoneId"] = hosted_zones["HostedZones"][0]["Id"]
+ else:
+ model["DistributionHostedZoneId"] = "Z2FDTNDATAQYW2"
+
+ model["DistributionDomainName"] = result.get("distributionDomainName") or result.get(
+ "domainName"
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[ApiGatewayDomainNameProperties],
+ ) -> ProgressEvent[ApiGatewayDomainNameProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - apigateway:*
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[ApiGatewayDomainNameProperties],
+ ) -> ProgressEvent[ApiGatewayDomainNameProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - apigateway:*
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ apigw.delete_domain_name(domainName=model["DomainName"])
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[ApiGatewayDomainNameProperties],
+ ) -> ProgressEvent[ApiGatewayDomainNameProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - apigateway:*
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_domainname.schema.json b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_domainname.schema.json
new file mode 100644
index 0000000000000..c0b50b24f2c33
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_domainname.schema.json
@@ -0,0 +1,124 @@
+{
+ "typeName": "AWS::ApiGateway::DomainName",
+ "description": "Resource Type definition for AWS::ApiGateway::DomainName.",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-rpdk.git",
+ "definitions": {
+ "EndpointConfiguration": {
+ "type": "object",
+ "properties": {
+ "Types": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ },
+ "additionalProperties": false
+ },
+ "MutualTlsAuthentication": {
+ "type": "object",
+ "properties": {
+ "TruststoreUri": {
+ "type": "string"
+ },
+ "TruststoreVersion": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false
+ },
+ "Tag": {
+ "type": "object",
+ "properties": {
+ "Key": {
+ "type": "string"
+ },
+ "Value": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false
+ }
+ },
+ "properties": {
+ "DomainName": {
+ "type": "string"
+ },
+ "DistributionDomainName": {
+ "type": "string"
+ },
+ "DistributionHostedZoneId": {
+ "type": "string"
+ },
+ "EndpointConfiguration": {
+ "$ref": "#/definitions/EndpointConfiguration"
+ },
+ "MutualTlsAuthentication": {
+ "$ref": "#/definitions/MutualTlsAuthentication"
+ },
+ "RegionalDomainName": {
+ "type": "string"
+ },
+ "RegionalHostedZoneId": {
+ "type": "string"
+ },
+ "CertificateArn": {
+ "type": "string"
+ },
+ "RegionalCertificateArn": {
+ "type": "string"
+ },
+ "OwnershipVerificationCertificateArn": {
+ "type": "string"
+ },
+ "SecurityPolicy": {
+ "type": "string"
+ },
+ "Tags": {
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "primaryIdentifier": [
+ "/properties/DomainName"
+ ],
+ "createOnlyProperties": [
+ "/properties/DomainName"
+ ],
+ "readOnlyProperties": [
+ "/properties/RegionalHostedZoneId",
+ "/properties/DistributionDomainName",
+ "/properties/RegionalDomainName",
+ "/properties/DistributionHostedZoneId"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "apigateway:*"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "apigateway:*"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "apigateway:*"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "apigateway:*"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "apigateway:*"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_domainname_plugin.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_domainname_plugin.py
new file mode 100644
index 0000000000000..49e6db22f12d8
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_domainname_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class ApiGatewayDomainNameProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::ApiGateway::DomainName"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.apigateway.resource_providers.aws_apigateway_domainname import (
+ ApiGatewayDomainNameProvider,
+ )
+
+ self.factory = ApiGatewayDomainNameProvider
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_gatewayresponse.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_gatewayresponse.py
new file mode 100644
index 0000000000000..bb52d43256e7b
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_gatewayresponse.py
@@ -0,0 +1,122 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+from localstack.utils.objects import keys_to_lower
+
+
+class ApiGatewayGatewayResponseProperties(TypedDict):
+ ResponseType: Optional[str]
+ RestApiId: Optional[str]
+ Id: Optional[str]
+ ResponseParameters: Optional[dict]
+ ResponseTemplates: Optional[dict]
+ StatusCode: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class ApiGatewayGatewayResponseProvider(ResourceProvider[ApiGatewayGatewayResponseProperties]):
+ TYPE = "AWS::ApiGateway::GatewayResponse" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[ApiGatewayGatewayResponseProperties],
+ ) -> ProgressEvent[ApiGatewayGatewayResponseProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - ResponseType
+ - RestApiId
+
+ Create-only properties:
+ - /properties/ResponseType
+ - /properties/RestApiId
+
+ Read-only properties:
+ - /properties/Id
+
+ IAM permissions required:
+ - apigateway:PUT
+ - apigateway:GET
+
+ """
+ model = request.desired_state
+ api = request.aws_client_factory.apigateway
+ # TODO: validations
+ model["Id"] = util.generate_default_name_without_stack(request.logical_resource_id)
+
+ params = util.select_attributes(
+ model,
+ ["RestApiId", "ResponseType", "StatusCode", "ResponseParameters", "ResponseTemplates"],
+ )
+ params = keys_to_lower(params.copy())
+
+ api.put_gateway_response(**params)
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[ApiGatewayGatewayResponseProperties],
+ ) -> ProgressEvent[ApiGatewayGatewayResponseProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[ApiGatewayGatewayResponseProperties],
+ ) -> ProgressEvent[ApiGatewayGatewayResponseProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - apigateway:GET
+ - apigateway:DELETE
+ """
+ model = request.desired_state
+ api = request.aws_client_factory.apigateway
+
+ api.delete_gateway_response(
+ restApiId=model["RestApiId"], responseType=model["ResponseType"]
+ )
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[ApiGatewayGatewayResponseProperties],
+ ) -> ProgressEvent[ApiGatewayGatewayResponseProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - apigateway:GET
+ - apigateway:PUT
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_gatewayresponse.schema.json b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_gatewayresponse.schema.json
new file mode 100644
index 0000000000000..063b2c6c91ca4
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_gatewayresponse.schema.json
@@ -0,0 +1,84 @@
+{
+ "typeName": "AWS::ApiGateway::GatewayResponse",
+ "description": "Resource Type definition for AWS::ApiGateway::GatewayResponse",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-rpdk.git",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "description": "A Cloudformation auto generated ID.",
+ "type": "string"
+ },
+ "RestApiId": {
+ "description": "The identifier of the API.",
+ "type": "string"
+ },
+ "ResponseType": {
+ "description": "The type of the Gateway Response.",
+ "type": "string"
+ },
+ "StatusCode": {
+ "description": "The HTTP status code for the response.",
+ "type": "string"
+ },
+ "ResponseParameters": {
+ "description": "The response parameters (paths, query strings, and headers) for the response.",
+ "type": "object",
+ "additionalProperties": false,
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ },
+ "ResponseTemplates": {
+ "description": "The response templates for the response.",
+ "type": "object",
+ "additionalProperties": false,
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ }
+ },
+ "required": [
+ "ResponseType",
+ "RestApiId"
+ ],
+ "createOnlyProperties": [
+ "/properties/ResponseType",
+ "/properties/RestApiId"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ],
+ "taggable": false,
+ "handlers": {
+ "create": {
+ "permissions": [
+ "apigateway:PUT",
+ "apigateway:GET"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "apigateway:GET",
+ "apigateway:PUT"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "apigateway:GET",
+ "apigateway:DELETE"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "apigateway:GET"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_gatewayresponse_plugin.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_gatewayresponse_plugin.py
new file mode 100644
index 0000000000000..86f43d46cdd21
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_gatewayresponse_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class ApiGatewayGatewayResponseProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::ApiGateway::GatewayResponse"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.apigateway.resource_providers.aws_apigateway_gatewayresponse import (
+ ApiGatewayGatewayResponseProvider,
+ )
+
+ self.factory = ApiGatewayGatewayResponseProvider
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_method.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_method.py
new file mode 100644
index 0000000000000..64598a4463898
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_method.py
@@ -0,0 +1,234 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from copy import deepcopy
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class ApiGatewayMethodProperties(TypedDict):
+ HttpMethod: Optional[str]
+ ResourceId: Optional[str]
+ RestApiId: Optional[str]
+ ApiKeyRequired: Optional[bool]
+ AuthorizationScopes: Optional[list[str]]
+ AuthorizationType: Optional[str]
+ AuthorizerId: Optional[str]
+ Integration: Optional[Integration]
+ MethodResponses: Optional[list[MethodResponse]]
+ OperationName: Optional[str]
+ RequestModels: Optional[dict]
+ RequestParameters: Optional[dict]
+ RequestValidatorId: Optional[str]
+
+
+class IntegrationResponse(TypedDict):
+ StatusCode: Optional[str]
+ ContentHandling: Optional[str]
+ ResponseParameters: Optional[dict]
+ ResponseTemplates: Optional[dict]
+ SelectionPattern: Optional[str]
+
+
+class Integration(TypedDict):
+ Type: Optional[str]
+ CacheKeyParameters: Optional[list[str]]
+ CacheNamespace: Optional[str]
+ ConnectionId: Optional[str]
+ ConnectionType: Optional[str]
+ ContentHandling: Optional[str]
+ Credentials: Optional[str]
+ IntegrationHttpMethod: Optional[str]
+ IntegrationResponses: Optional[list[IntegrationResponse]]
+ PassthroughBehavior: Optional[str]
+ RequestParameters: Optional[dict]
+ RequestTemplates: Optional[dict]
+ TimeoutInMillis: Optional[int]
+ Uri: Optional[str]
+
+
+class MethodResponse(TypedDict):
+ StatusCode: Optional[str]
+ ResponseModels: Optional[dict]
+ ResponseParameters: Optional[dict]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class ApiGatewayMethodProvider(ResourceProvider[ApiGatewayMethodProperties]):
+ TYPE = "AWS::ApiGateway::Method" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[ApiGatewayMethodProperties],
+ ) -> ProgressEvent[ApiGatewayMethodProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/RestApiId
+ - /properties/ResourceId
+ - /properties/HttpMethod
+
+ Required properties:
+ - RestApiId
+ - ResourceId
+ - HttpMethod
+
+ Create-only properties:
+ - /properties/RestApiId
+ - /properties/ResourceId
+ - /properties/HttpMethod
+
+
+
+ IAM permissions required:
+ - apigateway:PUT
+ - apigateway:GET
+
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+ operation_model = apigw.meta.service_model.operation_model
+
+ apigw.put_method(
+ **util.convert_request_kwargs(model, operation_model("PutMethod").input_shape)
+ )
+
+ # setting up integrations
+ integration = model.get("Integration")
+ if integration:
+ apigw.put_integration(
+ restApiId=model.get("RestApiId"),
+ resourceId=model.get("ResourceId"),
+ httpMethod=model.get("HttpMethod"),
+ **util.convert_request_kwargs(
+ integration, operation_model("PutIntegration").input_shape
+ ),
+ )
+
+ integration_responses = integration.pop("IntegrationResponses", [])
+ for integration_response in integration_responses:
+ apigw.put_integration_response(
+ restApiId=model.get("RestApiId"),
+ resourceId=model.get("ResourceId"),
+ httpMethod=model.get("HttpMethod"),
+ **util.convert_request_kwargs(
+ integration_response, operation_model("PutIntegrationResponse").input_shape
+ ),
+ )
+
+ responses = model.get("MethodResponses", [])
+ for response in responses:
+ apigw.put_method_response(
+ restApiId=model.get("RestApiId"),
+ resourceId=model.get("ResourceId"),
+ httpMethod=model.get("HttpMethod"),
+ **util.convert_request_kwargs(
+ response, operation_model("PutMethodResponse").input_shape
+ ),
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[ApiGatewayMethodProperties],
+ ) -> ProgressEvent[ApiGatewayMethodProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - apigateway:GET
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[ApiGatewayMethodProperties],
+ ) -> ProgressEvent[ApiGatewayMethodProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - apigateway:DELETE
+ """
+
+ # FIXME we sometimes get warnings when calling this method, probably because
+ # restAPI or resource has been already deleted
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ try:
+ apigw.delete_method(
+ **util.convert_request_kwargs(
+ model, apigw.meta.service_model.operation_model("DeleteMethod").input_shape
+ )
+ )
+ except apigw.exceptions.NotFoundException:
+ pass
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[ApiGatewayMethodProperties],
+ ) -> ProgressEvent[ApiGatewayMethodProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - apigateway:GET
+ - apigateway:DELETE
+ - apigateway:PUT
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+ operation_model = apigw.meta.service_model.operation_model
+
+ must_params = util.select_attributes(
+ model,
+ [
+ "RestApiId",
+ "ResourceId",
+ "HttpMethod",
+ ],
+ )
+
+ if integration := deepcopy(model.get("Integration")):
+ integration.update(must_params)
+ apigw.put_integration(
+ **util.convert_request_kwargs(
+ integration, operation_model("PutIntegration").input_shape
+ )
+ )
+
+ else:
+ must_params.update({"AuthorizationType": model.get("AuthorizationType")})
+ apigw.put_method(
+ **util.convert_request_kwargs(must_params, operation_model("PutMethod").input_shape)
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_method.schema.json b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_method.schema.json
new file mode 100644
index 0000000000000..1b64f208e9c6d
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_method.schema.json
@@ -0,0 +1,318 @@
+{
+ "typeName": "AWS::ApiGateway::Method",
+ "description": "Resource Type definition for AWS::ApiGateway::Method",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-apigateway.git",
+ "definitions": {
+ "Integration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "CacheKeyParameters": {
+ "description": "A list of request parameters whose values API Gateway caches.",
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "CacheNamespace": {
+ "description": "An API-specific tag group of related cached parameters.",
+ "type": "string"
+ },
+ "ConnectionId": {
+ "description": "The ID of the VpcLink used for the integration when connectionType=VPC_LINK, otherwise undefined.",
+ "type": "string"
+ },
+ "ConnectionType": {
+ "description": "The type of the network connection to the integration endpoint.",
+ "type": "string",
+ "enum": [
+ "INTERNET",
+ "VPC_LINK"
+ ]
+ },
+ "ContentHandling": {
+ "description": "Specifies how to handle request payload content type conversions.",
+ "type": "string",
+ "enum": [
+ "CONVERT_TO_BINARY",
+ "CONVERT_TO_TEXT"
+ ]
+ },
+ "Credentials": {
+ "description": "The credentials that are required for the integration.",
+ "type": "string"
+ },
+ "IntegrationHttpMethod": {
+ "description": "The integration's HTTP method type.",
+ "type": "string"
+ },
+ "IntegrationResponses": {
+ "description": "The response that API Gateway provides after a method's backend completes processing a request.",
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/IntegrationResponse"
+ }
+ },
+ "PassthroughBehavior": {
+ "description": "Indicates when API Gateway passes requests to the targeted backend.",
+ "type": "string",
+ "enum": [
+ "WHEN_NO_MATCH",
+ "WHEN_NO_TEMPLATES",
+ "NEVER"
+ ]
+ },
+ "RequestParameters": {
+ "description": "The request parameters that API Gateway sends with the backend request.",
+ "type": "object",
+ "additionalProperties": false,
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ },
+ "RequestTemplates": {
+ "description": "A map of Apache Velocity templates that are applied on the request payload.",
+ "type": "object",
+ "additionalProperties": false,
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ },
+ "TimeoutInMillis": {
+ "description": "Custom timeout between 50 and 29,000 milliseconds.",
+ "type": "integer",
+ "minimum": 50,
+ "maximum": 29000
+ },
+ "Type": {
+ "description": "The type of backend that your method is running.",
+ "type": "string",
+ "enum": [
+ "AWS",
+ "AWS_PROXY",
+ "HTTP",
+ "HTTP_PROXY",
+ "MOCK"
+ ]
+ },
+ "Uri": {
+ "description": "The Uniform Resource Identifier (URI) for the integration.",
+ "type": "string"
+ }
+ },
+ "required": [
+ "Type"
+ ]
+ },
+ "MethodResponse": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "ResponseModels": {
+ "description": "The resources used for the response's content type. Specify response models as key-value pairs (string-to-string maps), with a content type as the key and a Model resource name as the value.",
+ "type": "object",
+ "additionalProperties": false,
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ },
+ "ResponseParameters": {
+ "description": "Response parameters that API Gateway sends to the client that called a method. Specify response parameters as key-value pairs (string-to-Boolean maps), with a destination as the key and a Boolean as the value.",
+ "type": "object",
+ "additionalProperties": false,
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "boolean"
+ }
+ }
+ },
+ "StatusCode": {
+ "description": "The method response's status code, which you map to an IntegrationResponse.",
+ "type": "string"
+ }
+ },
+ "required": [
+ "StatusCode"
+ ]
+ },
+ "IntegrationResponse": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "ContentHandling": {
+ "description": "Specifies how to handle request payload content type conversions.",
+ "type": "string",
+ "enum": [
+ "CONVERT_TO_BINARY",
+ "CONVERT_TO_TEXT"
+ ]
+ },
+ "ResponseParameters": {
+ "description": "The response parameters from the backend response that API Gateway sends to the method response.",
+ "type": "object",
+ "additionalProperties": false,
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ },
+ "ResponseTemplates": {
+ "description": "The templates that are used to transform the integration response body. Specify templates as key-value pairs (string-to-string mappings), with a content type as the key and a template as the value.",
+ "type": "object",
+ "additionalProperties": false,
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ },
+ "SelectionPattern": {
+ "description": "A regular expression that specifies which error strings or status codes from the backend map to the integration response.",
+ "type": "string"
+ },
+ "StatusCode": {
+ "description": "The status code that API Gateway uses to map the integration response to a MethodResponse status code.",
+ "type": "string"
+ }
+ },
+ "required": [
+ "StatusCode"
+ ]
+ }
+ },
+ "properties": {
+ "ApiKeyRequired": {
+ "description": "Indicates whether the method requires clients to submit a valid API key.",
+ "type": "boolean"
+ },
+ "AuthorizationScopes": {
+ "description": "A list of authorization scopes configured on the method.",
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "AuthorizationType": {
+ "description": "The method's authorization type.",
+ "type": "string",
+ "enum": [
+ "NONE",
+ "AWS_IAM",
+ "CUSTOM",
+ "COGNITO_USER_POOLS"
+ ]
+ },
+ "AuthorizerId": {
+ "description": "The identifier of the authorizer to use on this method.",
+ "type": "string"
+ },
+ "HttpMethod": {
+ "description": "The backend system that the method calls when it receives a request.",
+ "type": "string"
+ },
+ "Integration": {
+ "description": "The backend system that the method calls when it receives a request.",
+ "$ref": "#/definitions/Integration"
+ },
+ "MethodResponses": {
+ "description": "The responses that can be sent to the client who calls the method.",
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/MethodResponse"
+ }
+ },
+ "OperationName": {
+ "description": "A friendly operation name for the method.",
+ "type": "string"
+ },
+ "RequestModels": {
+ "description": "The resources that are used for the request's content type. Specify request models as key-value pairs (string-to-string mapping), with a content type as the key and a Model resource name as the value.",
+ "type": "object",
+ "additionalProperties": false,
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ },
+ "RequestParameters": {
+ "description": "The request parameters that API Gateway accepts. Specify request parameters as key-value pairs (string-to-Boolean mapping), with a source as the key and a Boolean as the value.",
+ "type": "object",
+ "additionalProperties": false,
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "boolean"
+ }
+ }
+ },
+ "RequestValidatorId": {
+ "description": "The ID of the associated request validator.",
+ "type": "string"
+ },
+ "ResourceId": {
+ "description": "The ID of an API Gateway resource.",
+ "type": "string"
+ },
+ "RestApiId": {
+ "description": "The ID of the RestApi resource in which API Gateway creates the method.",
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "RestApiId",
+ "ResourceId",
+ "HttpMethod"
+ ],
+ "primaryIdentifier": [
+ "/properties/RestApiId",
+ "/properties/ResourceId",
+ "/properties/HttpMethod"
+ ],
+ "createOnlyProperties": [
+ "/properties/RestApiId",
+ "/properties/ResourceId",
+ "/properties/HttpMethod"
+ ],
+ "tagging": {
+ "taggable": false,
+ "tagOnCreate": false,
+ "tagUpdatable": false,
+ "cloudFormationSystemTags": false
+ },
+ "handlers": {
+ "create": {
+ "permissions": [
+ "apigateway:PUT",
+ "apigateway:GET"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "apigateway:GET"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "apigateway:GET",
+ "apigateway:DELETE",
+ "apigateway:PUT"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "apigateway:DELETE"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_method_plugin.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_method_plugin.py
new file mode 100644
index 0000000000000..34e0cec7971a9
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_method_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class ApiGatewayMethodProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::ApiGateway::Method"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.apigateway.resource_providers.aws_apigateway_method import (
+ ApiGatewayMethodProvider,
+ )
+
+ self.factory = ApiGatewayMethodProvider
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_model.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_model.py
new file mode 100644
index 0000000000000..07883e62983ca
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_model.py
@@ -0,0 +1,134 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class ApiGatewayModelProperties(TypedDict):
+ RestApiId: Optional[str]
+ ContentType: Optional[str]
+ Description: Optional[str]
+ Name: Optional[str]
+ Schema: Optional[dict | str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class ApiGatewayModelProvider(ResourceProvider[ApiGatewayModelProperties]):
+ TYPE = "AWS::ApiGateway::Model" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[ApiGatewayModelProperties],
+ ) -> ProgressEvent[ApiGatewayModelProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/RestApiId
+ - /properties/Name
+
+ Required properties:
+ - RestApiId
+
+ Create-only properties:
+ - /properties/ContentType
+ - /properties/Name
+ - /properties/RestApiId
+
+
+
+ IAM permissions required:
+ - apigateway:POST
+ - apigateway:GET
+
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ if not model.get("Name"):
+ model["Name"] = util.generate_default_name(
+ stack_name=request.stack_name, logical_resource_id=request.logical_resource_id
+ )
+
+ if not model.get("ContentType"):
+ model["ContentType"] = "application/json"
+
+ schema = json.dumps(model.get("Schema", {}))
+
+ apigw.create_model(
+ restApiId=model["RestApiId"],
+ name=model["Name"],
+ contentType=model["ContentType"],
+ schema=schema,
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[ApiGatewayModelProperties],
+ ) -> ProgressEvent[ApiGatewayModelProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - apigateway:GET
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[ApiGatewayModelProperties],
+ ) -> ProgressEvent[ApiGatewayModelProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - apigateway:GET
+ - apigateway:DELETE
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+ try:
+ apigw.delete_model(modelName=model["Name"], restApiId=model["RestApiId"])
+ except apigw.exceptions.NotFoundException:
+ # We are using try/except since at the moment
+ # CFN doesn't properly resolve dependency between resources
+ # so this resource could be deleted if parent resource was deleted first
+ pass
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[ApiGatewayModelProperties],
+ ) -> ProgressEvent[ApiGatewayModelProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - apigateway:PATCH
+ - apigateway:GET
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_model.schema.json b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_model.schema.json
new file mode 100644
index 0000000000000..7196fd5cc44b0
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_model.schema.json
@@ -0,0 +1,83 @@
+{
+ "typeName": "AWS::ApiGateway::Model",
+ "description": "Resource Type definition for AWS::ApiGateway::Model",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-apigateway",
+ "additionalProperties": false,
+ "properties": {
+ "ContentType": {
+ "type": "string",
+ "description": "The content type for the model."
+ },
+ "Description": {
+ "type": "string",
+ "description": "A description that identifies this model."
+ },
+ "Name": {
+ "type": "string",
+ "description": "A name for the model. If you don't specify a name, AWS CloudFormation generates a unique physical ID and uses that ID for the model name."
+ },
+ "RestApiId": {
+ "type": "string",
+ "description": "The ID of a REST API with which to associate this model."
+ },
+ "Schema": {
+ "description": "The schema to use to transform data to one or more output formats. Specify null ({}) if you don't want to specify a schema.",
+ "type": [
+ "object",
+ "string"
+ ]
+ }
+ },
+ "required": [
+ "RestApiId"
+ ],
+ "createOnlyProperties": [
+ "/properties/ContentType",
+ "/properties/Name",
+ "/properties/RestApiId"
+ ],
+ "primaryIdentifier": [
+ "/properties/RestApiId",
+ "/properties/Name"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "apigateway:POST",
+ "apigateway:GET"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "apigateway:GET"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "apigateway:PATCH",
+ "apigateway:GET"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "apigateway:GET",
+ "apigateway:DELETE"
+ ]
+ },
+ "list": {
+ "handlerSchema": {
+ "properties": {
+ "RestApiId": {
+ "$ref": "resource-schema.json#/properties/RestApiId"
+ }
+ },
+ "required": [
+ "RestApiId"
+ ]
+ },
+ "permissions": [
+ "apigateway:GET"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_model_plugin.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_model_plugin.py
new file mode 100644
index 0000000000000..d1bd727b602e5
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_model_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class ApiGatewayModelProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::ApiGateway::Model"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.apigateway.resource_providers.aws_apigateway_model import (
+ ApiGatewayModelProvider,
+ )
+
+ self.factory = ApiGatewayModelProvider
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_requestvalidator.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_requestvalidator.py
new file mode 100644
index 0000000000000..55d2a3bc4964e
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_requestvalidator.py
@@ -0,0 +1,125 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class ApiGatewayRequestValidatorProperties(TypedDict):
+ RestApiId: Optional[str]
+ Name: Optional[str]
+ RequestValidatorId: Optional[str]
+ ValidateRequestBody: Optional[bool]
+ ValidateRequestParameters: Optional[bool]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class ApiGatewayRequestValidatorProvider(ResourceProvider[ApiGatewayRequestValidatorProperties]):
+ TYPE = "AWS::ApiGateway::RequestValidator" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[ApiGatewayRequestValidatorProperties],
+ ) -> ProgressEvent[ApiGatewayRequestValidatorProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/RestApiId
+ - /properties/RequestValidatorId
+
+ Required properties:
+ - RestApiId
+
+ Create-only properties:
+ - /properties/Name
+ - /properties/RestApiId
+
+ Read-only properties:
+ - /properties/RequestValidatorId
+
+ IAM permissions required:
+ - apigateway:POST
+ - apigateway:GET
+
+ """
+ model = request.desired_state
+ api = request.aws_client_factory.apigateway
+
+ if not model.get("Name"):
+ model["Name"] = util.generate_default_name(
+ request.stack_name, request.logical_resource_id
+ )
+ response = api.create_request_validator(
+ name=model["Name"],
+ restApiId=model["RestApiId"],
+ validateRequestBody=model.get("ValidateRequestBody", False),
+ validateRequestParameters=model.get("ValidateRequestParameters", False),
+ )
+ model["RequestValidatorId"] = response["id"]
+ # FIXME error happens when other resources try to reference this one
+ # "An error occurred (BadRequestException) when calling the PutMethod operation:
+ # Invalid Request Validator identifier specified"
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[ApiGatewayRequestValidatorProperties],
+ ) -> ProgressEvent[ApiGatewayRequestValidatorProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - apigateway:GET
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[ApiGatewayRequestValidatorProperties],
+ ) -> ProgressEvent[ApiGatewayRequestValidatorProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - apigateway:DELETE
+ """
+ model = request.desired_state
+ api = request.aws_client_factory.apigateway
+
+ api.delete_request_validator(
+ restApiId=model["RestApiId"], requestValidatorId=model["RequestValidatorId"]
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[ApiGatewayRequestValidatorProperties],
+ ) -> ProgressEvent[ApiGatewayRequestValidatorProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - apigateway:PATCH
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_requestvalidator.schema.json b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_requestvalidator.schema.json
new file mode 100644
index 0000000000000..39d00e7be7d6d
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_requestvalidator.schema.json
@@ -0,0 +1,80 @@
+{
+ "typeName": "AWS::ApiGateway::RequestValidator",
+ "description": "Resource Type definition for AWS::ApiGateway::RequestValidator",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-apigateway",
+ "additionalProperties": false,
+ "properties": {
+ "RequestValidatorId": {
+ "description": "ID of the request validator.",
+ "type": "string"
+ },
+ "Name": {
+ "description": "Name of the request validator.",
+ "type": "string"
+ },
+ "RestApiId": {
+ "description": "The identifier of the targeted API entity.",
+ "type": "string"
+ },
+ "ValidateRequestBody": {
+ "description": "Indicates whether to validate the request body according to the configured schema for the targeted API and method. ",
+ "type": "boolean"
+ },
+ "ValidateRequestParameters": {
+ "description": "Indicates whether to validate request parameters.",
+ "type": "boolean"
+ }
+ },
+ "required": [
+ "RestApiId"
+ ],
+ "createOnlyProperties": [
+ "/properties/Name",
+ "/properties/RestApiId"
+ ],
+ "readOnlyProperties": [
+ "/properties/RequestValidatorId"
+ ],
+ "primaryIdentifier": [
+ "/properties/RestApiId",
+ "/properties/RequestValidatorId"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "apigateway:POST",
+ "apigateway:GET"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "apigateway:PATCH"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "apigateway:DELETE"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "apigateway:GET"
+ ]
+ },
+ "list": {
+ "handlerSchema": {
+ "properties": {
+ "RestApiId": {
+ "$ref": "resource-schema.json#/properties/RestApiId"
+ }
+ },
+ "required": [
+ "RestApiId"
+ ]
+ },
+ "permissions": [
+ "apigateway:GET"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_requestvalidator_plugin.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_requestvalidator_plugin.py
new file mode 100644
index 0000000000000..41175341a69de
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_requestvalidator_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class ApiGatewayRequestValidatorProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::ApiGateway::RequestValidator"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.apigateway.resource_providers.aws_apigateway_requestvalidator import (
+ ApiGatewayRequestValidatorProvider,
+ )
+
+ self.factory = ApiGatewayRequestValidatorProvider
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_resource.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_resource.py
new file mode 100644
index 0000000000000..89b868306e68d
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_resource.py
@@ -0,0 +1,168 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+from botocore.exceptions import ClientError
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.aws.api.cloudcontrol import InvalidRequestException, ResourceNotFoundException
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class ApiGatewayResourceProperties(TypedDict):
+ ParentId: Optional[str]
+ PathPart: Optional[str]
+ RestApiId: Optional[str]
+ ResourceId: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class ApiGatewayResourceProvider(ResourceProvider[ApiGatewayResourceProperties]):
+ TYPE = "AWS::ApiGateway::Resource" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[ApiGatewayResourceProperties],
+ ) -> ProgressEvent[ApiGatewayResourceProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/RestApiId
+ - /properties/ResourceId
+
+ Required properties:
+ - ParentId
+ - PathPart
+ - RestApiId
+
+ Create-only properties:
+ - /properties/PathPart
+ - /properties/ParentId
+ - /properties/RestApiId
+
+ Read-only properties:
+ - /properties/ResourceId
+
+ IAM permissions required:
+ - apigateway:POST
+
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ params = {
+ "restApiId": model.get("RestApiId"),
+ "pathPart": model.get("PathPart"),
+ "parentId": model.get("ParentId"),
+ }
+ if not params.get("parentId"):
+ # get root resource id
+ resources = apigw.get_resources(restApiId=params["restApiId"])["items"]
+ root_resource = ([r for r in resources if r["path"] == "/"] or [None])[0]
+ if not root_resource:
+ raise Exception(
+ "Unable to find root resource for REST API %s" % params["restApiId"]
+ )
+ params["parentId"] = root_resource["id"]
+ response = apigw.create_resource(**params)
+
+ model["ResourceId"] = response["id"]
+ model["ParentId"] = response["parentId"]
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[ApiGatewayResourceProperties],
+ ) -> ProgressEvent[ApiGatewayResourceProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - apigateway:GET
+ """
+ raise NotImplementedError
+
+ def list(
+ self,
+ request: ResourceRequest[ApiGatewayResourceProperties],
+ ) -> ProgressEvent[ApiGatewayResourceProperties]:
+ if "RestApiId" not in request.desired_state:
+ # TODO: parity
+ raise InvalidRequestException(
+ f"Missing or invalid ResourceModel property in {self.TYPE} list handler request input: 'RestApiId'"
+ )
+
+ rest_api_id = request.desired_state["RestApiId"]
+ try:
+ resources = request.aws_client_factory.apigateway.get_resources(restApiId=rest_api_id)[
+ "items"
+ ]
+ except ClientError as exc:
+ if exc.response.get("Error", {}).get("Code", {}) == "NotFoundException":
+ raise ResourceNotFoundException(f"Invalid API identifier specified: {rest_api_id}")
+ raise
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_models=[
+ ApiGatewayResourceProperties(
+ RestApiId=rest_api_id,
+ ResourceId=resource["id"],
+ ParentId=resource.get("parentId"),
+ PathPart=resource.get("path"),
+ )
+ for resource in resources
+ ],
+ )
+
+ def delete(
+ self,
+ request: ResourceRequest[ApiGatewayResourceProperties],
+ ) -> ProgressEvent[ApiGatewayResourceProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - apigateway:DELETE
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ try:
+ apigw.delete_resource(restApiId=model["RestApiId"], resourceId=model["ResourceId"])
+ except apigw.exceptions.NotFoundException:
+ pass
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[ApiGatewayResourceProperties],
+ ) -> ProgressEvent[ApiGatewayResourceProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - apigateway:GET
+ - apigateway:PATCH
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_resource.schema.json b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_resource.schema.json
new file mode 100644
index 0000000000000..7eaa8175b1827
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_resource.schema.json
@@ -0,0 +1,80 @@
+{
+ "typeName": "AWS::ApiGateway::Resource",
+ "description": "Resource Type definition for AWS::ApiGateway::Resource",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-apigateway",
+ "additionalProperties": false,
+ "properties": {
+ "ResourceId": {
+ "description": "A unique primary identifier for a Resource",
+ "type": "string"
+ },
+ "RestApiId": {
+ "description": "The ID of the RestApi resource in which you want to create this resource..",
+ "type": "string"
+ },
+ "ParentId": {
+ "description": "The parent resource's identifier.",
+ "type": "string"
+ },
+ "PathPart": {
+ "description": "The last path segment for this resource.",
+ "type": "string"
+ }
+ },
+ "taggable": false,
+ "required": [
+ "ParentId",
+ "PathPart",
+ "RestApiId"
+ ],
+ "createOnlyProperties": [
+ "/properties/PathPart",
+ "/properties/ParentId",
+ "/properties/RestApiId"
+ ],
+ "primaryIdentifier": [
+ "/properties/RestApiId",
+ "/properties/ResourceId"
+ ],
+ "readOnlyProperties": [
+ "/properties/ResourceId"
+ ],
+ "handlers": {
+ "read": {
+ "permissions": [
+ "apigateway:GET"
+ ]
+ },
+ "create": {
+ "permissions": [
+ "apigateway:POST"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "apigateway:GET",
+ "apigateway:PATCH"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "apigateway:DELETE"
+ ]
+ },
+ "list": {
+ "handlerSchema": {
+ "properties": {
+ "RestApiId": {
+ "$ref": "resource-schema.json#/properties/RestApiId"
+ }
+ },
+ "required": [
+ "RestApiId"
+ ]
+ },
+ "permissions": [
+ "apigateway:GET"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_resource_plugin.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_resource_plugin.py
new file mode 100644
index 0000000000000..f7ece7204435d
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_resource_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class ApiGatewayResourceProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::ApiGateway::Resource"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.apigateway.resource_providers.aws_apigateway_resource import (
+ ApiGatewayResourceProvider,
+ )
+
+ self.factory = ApiGatewayResourceProvider
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_restapi.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_restapi.py
new file mode 100644
index 0000000000000..c90e2b36f328b
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_restapi.py
@@ -0,0 +1,245 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+from localstack.utils.objects import keys_to_lower
+from localstack.utils.strings import to_bytes
+
+
+class ApiGatewayRestApiProperties(TypedDict):
+ ApiKeySourceType: Optional[str]
+ BinaryMediaTypes: Optional[list[str]]
+ Body: Optional[dict | str]
+ BodyS3Location: Optional[S3Location]
+ CloneFrom: Optional[str]
+ Description: Optional[str]
+ DisableExecuteApiEndpoint: Optional[bool]
+ EndpointConfiguration: Optional[EndpointConfiguration]
+ FailOnWarnings: Optional[bool]
+ MinimumCompressionSize: Optional[int]
+ Mode: Optional[str]
+ Name: Optional[str]
+ Parameters: Optional[dict | str]
+ Policy: Optional[dict | str]
+ RestApiId: Optional[str]
+ RootResourceId: Optional[str]
+ Tags: Optional[list[Tag]]
+
+
+class S3Location(TypedDict):
+ Bucket: Optional[str]
+ ETag: Optional[str]
+ Key: Optional[str]
+ Version: Optional[str]
+
+
+class EndpointConfiguration(TypedDict):
+ Types: Optional[list[str]]
+ VpcEndpointIds: Optional[list[str]]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class ApiGatewayRestApiProvider(ResourceProvider[ApiGatewayRestApiProperties]):
+ TYPE = "AWS::ApiGateway::RestApi" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[ApiGatewayRestApiProperties],
+ ) -> ProgressEvent[ApiGatewayRestApiProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/RestApiId
+
+
+ Read-only properties:
+ - /properties/RestApiId
+ - /properties/RootResourceId
+
+ IAM permissions required:
+ - apigateway:GET
+ - apigateway:POST
+ - apigateway:UpdateRestApiPolicy
+ - s3:GetObject
+ - iam:PassRole
+
+ """
+ model = request.desired_state
+ api = request.aws_client_factory.apigateway
+
+ # FIXME: this is only when Body or BodyS3Location is set, otherwise the deployment should fail without a name
+ role_name = model.get("Name")
+ if not role_name:
+ model["Name"] = util.generate_default_name(
+ request.stack_name, request.logical_resource_id
+ )
+ params = util.select_attributes(
+ model,
+ [
+ "Name",
+ "Description",
+ "Version",
+ "CloneFrom",
+ "BinaryMediaTypes",
+ "MinimumCompressionSize",
+ "ApiKeySource",
+ "EndpointConfiguration",
+ "Policy",
+ "Tags",
+ "DisableExecuteApiEndpoint",
+ ],
+ )
+ params = keys_to_lower(params, skip_children_of=["policy"])
+ params["tags"] = {tag["key"]: tag["value"] for tag in params.get("tags", [])}
+
+ cfn_client = request.aws_client_factory.cloudformation
+ stack_id = cfn_client.describe_stacks(StackName=request.stack_name)["Stacks"][0]["StackId"]
+ params["tags"].update(
+ {
+ "aws:cloudformation:logical-id": request.logical_resource_id,
+ "aws:cloudformation:stack-name": request.stack_name,
+ "aws:cloudformation:stack-id": stack_id,
+ }
+ )
+ if isinstance(params.get("policy"), dict):
+ params["policy"] = json.dumps(params["policy"])
+
+ result = api.create_rest_api(**params)
+ model["RestApiId"] = result["id"]
+
+ body = model.get("Body")
+ s3_body_location = model.get("BodyS3Location")
+ if body or s3_body_location:
+ # the default behavior for imports via CFn is basepath=ignore (validated against AWS)
+ import_parameters = model.get("Parameters", {})
+ import_parameters.setdefault("basepath", "ignore")
+
+ if body:
+ body = json.dumps(body) if isinstance(body, dict) else body
+ else:
+ get_obj_kwargs = {}
+ if version_id := s3_body_location.get("Version"):
+ get_obj_kwargs["VersionId"] = version_id
+
+ # what is the approach when client call fail? Do we bubble it up?
+ s3_client = request.aws_client_factory.s3
+ get_obj_req = s3_client.get_object(
+ Bucket=s3_body_location.get("Bucket"),
+ Key=s3_body_location.get("Key"),
+ **get_obj_kwargs,
+ )
+ if etag := s3_body_location.get("ETag"):
+ if etag != get_obj_req["ETag"]:
+ # TODO: validate the exception message
+ raise Exception(
+ "The ETag provided for the S3BodyLocation does not match the S3 Object"
+ )
+ body = get_obj_req["Body"].read()
+
+ put_kwargs = {}
+ if import_mode := model.get("Mode"):
+ put_kwargs["mode"] = import_mode
+ if fail_on_warnings_mode := model.get("FailOnWarnings"):
+ put_kwargs["failOnWarnings"] = fail_on_warnings_mode
+
+ api.put_rest_api(
+ restApiId=result["id"],
+ body=to_bytes(body),
+ parameters=import_parameters,
+ **put_kwargs,
+ )
+
+ resources = api.get_resources(restApiId=result["id"])["items"]
+ for res in resources:
+ if res["path"] == "/" and not res.get("parentId"):
+ model["RootResourceId"] = res["id"]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[ApiGatewayRestApiProperties],
+ ) -> ProgressEvent[ApiGatewayRestApiProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - apigateway:GET
+ """
+ raise NotImplementedError
+
+ def list(
+ self,
+ request: ResourceRequest[ApiGatewayRestApiProperties],
+ ) -> ProgressEvent[ApiGatewayRestApiProperties]:
+ # TODO: pagination
+ resources = request.aws_client_factory.apigateway.get_rest_apis()["items"]
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_models=[
+ ApiGatewayRestApiProperties(RestApiId=resource["id"], Name=resource["name"])
+ for resource in resources
+ ],
+ )
+
+ def delete(
+ self,
+ request: ResourceRequest[ApiGatewayRestApiProperties],
+ ) -> ProgressEvent[ApiGatewayRestApiProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - apigateway:DELETE
+ """
+ model = request.desired_state
+ api = request.aws_client_factory.apigateway
+
+ api.delete_rest_api(restApiId=model["RestApiId"])
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[ApiGatewayRestApiProperties],
+ ) -> ProgressEvent[ApiGatewayRestApiProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - apigateway:GET
+ - apigateway:DELETE
+ - apigateway:PATCH
+ - apigateway:PUT
+ - apigateway:UpdateRestApiPolicy
+ - s3:GetObject
+ - iam:PassRole
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_restapi.schema.json b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_restapi.schema.json
new file mode 100644
index 0000000000000..73e6f5dda9447
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_restapi.schema.json
@@ -0,0 +1,197 @@
+{
+ "typeName": "AWS::ApiGateway::RestApi",
+ "description": "Resource Type definition for AWS::ApiGateway::RestApi.",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-rpdk.git",
+ "additionalProperties": false,
+ "definitions": {
+ "EndpointConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Types": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "VpcEndpointIds": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ }
+ }
+ },
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Key": {
+ "type": "string"
+ },
+ "Value": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Key",
+ "Value"
+ ]
+ },
+ "S3Location": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Bucket": {
+ "type": "string"
+ },
+ "ETag": {
+ "type": "string"
+ },
+ "Version": {
+ "type": "string"
+ },
+ "Key": {
+ "type": "string"
+ }
+ }
+ }
+ },
+ "properties": {
+ "RestApiId": {
+ "type": "string"
+ },
+ "RootResourceId": {
+ "type": "string"
+ },
+ "ApiKeySourceType": {
+ "type": "string"
+ },
+ "BinaryMediaTypes": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Body": {
+ "type": [
+ "object",
+ "string"
+ ]
+ },
+ "BodyS3Location": {
+ "$ref": "#/definitions/S3Location"
+ },
+ "CloneFrom": {
+ "type": "string"
+ },
+ "EndpointConfiguration": {
+ "$ref": "#/definitions/EndpointConfiguration"
+ },
+ "Description": {
+ "type": "string"
+ },
+ "DisableExecuteApiEndpoint": {
+ "type": "boolean"
+ },
+ "FailOnWarnings": {
+ "type": "boolean"
+ },
+ "Name": {
+ "type": "string"
+ },
+ "MinimumCompressionSize": {
+ "type": "integer"
+ },
+ "Mode": {
+ "type": "string"
+ },
+ "Policy": {
+ "type": [
+ "object",
+ "string"
+ ]
+ },
+ "Parameters": {
+ "type": [
+ "object",
+ "string"
+ ],
+ "additionalProperties": false,
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ },
+ "Tags": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ }
+ },
+ "tagging": {
+ "taggable": true,
+ "tagOnCreate": true,
+ "tagUpdatable": true,
+ "cloudFormationSystemTags": true,
+ "tagProperty": "/properties/Tags"
+ },
+ "primaryIdentifier": [
+ "/properties/RestApiId"
+ ],
+ "readOnlyProperties": [
+ "/properties/RestApiId",
+ "/properties/RootResourceId"
+ ],
+ "writeOnlyProperties": [
+ "/properties/Body",
+ "/properties/BodyS3Location",
+ "/properties/CloneFrom",
+ "/properties/FailOnWarnings",
+ "/properties/Mode",
+ "/properties/Parameters"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "apigateway:GET",
+ "apigateway:POST",
+ "apigateway:UpdateRestApiPolicy",
+ "s3:GetObject",
+ "iam:PassRole"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "apigateway:GET"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "apigateway:GET",
+ "apigateway:DELETE",
+ "apigateway:PATCH",
+ "apigateway:PUT",
+ "apigateway:UpdateRestApiPolicy",
+ "s3:GetObject",
+ "iam:PassRole"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "apigateway:DELETE"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "apigateway:GET"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_restapi_plugin.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_restapi_plugin.py
new file mode 100644
index 0000000000000..e53c4a4d8205f
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_restapi_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class ApiGatewayRestApiProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::ApiGateway::RestApi"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.apigateway.resource_providers.aws_apigateway_restapi import (
+ ApiGatewayRestApiProvider,
+ )
+
+ self.factory = ApiGatewayRestApiProvider
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_stage.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_stage.py
new file mode 100644
index 0000000000000..b2b98bc715455
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_stage.py
@@ -0,0 +1,183 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import copy
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+from localstack.utils.objects import keys_to_lower
+
+
+class ApiGatewayStageProperties(TypedDict):
+ RestApiId: Optional[str]
+ AccessLogSetting: Optional[AccessLogSetting]
+ CacheClusterEnabled: Optional[bool]
+ CacheClusterSize: Optional[str]
+ CanarySetting: Optional[CanarySetting]
+ ClientCertificateId: Optional[str]
+ DeploymentId: Optional[str]
+ Description: Optional[str]
+ DocumentationVersion: Optional[str]
+ MethodSettings: Optional[list[MethodSetting]]
+ StageName: Optional[str]
+ Tags: Optional[list[Tag]]
+ TracingEnabled: Optional[bool]
+ Variables: Optional[dict]
+
+
+class AccessLogSetting(TypedDict):
+ DestinationArn: Optional[str]
+ Format: Optional[str]
+
+
+class CanarySetting(TypedDict):
+ DeploymentId: Optional[str]
+ PercentTraffic: Optional[float]
+ StageVariableOverrides: Optional[dict]
+ UseStageCache: Optional[bool]
+
+
+class MethodSetting(TypedDict):
+ CacheDataEncrypted: Optional[bool]
+ CacheTtlInSeconds: Optional[int]
+ CachingEnabled: Optional[bool]
+ DataTraceEnabled: Optional[bool]
+ HttpMethod: Optional[str]
+ LoggingLevel: Optional[str]
+ MetricsEnabled: Optional[bool]
+ ResourcePath: Optional[str]
+ ThrottlingBurstLimit: Optional[int]
+ ThrottlingRateLimit: Optional[float]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class ApiGatewayStageProvider(ResourceProvider[ApiGatewayStageProperties]):
+ TYPE = "AWS::ApiGateway::Stage" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[ApiGatewayStageProperties],
+ ) -> ProgressEvent[ApiGatewayStageProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/RestApiId
+ - /properties/StageName
+
+ Required properties:
+ - RestApiId
+
+ Create-only properties:
+ - /properties/RestApiId
+ - /properties/StageName
+
+
+
+ IAM permissions required:
+ - apigateway:POST
+ - apigateway:PATCH
+ - apigateway:GET
+
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ stage_name = model.get("StageName", "default")
+ stage_variables = model.get("Variables")
+ # we need to deep copy as several fields are nested dict and arrays
+ params = keys_to_lower(copy.deepcopy(model))
+ # TODO: add methodSettings
+ # TODO: add custom CfN tags
+ param_names = [
+ "restApiId",
+ "deploymentId",
+ "description",
+ "cacheClusterEnabled",
+ "cacheClusterSize",
+ "documentationVersion",
+ "canarySettings",
+ "tracingEnabled",
+ "tags",
+ ]
+ params = util.select_attributes(params, param_names)
+ params["tags"] = {t["key"]: t["value"] for t in params.get("tags", [])}
+ params["stageName"] = stage_name
+ if stage_variables:
+ params["variables"] = stage_variables
+
+ result = apigw.create_stage(**params)
+ model["StageName"] = result["stageName"]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[ApiGatewayStageProperties],
+ ) -> ProgressEvent[ApiGatewayStageProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - apigateway:GET
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[ApiGatewayStageProperties],
+ ) -> ProgressEvent[ApiGatewayStageProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - apigateway:DELETE
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+ try:
+ # we are checking if stage api has already been deleted before calling delete
+ apigw.get_stage(restApiId=model["RestApiId"], stageName=model["StageName"])
+ apigw.delete_stage(restApiId=model["RestApiId"], stageName=model["StageName"])
+ except apigw.exceptions.NotFoundException:
+ pass
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[ApiGatewayStageProperties],
+ ) -> ProgressEvent[ApiGatewayStageProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - apigateway:GET
+ - apigateway:PATCH
+ - apigateway:PUT
+ - apigateway:DELETE
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_stage.schema.json b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_stage.schema.json
new file mode 100644
index 0000000000000..fe67c2c0c626f
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_stage.schema.json
@@ -0,0 +1,260 @@
+{
+ "typeName": "AWS::ApiGateway::Stage",
+ "description": "Resource Type definition for AWS::ApiGateway::Stage",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-apigateway",
+ "additionalProperties": false,
+ "properties": {
+ "AccessLogSetting": {
+ "description": "Specifies settings for logging access in this stage.",
+ "$ref": "#/definitions/AccessLogSetting"
+ },
+ "CacheClusterEnabled": {
+ "description": "Indicates whether cache clustering is enabled for the stage.",
+ "type": "boolean"
+ },
+ "CacheClusterSize": {
+ "description": "The stage's cache cluster size.",
+ "type": "string"
+ },
+ "CanarySetting": {
+ "description": "Specifies settings for the canary deployment in this stage.",
+ "$ref": "#/definitions/CanarySetting"
+ },
+ "ClientCertificateId": {
+ "description": "The ID of the client certificate that API Gateway uses to call your integration endpoints in the stage. ",
+ "type": "string"
+ },
+ "DeploymentId": {
+ "description": "The ID of the deployment that the stage is associated with. This parameter is required to create a stage. ",
+ "type": "string"
+ },
+ "Description": {
+ "description": "A description of the stage.",
+ "type": "string"
+ },
+ "DocumentationVersion": {
+ "description": "The version ID of the API documentation snapshot.",
+ "type": "string"
+ },
+ "MethodSettings": {
+ "description": "Settings for all methods in the stage.",
+ "type": "array",
+ "uniqueItems": true,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/MethodSetting"
+ }
+ },
+ "RestApiId": {
+ "description": "The ID of the RestApi resource that you're deploying with this stage.",
+ "type": "string"
+ },
+ "StageName": {
+ "description": "The name of the stage, which API Gateway uses as the first path segment in the invoked Uniform Resource Identifier (URI).",
+ "type": "string"
+ },
+ "Tags": {
+ "description": "An array of arbitrary tags (key-value pairs) to associate with the stage.",
+ "type": "array",
+ "uniqueItems": false,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "TracingEnabled": {
+ "description": "Specifies whether active X-Ray tracing is enabled for this stage.",
+ "type": "boolean"
+ },
+ "Variables": {
+ "description": "A map (string-to-string map) that defines the stage variables, where the variable name is the key and the variable value is the value.",
+ "type": "object",
+ "additionalProperties": false,
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ }
+ },
+ "definitions": {
+ "CanarySetting": {
+ "description": "Specifies settings for the canary deployment in this stage.",
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "DeploymentId": {
+ "description": "The identifier of the deployment that the stage points to.",
+ "type": "string"
+ },
+ "PercentTraffic": {
+ "description": "The percentage (0-100) of traffic diverted to a canary deployment.",
+ "type": "number",
+ "minimum": 0,
+ "maximum": 100
+ },
+ "StageVariableOverrides": {
+ "description": "Stage variables overridden for a canary release deployment, including new stage variables introduced in the canary. These stage variables are represented as a string-to-string map between stage variable names and their values.",
+ "type": "object",
+ "additionalProperties": false,
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ },
+ "UseStageCache": {
+ "description": "Whether the canary deployment uses the stage cache or not.",
+ "type": "boolean"
+ }
+ }
+ },
+ "AccessLogSetting": {
+ "description": "Specifies settings for logging access in this stage.",
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "DestinationArn": {
+ "description": "The Amazon Resource Name (ARN) of the CloudWatch Logs log group or Kinesis Data Firehose delivery stream to receive access logs. If you specify a Kinesis Data Firehose delivery stream, the stream name must begin with amazon-apigateway-. This parameter is required to enable access logging.",
+ "type": "string"
+ },
+ "Format": {
+ "description": "A single line format of the access logs of data, as specified by selected $context variables (https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html#context-variable-reference). The format must include at least $context.requestId. This parameter is required to enable access logging.",
+ "type": "string"
+ }
+ }
+ },
+ "MethodSetting": {
+ "description": "Configures settings for all methods in a stage.",
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "CacheDataEncrypted": {
+ "description": "Indicates whether the cached responses are encrypted.",
+ "type": "boolean"
+ },
+ "CacheTtlInSeconds": {
+ "description": "The time-to-live (TTL) period, in seconds, that specifies how long API Gateway caches responses.",
+ "type": "integer"
+ },
+ "CachingEnabled": {
+ "description": "Indicates whether responses are cached and returned for requests. You must enable a cache cluster on the stage to cache responses.",
+ "type": "boolean"
+ },
+ "DataTraceEnabled": {
+ "description": "Indicates whether data trace logging is enabled for methods in the stage. API Gateway pushes these logs to Amazon CloudWatch Logs.",
+ "type": "boolean"
+ },
+ "HttpMethod": {
+ "description": "The HTTP method. You can use an asterisk (*) as a wildcard to apply method settings to multiple methods.",
+ "type": "string"
+ },
+ "LoggingLevel": {
+ "description": "The logging level for this method. For valid values, see the loggingLevel property of the Stage (https://docs.aws.amazon.com/apigateway/api-reference/resource/stage/#loggingLevel) resource in the Amazon API Gateway API Reference.",
+ "type": "string"
+ },
+ "MetricsEnabled": {
+ "description": "Indicates whether Amazon CloudWatch metrics are enabled for methods in the stage.",
+ "type": "boolean"
+ },
+ "ResourcePath": {
+ "description": "The resource path for this method. Forward slashes (/) are encoded as ~1 and the initial slash must include a forward slash. For example, the path value /resource/subresource must be encoded as /~1resource~1subresource. To specify the root path, use only a slash (/). You can use an asterisk (*) as a wildcard to apply method settings to multiple methods.",
+ "type": "string"
+ },
+ "ThrottlingBurstLimit": {
+ "description": "The number of burst requests per second that API Gateway permits across all APIs, stages, and methods in your AWS account.",
+ "type": "integer",
+ "minimum": 0
+ },
+ "ThrottlingRateLimit": {
+ "description": "The number of steady-state requests per second that API Gateway permits across all APIs, stages, and methods in your AWS account.",
+ "type": "number",
+ "minimum": 0
+ }
+ }
+ },
+ "Tag": {
+ "description": "Identify and categorize resources.",
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Key": {
+ "description": "The key name of the tag. You can specify a value that is 1 to 128 Unicode characters in length and cannot be prefixed with aws:.",
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 128
+ },
+ "Value": {
+ "description": "The value for the tag. You can specify a value that is 0 to 256 Unicode characters in length and cannot be prefixed with aws:.",
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 256
+ }
+ },
+ "required": [
+ "Key",
+ "Value"
+ ]
+ }
+ },
+ "required": [
+ "RestApiId"
+ ],
+ "createOnlyProperties": [
+ "/properties/RestApiId",
+ "/properties/StageName"
+ ],
+ "primaryIdentifier": [
+ "/properties/RestApiId",
+ "/properties/StageName"
+ ],
+ "tagging": {
+ "taggable": true,
+ "tagOnCreate": true,
+ "tagUpdatable": true,
+ "cloudFormationSystemTags": true,
+ "tagProperty": "/properties/Tags"
+ },
+ "handlers": {
+ "create": {
+ "permissions": [
+ "apigateway:POST",
+ "apigateway:PATCH",
+ "apigateway:GET"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "apigateway:GET"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "apigateway:GET",
+ "apigateway:PATCH",
+ "apigateway:PUT",
+ "apigateway:DELETE"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "apigateway:DELETE"
+ ]
+ },
+ "list": {
+ "handlerSchema": {
+ "properties": {
+ "RestApiId": {
+ "$ref": "resource-schema.json#/properties/RestApiId"
+ }
+ },
+ "required": [
+ "RestApiId"
+ ]
+ },
+ "permissions": [
+ "apigateway:GET"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_stage_plugin.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_stage_plugin.py
new file mode 100644
index 0000000000000..e0898bae2c695
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_stage_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class ApiGatewayStageProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::ApiGateway::Stage"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.apigateway.resource_providers.aws_apigateway_stage import (
+ ApiGatewayStageProvider,
+ )
+
+ self.factory = ApiGatewayStageProvider
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplan.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplan.py
new file mode 100644
index 0000000000000..1e10c9badfc3f
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplan.py
@@ -0,0 +1,215 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+from localstack.utils.aws.arns import get_partition
+from localstack.utils.objects import keys_to_lower
+from localstack.utils.strings import first_char_to_lower
+
+
+class ApiGatewayUsagePlanProperties(TypedDict):
+ ApiStages: Optional[list[ApiStage]]
+ Description: Optional[str]
+ Id: Optional[str]
+ Quota: Optional[QuotaSettings]
+ Tags: Optional[list[Tag]]
+ Throttle: Optional[ThrottleSettings]
+ UsagePlanName: Optional[str]
+
+
+class ApiStage(TypedDict):
+ ApiId: Optional[str]
+ Stage: Optional[str]
+ Throttle: Optional[dict]
+
+
+class QuotaSettings(TypedDict):
+ Limit: Optional[int]
+ Offset: Optional[int]
+ Period: Optional[str]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+class ThrottleSettings(TypedDict):
+ BurstLimit: Optional[int]
+ RateLimit: Optional[float]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class ApiGatewayUsagePlanProvider(ResourceProvider[ApiGatewayUsagePlanProperties]):
+ TYPE = "AWS::ApiGateway::UsagePlan" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[ApiGatewayUsagePlanProperties],
+ ) -> ProgressEvent[ApiGatewayUsagePlanProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Read-only properties:
+ - /properties/Id
+
+ IAM permissions required:
+ - apigateway:POST
+ - apigateway:GET
+
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ plan_name = model.get("UsagePlanName")
+ if not plan_name:
+ model["UsagePlanName"] = util.generate_default_name(
+ request.stack_name, request.logical_resource_id
+ )
+
+ params = util.select_attributes(model, ["Description", "ApiStages", "Quota", "Throttle"])
+ params = keys_to_lower(params.copy())
+ params["name"] = model["UsagePlanName"]
+
+ if model.get("Tags"):
+ params["tags"] = {tag["Key"]: tag["Value"] for tag in model["Tags"]}
+
+ # set int and float types
+ if params.get("quota"):
+ params["quota"]["limit"] = int(params["quota"]["limit"])
+
+ if params.get("throttle"):
+ params["throttle"]["burstLimit"] = int(params["throttle"]["burstLimit"])
+ params["throttle"]["rateLimit"] = float(params["throttle"]["rateLimit"])
+
+ response = apigw.create_usage_plan(**params)
+
+ model["Id"] = response["id"]
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[ApiGatewayUsagePlanProperties],
+ ) -> ProgressEvent[ApiGatewayUsagePlanProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - apigateway:GET
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[ApiGatewayUsagePlanProperties],
+ ) -> ProgressEvent[ApiGatewayUsagePlanProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - apigateway:DELETE
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ apigw.delete_usage_plan(usagePlanId=model["Id"])
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[ApiGatewayUsagePlanProperties],
+ ) -> ProgressEvent[ApiGatewayUsagePlanProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - apigateway:GET
+ - apigateway:DELETE
+ - apigateway:PATCH
+ - apigateway:PUT
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ parameters_to_select = [
+ "UsagePlanName",
+ "Description",
+ "ApiStages",
+ "Quota",
+ "Throttle",
+ "Tags",
+ ]
+ update_config_props = util.select_attributes(model, parameters_to_select)
+
+ updated_tags = update_config_props.pop("Tags", [])
+
+ usage_plan_id = request.previous_state["Id"]
+
+ patch_operations = []
+
+ for parameter in update_config_props:
+ value = update_config_props[parameter]
+ if parameter == "ApiStages":
+ for stage in value:
+ patch_operations.append(
+ {
+ "op": "replace",
+ "path": f"/{first_char_to_lower(parameter)}",
+ "value": f"{stage['ApiId']}:{stage['Stage']}",
+ }
+ )
+
+ if "Throttle" in stage:
+ patch_operations.append(
+ {
+ "op": "replace",
+ "path": f"/{first_char_to_lower(parameter)}/{stage['ApiId']}:{stage['Stage']}",
+ "value": json.dumps(stage["Throttle"]),
+ }
+ )
+
+ elif isinstance(value, dict):
+ for item in value:
+ last_value = value[item]
+ path = f"/{first_char_to_lower(parameter)}/{first_char_to_lower(item)}"
+ patch_operations.append({"op": "replace", "path": path, "value": last_value})
+ else:
+ patch_operations.append(
+ {"op": "replace", "path": f"/{first_char_to_lower(parameter)}", "value": value}
+ )
+ apigw.update_usage_plan(usagePlanId=usage_plan_id, patchOperations=patch_operations)
+
+ if updated_tags:
+ tags = {tag["Key"]: tag["Value"] for tag in updated_tags}
+ usage_plan_arn = f"arn:{get_partition(request.region_name)}:apigateway:{request.region_name}::/usageplans/{usage_plan_id}"
+ apigw.tag_resource(resourceArn=usage_plan_arn, tags=tags)
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model={**request.previous_state, **request.desired_state},
+ custom_context=request.custom_context,
+ )
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplan.schema.json b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplan.schema.json
new file mode 100644
index 0000000000000..96f6f07bb01ca
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplan.schema.json
@@ -0,0 +1,173 @@
+{
+ "typeName": "AWS::ApiGateway::UsagePlan",
+ "description": "Resource Type definition for AWS::ApiGateway::UsagePlan",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-apigateway.git",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "type": "string",
+ "description": "The provider-assigned unique ID for this managed resource."
+ },
+ "ApiStages": {
+ "type": "array",
+ "description": "The API stages to associate with this usage plan.",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/ApiStage"
+ }
+ },
+ "Description": {
+ "type": "string",
+ "description": "A description of the usage plan."
+ },
+ "Quota": {
+ "$ref": "#/definitions/QuotaSettings",
+ "description": "Configures the number of requests that users can make within a given interval."
+ },
+ "Tags": {
+ "type": "array",
+ "description": "An array of arbitrary tags (key-value pairs) to associate with the usage plan.",
+ "insertionOrder": false,
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "Throttle": {
+ "$ref": "#/definitions/ThrottleSettings",
+ "description": "Configures the overall request rate (average requests per second) and burst capacity."
+ },
+ "UsagePlanName": {
+ "type": "string",
+ "description": "A name for the usage plan."
+ }
+ },
+ "definitions": {
+ "ApiStage": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "ApiId": {
+ "type": "string",
+ "description": "The ID of an API that is in the specified Stage property that you want to associate with the usage plan."
+ },
+ "Stage": {
+ "type": "string",
+ "description": "The name of the stage to associate with the usage plan."
+ },
+ "Throttle": {
+ "type": "object",
+ "description": "Map containing method-level throttling information for an API stage in a usage plan. The key for the map is the path and method for which to configure custom throttling, for example, '/pets/GET'. Duplicates are not allowed.",
+ "additionalProperties": false,
+ "patternProperties": {
+ ".*": {
+ "$ref": "#/definitions/ThrottleSettings"
+ }
+ }
+ }
+ }
+ },
+ "ThrottleSettings": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "BurstLimit": {
+ "type": "integer",
+ "minimum": 0,
+ "description": "The maximum API request rate limit over a time ranging from one to a few seconds. The maximum API request rate limit depends on whether the underlying token bucket is at its full capacity."
+ },
+ "RateLimit": {
+ "type": "number",
+ "minimum": 0,
+ "description": "The API request steady-state rate limit (average requests per second over an extended period of time)."
+ }
+ }
+ },
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Key": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 128,
+ "description": "The key name of the tag. You can specify a value that is 1 to 128 Unicode characters in length and cannot be prefixed with aws:. You can use any of the following characters: the set of Unicode letters, digits, whitespace, _, ., /, =, +, and -."
+ },
+ "Value": {
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 256,
+ "description": "The value for the tag. You can specify a value that is 0 to 256 Unicode characters in length and cannot be prefixed with aws:. You can use any of the following characters: the set of Unicode letters, digits, whitespace, _, ., /, =, +, and -."
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ },
+ "QuotaSettings": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Limit": {
+ "type": "integer",
+ "minimum": 0,
+ "description": "The maximum number of requests that users can make within the specified time period."
+ },
+ "Offset": {
+ "type": "integer",
+ "minimum": 0,
+ "description": "For the initial time period, the number of requests to subtract from the specified limit. When you first implement a usage plan, the plan might start in the middle of the week or month. With this property, you can decrease the limit for this initial time period."
+ },
+ "Period": {
+ "type": "string",
+ "description": "The time period for which the maximum limit of requests applies, such as DAY or WEEK. For valid values, see the period property for the UsagePlan resource in the Amazon API Gateway REST API Reference."
+ }
+ }
+ }
+ },
+ "tagging": {
+ "taggable": true,
+ "tagOnCreate": true,
+ "tagUpdatable": true,
+ "cloudFormationSystemTags": true,
+ "tagProperty": "/properties/Tags"
+ },
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "apigateway:POST",
+ "apigateway:GET"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "apigateway:GET"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "apigateway:GET",
+ "apigateway:DELETE",
+ "apigateway:PATCH",
+ "apigateway:PUT"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "apigateway:DELETE"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "apigateway:GET"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplan_plugin.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplan_plugin.py
new file mode 100644
index 0000000000000..154207ac69b58
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplan_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class ApiGatewayUsagePlanProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::ApiGateway::UsagePlan"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.apigateway.resource_providers.aws_apigateway_usageplan import (
+ ApiGatewayUsagePlanProvider,
+ )
+
+ self.factory = ApiGatewayUsagePlanProvider
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplankey.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplankey.py
new file mode 100644
index 0000000000000..33a6e155d5c4f
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplankey.py
@@ -0,0 +1,114 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+from localstack.utils.objects import keys_to_lower
+
+
+class ApiGatewayUsagePlanKeyProperties(TypedDict):
+ KeyId: Optional[str]
+ KeyType: Optional[str]
+ UsagePlanId: Optional[str]
+ Id: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class ApiGatewayUsagePlanKeyProvider(ResourceProvider[ApiGatewayUsagePlanKeyProperties]):
+ TYPE = "AWS::ApiGateway::UsagePlanKey" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[ApiGatewayUsagePlanKeyProperties],
+ ) -> ProgressEvent[ApiGatewayUsagePlanKeyProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - KeyType
+ - UsagePlanId
+ - KeyId
+
+ Create-only properties:
+ - /properties/KeyId
+ - /properties/UsagePlanId
+ - /properties/KeyType
+
+ Read-only properties:
+ - /properties/Id
+
+ IAM permissions required:
+ - apigateway:POST
+ - apigateway:GET
+
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ params = keys_to_lower(model.copy())
+ result = apigw.create_usage_plan_key(**params)
+
+ model["Id"] = result["id"]
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[ApiGatewayUsagePlanKeyProperties],
+ ) -> ProgressEvent[ApiGatewayUsagePlanKeyProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - apigateway:GET
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[ApiGatewayUsagePlanKeyProperties],
+ ) -> ProgressEvent[ApiGatewayUsagePlanKeyProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - apigateway:DELETE
+ """
+ model = request.desired_state
+ apigw = request.aws_client_factory.apigateway
+
+ apigw.delete_usage_plan_key(usagePlanId=model["UsagePlanId"], keyId=model["KeyId"])
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[ApiGatewayUsagePlanKeyProperties],
+ ) -> ProgressEvent[ApiGatewayUsagePlanKeyProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplankey.schema.json b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplankey.schema.json
new file mode 100644
index 0000000000000..997f3be9a0d49
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplankey.schema.json
@@ -0,0 +1,77 @@
+{
+ "typeName": "AWS::ApiGateway::UsagePlanKey",
+ "description": "Resource Type definition for AWS::ApiGateway::UsagePlanKey",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-apigateway",
+ "additionalProperties": false,
+ "properties": {
+ "KeyId": {
+ "description": "The ID of the usage plan key.",
+ "type": "string"
+ },
+ "KeyType": {
+ "description": "The type of usage plan key. Currently, the only valid key type is API_KEY.",
+ "type": "string",
+ "enum": [
+ "API_KEY"
+ ]
+ },
+ "UsagePlanId": {
+ "description": "The ID of the usage plan.",
+ "type": "string"
+ },
+ "Id": {
+ "description": "An autogenerated ID which is a combination of the ID of the key and ID of the usage plan combined with a : such as 123abcdef:abc123.",
+ "type": "string"
+ }
+ },
+ "taggable": false,
+ "required": [
+ "KeyType",
+ "UsagePlanId",
+ "KeyId"
+ ],
+ "createOnlyProperties": [
+ "/properties/KeyId",
+ "/properties/UsagePlanId",
+ "/properties/KeyType"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "apigateway:POST",
+ "apigateway:GET"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "apigateway:GET"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "apigateway:DELETE"
+ ]
+ },
+ "list": {
+ "handlerSchema": {
+ "properties": {
+ "UsagePlanId": {
+ "$ref": "resource-schema.json#/properties/UsagePlanId"
+ }
+ },
+ "required": [
+ "UsagePlanId"
+ ]
+ },
+ "permissions": [
+ "apigateway:GET"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplankey_plugin.py b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplankey_plugin.py
new file mode 100644
index 0000000000000..eb21b610bfc22
--- /dev/null
+++ b/localstack-core/localstack/services/apigateway/resource_providers/aws_apigateway_usageplankey_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class ApiGatewayUsagePlanKeyProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::ApiGateway::UsagePlanKey"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.apigateway.resource_providers.aws_apigateway_usageplankey import (
+ ApiGatewayUsagePlanKeyProvider,
+ )
+
+ self.factory = ApiGatewayUsagePlanKeyProvider
diff --git a/localstack/utils/kinesis/__init__.py b/localstack-core/localstack/services/cdk/__init__.py
similarity index 100%
rename from localstack/utils/kinesis/__init__.py
rename to localstack-core/localstack/services/cdk/__init__.py
diff --git a/tests/integration/lambdas/__init__.py b/localstack-core/localstack/services/cdk/resource_providers/__init__.py
similarity index 100%
rename from tests/integration/lambdas/__init__.py
rename to localstack-core/localstack/services/cdk/resource_providers/__init__.py
diff --git a/localstack-core/localstack/services/cdk/resource_providers/cdk_metadata.py b/localstack-core/localstack/services/cdk/resource_providers/cdk_metadata.py
new file mode 100644
index 0000000000000..7e5eb5ca2f988
--- /dev/null
+++ b/localstack-core/localstack/services/cdk/resource_providers/cdk_metadata.py
@@ -0,0 +1,90 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class CDKMetadataProperties(TypedDict):
+ Id: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class CDKMetadataProvider(ResourceProvider[CDKMetadataProperties]):
+ TYPE = "AWS::CDK::Metadata" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[CDKMetadataProperties],
+ ) -> ProgressEvent[CDKMetadataProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+
+
+ """
+ model = request.desired_state
+ model["Id"] = util.generate_default_name(
+ stack_name=request.stack_name, logical_resource_id=request.logical_resource_id
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[CDKMetadataProperties],
+ ) -> ProgressEvent[CDKMetadataProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[CDKMetadataProperties],
+ ) -> ProgressEvent[CDKMetadataProperties]:
+ """
+ Delete a resource
+
+
+ """
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=request.previous_state,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[CDKMetadataProperties],
+ ) -> ProgressEvent[CDKMetadataProperties]:
+ """
+ Update a resource
+
+
+ """
+ model = request.desired_state
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ )
diff --git a/localstack-core/localstack/services/cdk/resource_providers/cdk_metadata.schema.json b/localstack-core/localstack/services/cdk/resource_providers/cdk_metadata.schema.json
new file mode 100644
index 0000000000000..636fc68e2e9c0
--- /dev/null
+++ b/localstack-core/localstack/services/cdk/resource_providers/cdk_metadata.schema.json
@@ -0,0 +1,22 @@
+{
+ "typeName": "AWS::CDK::Metadata" ,
+ "description": "Resource Type definition for AWS::CDK::Metadata",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "type": "string"
+ }
+ },
+ "definitions": {
+ },
+ "required": [
+ ],
+ "createOnlyProperties": [
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ]
+}
diff --git a/localstack-core/localstack/services/cdk/resource_providers/cdk_metadata_plugin.py b/localstack-core/localstack/services/cdk/resource_providers/cdk_metadata_plugin.py
new file mode 100644
index 0000000000000..924ca3cb79eae
--- /dev/null
+++ b/localstack-core/localstack/services/cdk/resource_providers/cdk_metadata_plugin.py
@@ -0,0 +1,18 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class LambdaAliasProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::CDK::Metadata"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.cdk.resource_providers.cdk_metadata import CDKMetadataProvider
+
+ self.factory = CDKMetadataProvider
diff --git a/localstack-core/localstack/services/certificatemanager/__init__.py b/localstack-core/localstack/services/certificatemanager/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/certificatemanager/resource_providers/__init__.py b/localstack-core/localstack/services/certificatemanager/resource_providers/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/certificatemanager/resource_providers/aws_certificatemanager_certificate.py b/localstack-core/localstack/services/certificatemanager/resource_providers/aws_certificatemanager_certificate.py
new file mode 100644
index 0000000000000..d79d62975e87f
--- /dev/null
+++ b/localstack-core/localstack/services/certificatemanager/resource_providers/aws_certificatemanager_certificate.py
@@ -0,0 +1,151 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class CertificateManagerCertificateProperties(TypedDict):
+ DomainName: Optional[str]
+ CertificateAuthorityArn: Optional[str]
+ CertificateTransparencyLoggingPreference: Optional[str]
+ DomainValidationOptions: Optional[list[DomainValidationOption]]
+ Id: Optional[str]
+ SubjectAlternativeNames: Optional[list[str]]
+ Tags: Optional[list[Tag]]
+ ValidationMethod: Optional[str]
+
+
+class DomainValidationOption(TypedDict):
+ DomainName: Optional[str]
+ HostedZoneId: Optional[str]
+ ValidationDomain: Optional[str]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class CertificateManagerCertificateProvider(
+ ResourceProvider[CertificateManagerCertificateProperties]
+):
+ TYPE = "AWS::CertificateManager::Certificate" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[CertificateManagerCertificateProperties],
+ ) -> ProgressEvent[CertificateManagerCertificateProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - DomainName
+
+ Create-only properties:
+ - /properties/SubjectAlternativeNames
+ - /properties/DomainValidationOptions
+ - /properties/ValidationMethod
+ - /properties/DomainName
+ - /properties/CertificateAuthorityArn
+
+ Read-only properties:
+ - /properties/Id
+
+
+
+ """
+ model = request.desired_state
+ acm = request.aws_client_factory.acm
+
+ params = util.select_attributes(
+ model,
+ [
+ "CertificateAuthorityArn",
+ "DomainName",
+ "DomainValidationOptions",
+ "SubjectAlternativeNames",
+ "Tags",
+ "ValidationMethod",
+ ],
+ )
+ # adjust domain validation options
+ valid_opts = params.get("DomainValidationOptions")
+ if valid_opts:
+
+ def _convert(opt):
+ res = util.select_attributes(opt, ["DomainName", "ValidationDomain"])
+ res.setdefault("ValidationDomain", res["DomainName"])
+ return res
+
+ params["DomainValidationOptions"] = [_convert(opt) for opt in valid_opts]
+
+ # adjust logging preferences
+ logging_pref = params.get("CertificateTransparencyLoggingPreference")
+ if logging_pref:
+ params["Options"] = {"CertificateTransparencyLoggingPreference": logging_pref}
+
+ response = acm.request_certificate(**params)
+ model["Id"] = response["CertificateArn"]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[CertificateManagerCertificateProperties],
+ ) -> ProgressEvent[CertificateManagerCertificateProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[CertificateManagerCertificateProperties],
+ ) -> ProgressEvent[CertificateManagerCertificateProperties]:
+ """
+ Delete a resource
+
+
+ """
+ model = request.desired_state
+ acm = request.aws_client_factory.acm
+
+ acm.delete_certificate(CertificateArn=model["Id"])
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[CertificateManagerCertificateProperties],
+ ) -> ProgressEvent[CertificateManagerCertificateProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/certificatemanager/resource_providers/aws_certificatemanager_certificate.schema.json b/localstack-core/localstack/services/certificatemanager/resource_providers/aws_certificatemanager_certificate.schema.json
new file mode 100644
index 0000000000000..a4d90a42f0839
--- /dev/null
+++ b/localstack-core/localstack/services/certificatemanager/resource_providers/aws_certificatemanager_certificate.schema.json
@@ -0,0 +1,95 @@
+{
+ "typeName": "AWS::CertificateManager::Certificate",
+ "description": "Resource Type definition for AWS::CertificateManager::Certificate",
+ "additionalProperties": false,
+ "properties": {
+ "CertificateAuthorityArn": {
+ "type": "string"
+ },
+ "DomainValidationOptions": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/DomainValidationOption"
+ }
+ },
+ "CertificateTransparencyLoggingPreference": {
+ "type": "string"
+ },
+ "DomainName": {
+ "type": "string"
+ },
+ "ValidationMethod": {
+ "type": "string"
+ },
+ "SubjectAlternativeNames": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Id": {
+ "type": "string"
+ },
+ "Tags": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ }
+ },
+ "definitions": {
+ "DomainValidationOption": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "DomainName": {
+ "type": "string"
+ },
+ "ValidationDomain": {
+ "type": "string"
+ },
+ "HostedZoneId": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "DomainName"
+ ]
+ },
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Value": {
+ "type": "string"
+ },
+ "Key": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ }
+ },
+ "required": [
+ "DomainName"
+ ],
+ "createOnlyProperties": [
+ "/properties/SubjectAlternativeNames",
+ "/properties/DomainValidationOptions",
+ "/properties/ValidationMethod",
+ "/properties/DomainName",
+ "/properties/CertificateAuthorityArn"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ]
+}
diff --git a/localstack-core/localstack/services/certificatemanager/resource_providers/aws_certificatemanager_certificate_plugin.py b/localstack-core/localstack/services/certificatemanager/resource_providers/aws_certificatemanager_certificate_plugin.py
new file mode 100644
index 0000000000000..5aae4de01c7b3
--- /dev/null
+++ b/localstack-core/localstack/services/certificatemanager/resource_providers/aws_certificatemanager_certificate_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class CertificateManagerCertificateProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::CertificateManager::Certificate"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.certificatemanager.resource_providers.aws_certificatemanager_certificate import (
+ CertificateManagerCertificateProvider,
+ )
+
+ self.factory = CertificateManagerCertificateProvider
diff --git a/localstack-core/localstack/services/cloudformation/__init__.py b/localstack-core/localstack/services/cloudformation/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/cloudformation/api_utils.py b/localstack-core/localstack/services/cloudformation/api_utils.py
new file mode 100644
index 0000000000000..556435ed699a7
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/api_utils.py
@@ -0,0 +1,106 @@
+import logging
+import re
+from urllib.parse import urlparse
+
+from localstack import config, constants
+from localstack.aws.connect import connect_to
+from localstack.services.s3.utils import (
+ extract_bucket_name_and_key_from_headers_and_path,
+ normalize_bucket_name,
+)
+from localstack.utils.functions import run_safe
+from localstack.utils.http import safe_requests
+from localstack.utils.strings import to_str
+from localstack.utils.urls import localstack_host
+
+LOG = logging.getLogger(__name__)
+
+
+def prepare_template_body(req_data: dict) -> str | bytes | None: # TODO: mutating and returning
+ template_url = req_data.get("TemplateURL")
+ if template_url:
+ req_data["TemplateURL"] = convert_s3_to_local_url(template_url)
+ url = req_data.get("TemplateURL", "")
+ if is_local_service_url(url):
+ modified_template_body = get_template_body(req_data)
+ if modified_template_body:
+ req_data.pop("TemplateURL", None)
+ req_data["TemplateBody"] = modified_template_body
+ modified_template_body = get_template_body(req_data)
+ if modified_template_body:
+ req_data["TemplateBody"] = modified_template_body
+ return modified_template_body
+
+
+def get_template_body(req_data: dict) -> str:
+ body = req_data.get("TemplateBody")
+ if body:
+ return body
+ url = req_data.get("TemplateURL")
+ if url:
+ response = run_safe(lambda: safe_requests.get(url, verify=False))
+ # check error codes, and code 301 - fixes https://github.com/localstack/localstack/issues/1884
+ status_code = 0 if response is None else response.status_code
+ if response is None or status_code == 301 or status_code >= 400:
+ # check if this is an S3 URL, then get the file directly from there
+ url = convert_s3_to_local_url(url)
+ if is_local_service_url(url):
+ parsed_path = urlparse(url).path.lstrip("/")
+ parts = parsed_path.partition("/")
+ client = connect_to().s3
+ LOG.debug(
+ "Download CloudFormation template content from local S3: %s - %s",
+ parts[0],
+ parts[2],
+ )
+ result = client.get_object(Bucket=parts[0], Key=parts[2])
+ body = to_str(result["Body"].read())
+ return body
+ raise Exception(
+ "Unable to fetch template body (code %s) from URL %s" % (status_code, url)
+ )
+ return to_str(response.content)
+ raise Exception("Unable to get template body from input: %s" % req_data)
+
+
+def is_local_service_url(url: str) -> bool:
+ if not url:
+ return False
+ candidates = (
+ constants.LOCALHOST,
+ constants.LOCALHOST_HOSTNAME,
+ localstack_host().host,
+ )
+ if any(re.match(r"^[^:]+://[^:/]*%s([:/]|$)" % host, url) for host in candidates):
+ return True
+ host = url.split("://")[-1].split("/")[0]
+ return "localhost" in host
+
+
+def convert_s3_to_local_url(url: str) -> str:
+ from localstack.services.cloudformation.provider import ValidationError
+
+ url_parsed = urlparse(url)
+ path = url_parsed.path
+
+ headers = {"host": url_parsed.netloc}
+ bucket_name, key_name = extract_bucket_name_and_key_from_headers_and_path(headers, path)
+
+ if url_parsed.scheme == "s3":
+ raise ValidationError(
+ f"S3 error: Domain name specified in {url_parsed.netloc} is not a valid S3 domain"
+ )
+
+ if not bucket_name or not key_name:
+ if not (url_parsed.netloc.startswith("s3.") or ".s3." in url_parsed.netloc):
+ raise ValidationError("TemplateURL must be a supported URL.")
+
+ # note: make sure to normalize the bucket name here!
+ bucket_name = normalize_bucket_name(bucket_name)
+ local_url = f"{config.internal_service_url()}/{bucket_name}/{key_name}"
+ return local_url
+
+
+def validate_stack_name(stack_name):
+ pattern = r"[a-zA-Z][-a-zA-Z0-9]*|arn:[-a-zA-Z0-9:/._+]*"
+ return re.match(pattern, stack_name) is not None
diff --git a/localstack-core/localstack/services/cloudformation/cfn_utils.py b/localstack-core/localstack/services/cloudformation/cfn_utils.py
new file mode 100644
index 0000000000000..6fcc5d16fb573
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/cfn_utils.py
@@ -0,0 +1,84 @@
+import json
+from typing import Callable
+
+from localstack.utils.objects import recurse_object
+
+
+def rename_params(func, rename_map):
+ def do_rename(account_id, region_name, params, logical_resource_id, *args, **kwargs):
+ values = (
+ func(account_id, region_name, params, logical_resource_id, *args, **kwargs)
+ if func
+ else params
+ )
+ for old_param, new_param in rename_map.items():
+ values[new_param] = values.pop(old_param, None)
+ return values
+
+ return do_rename
+
+
+def lambda_convert_types(func, types):
+ return (
+ lambda account_id, region_name, params, logical_resource_id, *args, **kwargs: convert_types(
+ func(account_id, region_name, params, *args, **kwargs), types
+ )
+ )
+
+
+def lambda_to_json(attr):
+ return lambda account_id, region_name, params, logical_resource_id, *args, **kwargs: json.dumps(
+ params[attr]
+ )
+
+
+def lambda_rename_attributes(attrs, func=None):
+ def recurse(o, path):
+ if isinstance(o, dict):
+ for k in list(o.keys()):
+ for a in attrs.keys():
+ if k == a:
+ o[attrs[k]] = o.pop(k)
+ return o
+
+ func = func or (lambda account_id, region_name, x, logical_resource_id, *args, **kwargs: x)
+ return (
+ lambda account_id,
+ region_name,
+ params,
+ logical_resource_id,
+ *args,
+ **kwargs: recurse_object(
+ func(account_id, region_name, params, logical_resource_id, *args, **kwargs), recurse
+ )
+ )
+
+
+def convert_types(obj, types):
+ def fix_types(key, type_class):
+ def recurse(o, path):
+ if isinstance(o, dict):
+ for k, v in dict(o).items():
+ key_path = "%s%s" % (path or ".", k)
+ if key in [k, key_path]:
+ o[k] = type_class(v)
+ return o
+
+ return recurse_object(obj, recurse)
+
+ for key, type_class in types.items():
+ fix_types(key, type_class)
+ return obj
+
+
+def get_tags_param(resource_type: str) -> Callable:
+ """Return a tag parameters creation function for the given resource type"""
+
+ def _param(account_id: str, region_name: str, params, logical_resource_id, *args, **kwargs):
+ tags = params.get("Tags")
+ if not tags:
+ return None
+
+ return [{"ResourceType": resource_type, "Tags": tags}]
+
+ return _param
diff --git a/localstack-core/localstack/services/cloudformation/deploy.html b/localstack-core/localstack/services/cloudformation/deploy.html
new file mode 100644
index 0000000000000..47af619288057
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/deploy.html
@@ -0,0 +1,144 @@
+
+
+
+
+ LocalStack - CloudFormation Deployment
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/localstack-core/localstack/services/cloudformation/deploy_ui.py b/localstack-core/localstack/services/cloudformation/deploy_ui.py
new file mode 100644
index 0000000000000..deac95b408b1f
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/deploy_ui.py
@@ -0,0 +1,47 @@
+import json
+import logging
+import os
+
+import requests
+from rolo import Response
+
+from localstack import constants
+from localstack.utils.files import load_file
+from localstack.utils.json import parse_json_or_yaml
+
+LOG = logging.getLogger(__name__)
+
+
+class CloudFormationUi:
+ def on_get(self, request):
+ from localstack.utils.aws.aws_stack import get_valid_regions
+
+ deploy_html_file = os.path.join(
+ constants.MODULE_MAIN_PATH, "services", "cloudformation", "deploy.html"
+ )
+ deploy_html = load_file(deploy_html_file)
+ req_params = request.values
+ params = {
+ "stackName": "stack1",
+ "templateBody": "{}",
+ "errorMessage": "''",
+ "regions": json.dumps(sorted(get_valid_regions())),
+ }
+
+ download_url = req_params.get("templateURL")
+ if download_url:
+ try:
+ LOG.debug("Attempting to download CloudFormation template URL: %s", download_url)
+ template_body = requests.get(download_url).text
+ template_body = parse_json_or_yaml(template_body)
+ params["templateBody"] = json.dumps(template_body)
+ except Exception as e:
+ msg = f"Unable to download CloudFormation template URL: {e}"
+ LOG.info(msg)
+ params["errorMessage"] = json.dumps(msg.replace("\n", " - "))
+
+ # using simple string replacement here, for simplicity (could be replaced with, e.g., jinja)
+ for key, value in params.items():
+ deploy_html = deploy_html.replace(f"<{key}>", value)
+
+ return Response(deploy_html, mimetype="text/html")
diff --git a/localstack-core/localstack/services/cloudformation/deployment_utils.py b/localstack-core/localstack/services/cloudformation/deployment_utils.py
new file mode 100644
index 0000000000000..6355db6b5c27a
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/deployment_utils.py
@@ -0,0 +1,319 @@
+import builtins
+import json
+import logging
+import re
+from copy import deepcopy
+from typing import Callable, List
+
+from localstack import config
+from localstack.utils import common
+from localstack.utils.aws import aws_stack
+from localstack.utils.common import select_attributes, short_uid
+from localstack.utils.functions import run_safe
+from localstack.utils.json import json_safe
+from localstack.utils.objects import recurse_object
+from localstack.utils.strings import is_string
+
+# placeholders
+PLACEHOLDER_AWS_NO_VALUE = "__aws_no_value__"
+
+LOG = logging.getLogger(__name__)
+
+
+def dump_json_params(param_func=None, *param_names):
+ def replace(account_id: str, region_name: str, params, logical_resource_id, *args, **kwargs):
+ result = (
+ param_func(account_id, region_name, params, logical_resource_id, *args, **kwargs)
+ if param_func
+ else params
+ )
+ for name in param_names:
+ if isinstance(result.get(name), (dict, list)):
+ # Fix for https://github.com/localstack/localstack/issues/2022
+ # Convert any date instances to date strings, etc, Version: "2012-10-17"
+ param_value = common.json_safe(result[name])
+ result[name] = json.dumps(param_value)
+ return result
+
+ return replace
+
+
+# TODO: remove
+def param_defaults(param_func, defaults):
+ def replace(
+ account_id: str,
+ region_name: str,
+ properties: dict,
+ logical_resource_id: str,
+ *args,
+ **kwargs,
+ ):
+ result = param_func(
+ account_id, region_name, properties, logical_resource_id, *args, **kwargs
+ )
+ for key, value in defaults.items():
+ if result.get(key) in ["", None]:
+ result[key] = value
+ return result
+
+ return replace
+
+
+def remove_none_values(params):
+ """Remove None values and AWS::NoValue placeholders (recursively) in the given object."""
+
+ def remove_nones(o, **kwargs):
+ if isinstance(o, dict):
+ for k, v in dict(o).items():
+ if v in [None, PLACEHOLDER_AWS_NO_VALUE]:
+ o.pop(k)
+ if isinstance(o, list):
+ common.run_safe(o.remove, None)
+ common.run_safe(o.remove, PLACEHOLDER_AWS_NO_VALUE)
+ return o
+
+ result = common.recurse_object(params, remove_nones)
+ return result
+
+
+def params_list_to_dict(param_name, key_attr_name="Key", value_attr_name="Value"):
+ def do_replace(account_id: str, region_name: str, params, logical_resource_id, *args, **kwargs):
+ result = {}
+ for entry in params.get(param_name, []):
+ key = entry[key_attr_name]
+ value = entry[value_attr_name]
+ result[key] = value
+ return result
+
+ return do_replace
+
+
+def lambda_keys_to_lower(key=None, skip_children_of: List[str] = None):
+ return (
+ lambda account_id,
+ region_name,
+ params,
+ logical_resource_id,
+ *args,
+ **kwargs: common.keys_to_lower(
+ obj=(params.get(key) if key else params), skip_children_of=skip_children_of
+ )
+ )
+
+
+def merge_parameters(func1, func2):
+ return (
+ lambda account_id,
+ region_name,
+ properties,
+ logical_resource_id,
+ *args,
+ **kwargs: common.merge_dicts(
+ func1(account_id, region_name, properties, logical_resource_id, *args, **kwargs),
+ func2(account_id, region_name, properties, logical_resource_id, *args, **kwargs),
+ )
+ )
+
+
+def str_or_none(o):
+ return o if o is None else json.dumps(o) if isinstance(o, (dict, list)) else str(o)
+
+
+def params_dict_to_list(param_name, key_attr_name="Key", value_attr_name="Value", wrapper=None):
+ def do_replace(account_id: str, region_name: str, params, logical_resource_id, *args, **kwargs):
+ result = []
+ for key, value in params.get(param_name, {}).items():
+ result.append({key_attr_name: key, value_attr_name: value})
+ if wrapper:
+ result = {wrapper: result}
+ return result
+
+ return do_replace
+
+
+# TODO: remove
+def params_select_attributes(*attrs):
+ def do_select(account_id: str, region_name: str, params, logical_resource_id, *args, **kwargs):
+ result = {}
+ for attr in attrs:
+ if params.get(attr) is not None:
+ result[attr] = str_or_none(params.get(attr))
+ return result
+
+ return do_select
+
+
+def param_json_to_str(name):
+ def _convert(account_id: str, region_name: str, params, logical_resource_id, *args, **kwargs):
+ result = params.get(name)
+ if result:
+ result = json.dumps(result)
+ return result
+
+ return _convert
+
+
+def lambda_select_params(*selected):
+ # TODO: remove and merge with function below
+ return select_parameters(*selected)
+
+
+def select_parameters(*param_names):
+ return (
+ lambda account_id,
+ region_name,
+ properties,
+ logical_resource_id,
+ *args,
+ **kwargs: select_attributes(properties, param_names)
+ )
+
+
+def is_none_or_empty_value(value):
+ return not value or value == PLACEHOLDER_AWS_NO_VALUE
+
+
+def generate_default_name(stack_name: str, logical_resource_id: str):
+ random_id_part = short_uid()
+ resource_id_part = logical_resource_id[:24]
+ stack_name_part = stack_name[: 63 - 2 - (len(random_id_part) + len(resource_id_part))]
+ return f"{stack_name_part}-{resource_id_part}-{random_id_part}"
+
+
+def generate_default_name_without_stack(logical_resource_id: str):
+ random_id_part = short_uid()
+ resource_id_part = logical_resource_id[: 63 - 1 - len(random_id_part)]
+ return f"{resource_id_part}-{random_id_part}"
+
+
+# Utils for parameter conversion
+
+# TODO: handling of multiple valid types
+param_validation = re.compile(
+ r"Invalid type for parameter (?P [\w.]+), value: (?P\w+), type: \w+)'>, valid types: \w+)'>"
+)
+
+
+def get_nested(obj: dict, path: str):
+ parts = path.split(".")
+ result = obj
+ for p in parts[:-1]:
+ result = result.get(p, {})
+ return result.get(parts[-1])
+
+
+def set_nested(obj: dict, path: str, value):
+ parts = path.split(".")
+ result = obj
+ for p in parts[:-1]:
+ result = result.get(p, {})
+ result[parts[-1]] = value
+
+
+def fix_boto_parameters_based_on_report(original_params: dict, report: str) -> dict:
+ """
+ Fix invalid type parameter validation errors in boto request parameters
+
+ :param original_params: original boto request parameters that lead to the parameter validation error
+ :param report: error report from botocore ParamValidator
+ :return: a copy of original_params with all values replaced by their correctly cast ones
+ """
+ params = deepcopy(original_params)
+ for found in param_validation.findall(report):
+ param_name, value, wrong_class, valid_class = found
+ cast_class = getattr(builtins, valid_class)
+ old_value = get_nested(params, param_name)
+
+ if cast_class == bool and str(old_value).lower() in ["true", "false"]:
+ new_value = str(old_value).lower() == "true"
+ else:
+ new_value = cast_class(old_value)
+ set_nested(params, param_name, new_value)
+ return params
+
+
+def fix_account_id_in_arns(params: dict, replacement_account_id: str) -> dict:
+ def fix_ids(o, **kwargs):
+ if isinstance(o, dict):
+ for k, v in o.items():
+ if is_string(v, exclude_binary=True):
+ o[k] = aws_stack.fix_account_id_in_arns(v, replacement=replacement_account_id)
+ elif is_string(o, exclude_binary=True):
+ o = aws_stack.fix_account_id_in_arns(o, replacement=replacement_account_id)
+ return o
+
+ result = recurse_object(params, fix_ids)
+ return result
+
+
+def convert_data_types(type_conversions: dict[str, Callable], params: dict) -> dict:
+ """Convert data types in the "params" object, with the type defs
+ specified in the 'types' attribute of "func_details"."""
+ attr_names = type_conversions.keys() or []
+
+ def cast(_obj, _type):
+ if _type == bool:
+ return _obj in ["True", "true", True]
+ if _type == str:
+ if isinstance(_obj, bool):
+ return str(_obj).lower()
+ return str(_obj)
+ if _type in (int, float):
+ return _type(_obj)
+ return _obj
+
+ def fix_types(o, **kwargs):
+ if isinstance(o, dict):
+ for k, v in o.items():
+ if k in attr_names:
+ o[k] = cast(v, type_conversions[k])
+ return o
+
+ result = recurse_object(params, fix_types)
+ return result
+
+
+def log_not_available_message(resource_type: str, message: str):
+ LOG.warning(
+ "%s. To find out if %s is supported in LocalStack Pro, "
+ "please check out our docs at https://docs.localstack.cloud/user-guide/aws/cloudformation/#resources-pro--enterprise-edition",
+ message,
+ resource_type,
+ )
+
+
+def dump_resource_as_json(resource: dict) -> str:
+ return str(run_safe(lambda: json.dumps(json_safe(resource))) or resource)
+
+
+def get_action_name_for_resource_change(res_change: str) -> str:
+ return {"Add": "CREATE", "Remove": "DELETE", "Modify": "UPDATE"}.get(res_change)
+
+
+def check_not_found_exception(e, resource_type, resource, resource_status=None):
+ # we expect this to be a "not found" exception
+ markers = [
+ "NoSuchBucket",
+ "ResourceNotFound",
+ "NoSuchEntity",
+ "NotFoundException",
+ "404",
+ "not found",
+ "not exist",
+ ]
+
+ markers_hit = [m for m in markers if m in str(e)]
+ if not markers_hit:
+ LOG.warning(
+ "Unexpected error processing resource type %s: Exception: %s - %s - status: %s",
+ resource_type,
+ str(e),
+ resource,
+ resource_status,
+ )
+ if config.CFN_VERBOSE_ERRORS:
+ raise e
+ else:
+ return False
+
+ return True
diff --git a/localstack-core/localstack/services/cloudformation/engine/__init__.py b/localstack-core/localstack/services/cloudformation/engine/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/cloudformation/engine/changes.py b/localstack-core/localstack/services/cloudformation/engine/changes.py
new file mode 100644
index 0000000000000..ae6ced9e5563e
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/engine/changes.py
@@ -0,0 +1,18 @@
+from typing import Literal, Optional, TypedDict
+
+Action = str
+
+
+class ResourceChange(TypedDict):
+ Action: Action
+ LogicalResourceId: str
+ PhysicalResourceId: Optional[str]
+ ResourceType: str
+ Scope: list
+ Details: list
+ Replacement: Optional[Literal["False"]]
+
+
+class ChangeConfig(TypedDict):
+ Type: str
+ ResourceChange: ResourceChange
diff --git a/localstack-core/localstack/services/cloudformation/engine/entities.py b/localstack-core/localstack/services/cloudformation/engine/entities.py
new file mode 100644
index 0000000000000..6151b46801b1c
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/engine/entities.py
@@ -0,0 +1,401 @@
+import logging
+from typing import Optional, TypedDict
+
+from localstack.aws.api.cloudformation import Capability, ChangeSetType, Parameter
+from localstack.services.cloudformation.engine.parameters import (
+ StackParameter,
+ convert_stack_parameters_to_list,
+ mask_no_echo,
+ strip_parameter_type,
+)
+from localstack.utils.aws import arns
+from localstack.utils.collections import select_attributes
+from localstack.utils.id_generator import ExistingIds, ResourceIdentifier, Tags, generate_short_uid
+from localstack.utils.json import clone_safe
+from localstack.utils.objects import recurse_object
+from localstack.utils.strings import long_uid, short_uid
+from localstack.utils.time import timestamp_millis
+
+LOG = logging.getLogger(__name__)
+
+
+class StackSet:
+ """A stack set contains multiple stack instances."""
+
+ # FIXME: confusing name. metadata is the complete incoming request object
+ def __init__(self, metadata: dict):
+ self.metadata = metadata
+ # list of stack instances
+ self.stack_instances = []
+ # maps operation ID to stack set operation details
+ self.operations = {}
+
+ @property
+ def stack_set_name(self):
+ return self.metadata.get("StackSetName")
+
+
+class StackInstance:
+ """A stack instance belongs to a stack set and is specific to a region / account ID."""
+
+ # FIXME: confusing name. metadata is the complete incoming request object
+ def __init__(self, metadata: dict):
+ self.metadata = metadata
+ # reference to the deployed stack belonging to this stack instance
+ self.stack = None
+
+
+class StackMetadata(TypedDict):
+ StackName: str
+ Capabilities: list[Capability]
+ ChangeSetName: Optional[str]
+ ChangSetType: Optional[ChangeSetType]
+ Parameters: list[Parameter]
+
+
+class StackTemplate(TypedDict):
+ StackName: str
+ ChangeSetName: Optional[str]
+ Outputs: dict
+ Resources: dict
+
+
+class StackIdentifier(ResourceIdentifier):
+ service = "cloudformation"
+ resource = "stack"
+
+ def __init__(self, account_id: str, region: str, stack_name: str):
+ super().__init__(account_id, region, stack_name)
+
+ def generate(self, existing_ids: ExistingIds = None, tags: Tags = None) -> str:
+ return generate_short_uid(resource_identifier=self, existing_ids=existing_ids, tags=tags)
+
+
+# TODO: remove metadata (flatten into individual fields)
+class Stack:
+ change_sets: list["StackChangeSet"]
+
+ def __init__(
+ self,
+ account_id: str,
+ region_name: str,
+ metadata: Optional[StackMetadata] = None,
+ template: Optional[StackTemplate] = None,
+ template_body: Optional[str] = None,
+ ):
+ self.account_id = account_id
+ self.region_name = region_name
+
+ if template is None:
+ template = {}
+
+ self.resolved_outputs = list() # TODO
+ self.resolved_parameters: dict[str, StackParameter] = {}
+ self.resolved_conditions: dict[str, bool] = {}
+
+ self.metadata = metadata or {}
+ self.template = template or {}
+ self.template_body = template_body
+ self._template_raw = clone_safe(self.template)
+ self.template_original = clone_safe(self.template)
+ # initialize resources
+ for resource_id, resource in self.template_resources.items():
+ resource["LogicalResourceId"] = self.template_original["Resources"][resource_id][
+ "LogicalResourceId"
+ ] = resource.get("LogicalResourceId") or resource_id
+ # initialize stack template attributes
+ stack_id = self.metadata.get("StackId") or arns.cloudformation_stack_arn(
+ self.stack_name,
+ stack_id=StackIdentifier(
+ account_id=account_id, region=region_name, stack_name=metadata.get("StackName")
+ ).generate(tags=metadata.get("tags")),
+ account_id=account_id,
+ region_name=region_name,
+ )
+ self.template["StackId"] = self.metadata["StackId"] = stack_id
+ self.template["Parameters"] = self.template.get("Parameters") or {}
+ self.template["Outputs"] = self.template.get("Outputs") or {}
+ self.template["Conditions"] = self.template.get("Conditions") or {}
+ # initialize metadata
+ self.metadata["Parameters"] = self.metadata.get("Parameters") or []
+ self.metadata["StackStatus"] = "CREATE_IN_PROGRESS"
+ self.metadata["CreationTime"] = self.metadata.get("CreationTime") or timestamp_millis()
+ self.metadata["LastUpdatedTime"] = self.metadata["CreationTime"]
+ self.metadata.setdefault("Description", self.template.get("Description"))
+ self.metadata.setdefault("RollbackConfiguration", {})
+ self.metadata.setdefault("DisableRollback", False)
+ self.metadata.setdefault("EnableTerminationProtection", False)
+ # maps resource id to resource state
+ self._resource_states = {}
+ # list of stack events
+ self.events = []
+ # list of stack change sets
+ self.change_sets = []
+ # self.evaluated_conditions = {}
+
+ def set_resolved_parameters(self, resolved_parameters: dict[str, StackParameter]):
+ self.resolved_parameters = resolved_parameters
+ if resolved_parameters:
+ self.metadata["Parameters"] = list(resolved_parameters.values())
+
+ def set_resolved_stack_conditions(self, resolved_conditions: dict[str, bool]):
+ self.resolved_conditions = resolved_conditions
+
+ def describe_details(self):
+ attrs = [
+ "StackId",
+ "StackName",
+ "Description",
+ "StackStatusReason",
+ "StackStatus",
+ "Capabilities",
+ "ParentId",
+ "RootId",
+ "RoleARN",
+ "CreationTime",
+ "DeletionTime",
+ "LastUpdatedTime",
+ "ChangeSetId",
+ "RollbackConfiguration",
+ "DisableRollback",
+ "EnableTerminationProtection",
+ "DriftInformation",
+ ]
+ result = select_attributes(self.metadata, attrs)
+ result["Tags"] = self.tags
+ outputs = self.resolved_outputs
+ if outputs:
+ result["Outputs"] = outputs
+ stack_parameters = convert_stack_parameters_to_list(self.resolved_parameters)
+ if stack_parameters:
+ result["Parameters"] = [
+ mask_no_echo(strip_parameter_type(sp)) for sp in stack_parameters
+ ]
+ if not result.get("DriftInformation"):
+ result["DriftInformation"] = {"StackDriftStatus": "NOT_CHECKED"}
+ for attr in ["Tags", "NotificationARNs"]:
+ result.setdefault(attr, [])
+ return result
+
+ def set_stack_status(self, status: str, status_reason: Optional[str] = None):
+ self.metadata["StackStatus"] = status
+ if "FAILED" in status:
+ self.metadata["StackStatusReason"] = status_reason or "Deployment failed"
+ self.log_stack_errors()
+ self.add_stack_event(
+ self.stack_name, self.stack_id, status, status_reason=status_reason or ""
+ )
+
+ def log_stack_errors(self, level=logging.WARNING):
+ for event in self.events:
+ if event["ResourceStatus"].endswith("FAILED"):
+ if reason := event.get("ResourceStatusReason"):
+ reason = reason.replace("\n", "; ")
+ LOG.log(
+ level,
+ "CFn resource failed to deploy: %s (%s)",
+ event["LogicalResourceId"],
+ reason,
+ )
+ else:
+ LOG.warning("CFn resource failed to deploy: %s", event["LogicalResourceId"])
+
+ def set_time_attribute(self, attribute, new_time=None):
+ self.metadata[attribute] = new_time or timestamp_millis()
+
+ def add_stack_event(
+ self,
+ resource_id: str = None,
+ physical_res_id: str = None,
+ status: str = "",
+ status_reason: str = "",
+ ):
+ resource_id = resource_id or self.stack_name
+ physical_res_id = physical_res_id or self.stack_id
+ resource_type = (
+ self.template.get("Resources", {})
+ .get(resource_id, {})
+ .get("Type", "AWS::CloudFormation::Stack")
+ )
+
+ event = {
+ "EventId": long_uid(),
+ "Timestamp": timestamp_millis(),
+ "StackId": self.stack_id,
+ "StackName": self.stack_name,
+ "LogicalResourceId": resource_id,
+ "PhysicalResourceId": physical_res_id,
+ "ResourceStatus": status,
+ "ResourceType": resource_type,
+ }
+
+ if status_reason:
+ event["ResourceStatusReason"] = status_reason
+
+ self.events.insert(0, event)
+
+ def set_resource_status(self, resource_id: str, status: str, status_reason: str = ""):
+ """Update the deployment status of the given resource ID and publish a corresponding stack event."""
+ physical_res_id = self.resources.get(resource_id, {}).get("PhysicalResourceId")
+ self._set_resource_status_details(resource_id, physical_res_id=physical_res_id)
+ state = self.resource_states.setdefault(resource_id, {})
+ state["PreviousResourceStatus"] = state.get("ResourceStatus")
+ state["ResourceStatus"] = status
+ state["LastUpdatedTimestamp"] = timestamp_millis()
+ self.add_stack_event(resource_id, physical_res_id, status, status_reason=status_reason)
+
+ def _set_resource_status_details(self, resource_id: str, physical_res_id: str = None):
+ """Helper function to ensure that the status details for the given resource ID are up-to-date."""
+ resource = self.resources.get(resource_id)
+ if resource is None or resource.get("Type") == "Parameter":
+ # make sure we delete the states for any non-existing/deleted resources
+ self._resource_states.pop(resource_id, None)
+ return
+ state = self._resource_states.setdefault(resource_id, {})
+ attr_defaults = (
+ ("LogicalResourceId", resource_id),
+ ("PhysicalResourceId", physical_res_id),
+ )
+ for res in [resource, state]:
+ for attr, default in attr_defaults:
+ res[attr] = res.get(attr) or default
+ state["StackName"] = state.get("StackName") or self.stack_name
+ state["StackId"] = state.get("StackId") or self.stack_id
+ state["ResourceType"] = state.get("ResourceType") or self.resources[resource_id].get("Type")
+ state["Timestamp"] = timestamp_millis()
+ return state
+
+ def resource_status(self, resource_id: str):
+ result = self._lookup(self.resource_states, resource_id)
+ return result
+
+ def latest_template_raw(self):
+ if self.change_sets:
+ return self.change_sets[-1]._template_raw
+ return self._template_raw
+
+ @property
+ def resource_states(self):
+ for resource_id in list(self._resource_states.keys()):
+ self._set_resource_status_details(resource_id)
+ return self._resource_states
+
+ @property
+ def stack_name(self):
+ return self.metadata["StackName"]
+
+ @property
+ def stack_id(self):
+ return self.metadata["StackId"]
+
+ @property
+ def resources(self):
+ """Return dict of resources"""
+ return dict(self.template_resources)
+
+ @property
+ def template_resources(self):
+ return self.template.setdefault("Resources", {})
+
+ @property
+ def tags(self):
+ return self.metadata.get("Tags", [])
+
+ @property
+ def imports(self):
+ def _collect(o, **kwargs):
+ if isinstance(o, dict):
+ import_val = o.get("Fn::ImportValue")
+ if import_val:
+ result.add(import_val)
+ return o
+
+ result = set()
+ recurse_object(self.resources, _collect)
+ return result
+
+ @property
+ def template_parameters(self):
+ return self.template["Parameters"]
+
+ @property
+ def conditions(self):
+ """Returns the (mutable) dict of stack conditions."""
+ return self.template.setdefault("Conditions", {})
+
+ @property
+ def mappings(self):
+ """Returns the (mutable) dict of stack mappings."""
+ return self.template.setdefault("Mappings", {})
+
+ @property
+ def outputs(self):
+ """Returns the (mutable) dict of stack outputs."""
+ return self.template.setdefault("Outputs", {})
+
+ @property
+ def status(self):
+ return self.metadata["StackStatus"]
+
+ @property
+ def resource_types(self):
+ return [r.get("Type") for r in self.template_resources.values()]
+
+ def resource(self, resource_id):
+ return self._lookup(self.resources, resource_id)
+
+ def _lookup(self, resource_map, resource_id):
+ resource = resource_map.get(resource_id)
+ if not resource:
+ raise Exception(
+ 'Unable to find details for resource "%s" in stack "%s"'
+ % (resource_id, self.stack_name)
+ )
+ return resource
+
+ def copy(self):
+ return Stack(
+ account_id=self.account_id,
+ region_name=self.region_name,
+ metadata=dict(self.metadata),
+ template=dict(self.template),
+ )
+
+
+# FIXME: remove inheritance
+class StackChangeSet(Stack):
+ def __init__(self, account_id: str, region_name: str, stack: Stack, params=None, template=None):
+ if template is None:
+ template = {}
+ if params is None:
+ params = {}
+ super(StackChangeSet, self).__init__(account_id, region_name, params, template)
+
+ name = self.metadata["ChangeSetName"]
+ if not self.metadata.get("ChangeSetId"):
+ self.metadata["ChangeSetId"] = arns.cloudformation_change_set_arn(
+ name, change_set_id=short_uid(), account_id=account_id, region_name=region_name
+ )
+
+ self.account_id = account_id
+ self.region_name = region_name
+ self.stack = stack
+ self.metadata["StackId"] = stack.stack_id
+ self.metadata["Status"] = "CREATE_PENDING"
+
+ @property
+ def change_set_id(self):
+ return self.metadata["ChangeSetId"]
+
+ @property
+ def change_set_name(self):
+ return self.metadata["ChangeSetName"]
+
+ @property
+ def resources(self):
+ return dict(self.stack.resources)
+
+ @property
+ def changes(self):
+ result = self.metadata["Changes"] = self.metadata.get("Changes", [])
+ return result
diff --git a/localstack-core/localstack/services/cloudformation/engine/errors.py b/localstack-core/localstack/services/cloudformation/engine/errors.py
new file mode 100644
index 0000000000000..0ee44f3530e58
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/engine/errors.py
@@ -0,0 +1,4 @@
+class TemplateError(RuntimeError):
+ """
+ Error thrown on a programming error from the user
+ """
diff --git a/localstack-core/localstack/services/cloudformation/engine/parameters.py b/localstack-core/localstack/services/cloudformation/engine/parameters.py
new file mode 100644
index 0000000000000..ba39fafc40db2
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/engine/parameters.py
@@ -0,0 +1,209 @@
+"""
+TODO: ordering & grouping of parameters
+TODO: design proper structure for parameters to facilitate validation etc.
+TODO: clearer language around both parameters and "resolving"
+
+Documentation extracted from AWS docs (https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/parameters-section-structure.html):
+ The following requirements apply when using parameters:
+
+ You can have a maximum of 200 parameters in an AWS CloudFormation template.
+ Each parameter must be given a logical name (also called logical ID), which must be alphanumeric and unique among all logical names within the template.
+ Each parameter must be assigned a parameter type that is supported by AWS CloudFormation. For more information, see Type.
+ Each parameter must be assigned a value at runtime for AWS CloudFormation to successfully provision the stack. You can optionally specify a default value for AWS CloudFormation to use unless another value is provided.
+ Parameters must be declared and referenced from within the same template. You can reference parameters from the Resources and Outputs sections of the template.
+
+ When you create or update stacks and create change sets, AWS CloudFormation uses whatever values exist in Parameter Store at the time the operation is run. If a specified parameter doesn't exist in Parameter Store under the caller's AWS account, AWS CloudFormation returns a validation error.
+
+ For stack updates, the Use existing value option in the console and the UsePreviousValue attribute for update-stack tell AWS CloudFormation to use the existing Systems Manager parameter keyβnot its value. AWS CloudFormation always fetches the latest values from Parameter Store when it updates stacks.
+
+"""
+
+import logging
+from typing import Literal, Optional, TypedDict
+
+from botocore.exceptions import ClientError
+
+from localstack.aws.api.cloudformation import Parameter, ParameterDeclaration
+from localstack.aws.connect import connect_to
+
+LOG = logging.getLogger(__name__)
+
+
+def extract_stack_parameter_declarations(template: dict) -> dict[str, ParameterDeclaration]:
+ """
+ Extract and build a dict of stack parameter declarations from a CloudFormation stack templatef
+
+ :param template: the parsed CloudFormation stack template
+ :return: a dictionary of declared parameters, mapping logical IDs to the corresponding parameter declaration
+ """
+ result = {}
+ for param_key, param in template.get("Parameters", {}).items():
+ result[param_key] = ParameterDeclaration(
+ ParameterKey=param_key,
+ DefaultValue=param.get("Default"),
+ ParameterType=param.get("Type"),
+ NoEcho=param.get("NoEcho", False),
+ # TODO: test & implement rest here
+ # ParameterConstraints=?,
+ # Description=?
+ )
+ return result
+
+
+class StackParameter(Parameter):
+ # we need the type information downstream when actually using the resolved value
+ # e.g. in case of lists so that we know that we should interpret the string as a comma-separated list.
+ ParameterType: str
+
+
+def resolve_parameters(
+ account_id: str,
+ region_name: str,
+ parameter_declarations: dict[str, ParameterDeclaration],
+ new_parameters: dict[str, Parameter],
+ old_parameters: dict[str, Parameter],
+) -> dict[str, StackParameter]:
+ """
+ Resolves stack parameters or raises an exception if any parameter can not be resolved.
+
+ Assumptions:
+ - There are no extra undeclared parameters given (validate before calling this method)
+
+ TODO: is UsePreviousValue=False equivalent to not specifying it, in all situations?
+
+ :param parameter_declarations: The parameter declaration from the (potentially new) template, i.e. the "Parameters" section
+ :param new_parameters: The parameters to resolve
+ :param old_parameters: The old parameters from the previous stack deployment, if available
+ :return: a copy of new_parameters with resolved values
+ """
+ resolved_parameters = dict()
+
+ # populate values for every parameter declared in the template
+ for pm in parameter_declarations.values():
+ pm_key = pm["ParameterKey"]
+ resolved_param = StackParameter(ParameterKey=pm_key, ParameterType=pm["ParameterType"])
+ new_parameter = new_parameters.get(pm_key)
+ old_parameter = old_parameters.get(pm_key)
+
+ if new_parameter is None:
+ # since no value has been specified for the deployment, we need to be able to resolve the default or fail
+ default_value = pm["DefaultValue"]
+ if default_value is None:
+ LOG.error("New parameter without a default value: %s", pm_key)
+ raise Exception(
+ f"Invalid. Parameter '{pm_key}' needs to have either param specified or Default."
+ ) # TODO: test and verify
+
+ resolved_param["ParameterValue"] = default_value
+ else:
+ if (
+ new_parameter.get("UsePreviousValue", False)
+ and new_parameter.get("ParameterValue") is not None
+ ):
+ raise Exception(
+ f"Can't set both 'UsePreviousValue' and a concrete value for parameter '{pm_key}'."
+ ) # TODO: test and verify
+
+ if new_parameter.get("UsePreviousValue", False):
+ if old_parameter is None:
+ raise Exception(
+ f"Set 'UsePreviousValue' but stack has no previous value for parameter '{pm_key}'."
+ ) # TODO: test and verify
+
+ resolved_param["ParameterValue"] = old_parameter["ParameterValue"]
+ else:
+ resolved_param["ParameterValue"] = new_parameter["ParameterValue"]
+
+ resolved_param["NoEcho"] = pm.get("NoEcho", False)
+ resolved_parameters[pm_key] = resolved_param
+
+ # Note that SSM parameters always need to be resolved anew here
+ # TODO: support more parameter types
+ if pm["ParameterType"].startswith("AWS::SSM"):
+ if pm["ParameterType"] in [
+ "AWS::SSM::Parameter::Value",
+ "AWS::SSM::Parameter::Value",
+ "AWS::SSM::Parameter::Value",
+ ]:
+ # TODO: error handling (e.g. no permission to lookup SSM parameter or SSM parameter doesn't exist)
+ resolved_param["ResolvedValue"] = resolve_ssm_parameter(
+ account_id, region_name, resolved_param["ParameterValue"]
+ )
+ else:
+ raise Exception(f"Unsupported stack parameter type: {pm['ParameterType']}")
+
+ return resolved_parameters
+
+
+# TODO: inject credentials / client factory for proper account/region lookup
+def resolve_ssm_parameter(account_id: str, region_name: str, stack_parameter_value: str) -> str:
+ """
+ Resolve the SSM stack parameter from the SSM service with a name equal to the stack parameter value.
+ """
+ ssm_client = connect_to(aws_access_key_id=account_id, region_name=region_name).ssm
+ try:
+ return ssm_client.get_parameter(Name=stack_parameter_value)["Parameter"]["Value"]
+ except ClientError:
+ LOG.error("client error fetching parameter '%s'", stack_parameter_value)
+ raise
+
+
+def strip_parameter_type(in_param: StackParameter) -> Parameter:
+ result = in_param.copy()
+ result.pop("ParameterType", None)
+ return result
+
+
+def mask_no_echo(in_param: StackParameter) -> Parameter:
+ result = in_param.copy()
+ no_echo = result.pop("NoEcho", False)
+ if no_echo:
+ result["ParameterValue"] = "****"
+ return result
+
+
+def convert_stack_parameters_to_list(
+ in_params: dict[str, StackParameter] | None,
+) -> list[StackParameter]:
+ if not in_params:
+ return []
+ return list(in_params.values())
+
+
+def convert_stack_parameters_to_dict(in_params: list[Parameter] | None) -> dict[str, Parameter]:
+ if not in_params:
+ return {}
+ return {p["ParameterKey"]: p for p in in_params}
+
+
+class LegacyParameterProperties(TypedDict):
+ Value: str
+ ParameterType: str
+ ParameterValue: Optional[str]
+ ResolvedValue: Optional[str]
+
+
+class LegacyParameter(TypedDict):
+ LogicalResourceId: str
+ Type: Literal["Parameter"]
+ Properties: LegacyParameterProperties
+
+
+# TODO: not actually parameter_type but the logical "ID"
+def map_to_legacy_structure(parameter_name: str, new_parameter: StackParameter) -> LegacyParameter:
+ """
+ Helper util to convert a normal (resolved) stack parameter to a legacy parameter structure that can then be merged with stack resources.
+
+ :param new_parameter: a resolved stack parameter
+ :return: legacy parameter that can be merged with stack resources for uniform lookup based on logical ID
+ """
+ return LegacyParameter(
+ LogicalResourceId=new_parameter["ParameterKey"],
+ Type="Parameter",
+ Properties=LegacyParameterProperties(
+ ParameterType=new_parameter.get("ParameterType"),
+ ParameterValue=new_parameter.get("ParameterValue"),
+ ResolvedValue=new_parameter.get("ResolvedValue"),
+ Value=new_parameter.get("ResolvedValue", new_parameter.get("ParameterValue")),
+ ),
+ )
diff --git a/localstack-core/localstack/services/cloudformation/engine/policy_loader.py b/localstack-core/localstack/services/cloudformation/engine/policy_loader.py
new file mode 100644
index 0000000000000..8f3d11be79244
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/engine/policy_loader.py
@@ -0,0 +1,18 @@
+import logging
+
+from samtranslator.translator.managed_policy_translator import ManagedPolicyLoader
+
+from localstack.aws.connect import connect_to
+
+LOG = logging.getLogger(__name__)
+
+
+policy_loader = None
+
+
+def create_policy_loader() -> ManagedPolicyLoader:
+ global policy_loader
+ if not policy_loader:
+ iam_client = connect_to().iam
+ policy_loader = ManagedPolicyLoader(iam_client=iam_client)
+ return policy_loader
diff --git a/localstack-core/localstack/services/cloudformation/engine/quirks.py b/localstack-core/localstack/services/cloudformation/engine/quirks.py
new file mode 100644
index 0000000000000..b38056474b560
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/engine/quirks.py
@@ -0,0 +1,65 @@
+"""
+We can't always automatically determine which value serves as the physical resource ID.
+=> This needs to be determined manually by testing against AWS (!)
+
+There's also a reason that the mapping is located here instead of closer to the resource providers themselves.
+If the resources were compliant with the generic AWS resource provider framework that AWS provides for your own resource types, we wouldn't need this.
+For legacy resources (and even some of the ones where they are open-sourced), AWS still has a layer of "secret sauce" that defines what the actual physical resource ID is.
+An extension schema only defines the primary identifiers but not directly the physical resource ID that is generated based on those.
+Since this is therefore rather part of the cloudformation layer and *not* the resource providers responsibility, we've put the mapping closer to the cloudformation engine.
+"""
+
+# note: format here is subject to change (e.g. it might not be a pure str -> str mapping, it could also involve more sophisticated handlers
+PHYSICAL_RESOURCE_ID_SPECIAL_CASES = {
+ "AWS::ApiGateway::Authorizer": "/properties/AuthorizerId",
+ "AWS::ApiGateway::RequestValidator": "/properties/RequestValidatorId",
+ "AWS::ApiGatewayV2::Authorizer": "/properties/AuthorizerId",
+ "AWS::ApiGatewayV2::Deployment": "/properties/DeploymentId",
+ "AWS::ApiGatewayV2::IntegrationResponse": "/properties/IntegrationResponseId",
+ "AWS::ApiGatewayV2::Route": "/properties/RouteId",
+ "AWS::ApiGateway::BasePathMapping": "/properties/RestApiId",
+ "AWS::ApiGateway::Deployment": "/properties/DeploymentId",
+ "AWS::ApiGateway::Model": "/properties/Name",
+ "AWS::ApiGateway::Resource": "/properties/ResourceId",
+ "AWS::ApiGateway::Stage": "/properties/StageName",
+ "AWS::Cognito::UserPoolClient": "/properties/ClientId",
+ "AWS::ECS::Service": "/properties/ServiceArn",
+ "AWS::EKS::FargateProfile": "|", # composite
+ "AWS::Events::EventBus": "/properties/Name",
+ "AWS::Logs::LogStream": "/properties/LogStreamName",
+ "AWS::Logs::SubscriptionFilter": "/properties/LogGroupName",
+ "AWS::RDS::DBProxyTargetGroup": "/properties/TargetGroupName",
+ "AWS::Glue::SchemaVersionMetadata": "||", # composite
+ "AWS::WAFv2::WebACL": "||",
+ "AWS::WAFv2::WebACLAssociation": "|",
+ "AWS::WAFv2::IPSet": "||",
+ # composite
+}
+
+# You can usually find the available GetAtt targets in the official resource documentation:
+# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-template-resource-type-ref.html
+# Use the scaffolded exploration test to verify against AWS which attributes you can access.
+# This mapping is not in use yet (!)
+VALID_GETATT_PROPERTIES = {
+ # Other Examples
+ # "AWS::ApiGateway::Resource": ["ResourceId"],
+ # "AWS::IAM::User": ["Arn"], # TODO: not validated yet
+ "AWS::SSM::Parameter": ["Type", "Value"], # TODO: not validated yet
+ # "AWS::OpenSearchService::Domain": [
+ # "AdvancedSecurityOptions.AnonymousAuthDisableDate",
+ # "Arn",
+ # "DomainArn",
+ # "DomainEndpoint",
+ # "DomainEndpoints",
+ # "Id",
+ # "ServiceSoftwareOptions",
+ # "ServiceSoftwareOptions.AutomatedUpdateDate",
+ # "ServiceSoftwareOptions.Cancellable",
+ # "ServiceSoftwareOptions.CurrentVersion",
+ # "ServiceSoftwareOptions.Description",
+ # "ServiceSoftwareOptions.NewVersion",
+ # "ServiceSoftwareOptions.OptionalDeployment",
+ # "ServiceSoftwareOptions.UpdateAvailable",
+ # "ServiceSoftwareOptions.UpdateStatus",
+ # ], # TODO: not validated yet
+}
diff --git a/localstack-core/localstack/services/cloudformation/engine/resource_ordering.py b/localstack-core/localstack/services/cloudformation/engine/resource_ordering.py
new file mode 100644
index 0000000000000..53eaf4d9279c6
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/engine/resource_ordering.py
@@ -0,0 +1,109 @@
+from collections import OrderedDict
+
+from localstack.services.cloudformation.engine.changes import ChangeConfig
+from localstack.services.cloudformation.engine.parameters import StackParameter
+from localstack.services.cloudformation.engine.template_utils import get_deps_for_resource
+
+
+class NoResourceInStack(ValueError):
+ """Raised when we preprocess the template and do not find a resource"""
+
+ def __init__(self, logical_resource_id: str):
+ msg = f"Template format error: Unresolved resource dependencies [{logical_resource_id}] in the Resources block of the template"
+
+ super().__init__(msg)
+
+
+def order_resources(
+ resources: dict,
+ resolved_parameters: dict[str, StackParameter],
+ resolved_conditions: dict[str, bool],
+ reverse: bool = False,
+) -> OrderedDict:
+ """
+ Given a dictionary of resources, topologically sort the resources based on
+ inter-resource dependencies (e.g. usages of intrinsic functions).
+ """
+ nodes: dict[str, list[str]] = {}
+ for logical_resource_id, properties in resources.items():
+ nodes.setdefault(logical_resource_id, [])
+ deps = get_deps_for_resource(properties, resolved_conditions)
+ for dep in deps:
+ if dep in resolved_parameters:
+ # we only care about other resources
+ continue
+ nodes.setdefault(dep, [])
+ nodes[dep].append(logical_resource_id)
+
+ # implementation from https://dev.to/leopfeiffer/topological-sort-with-kahns-algorithm-3dl1
+ indegrees = {k: 0 for k in nodes.keys()}
+ for dependencies in nodes.values():
+ for dependency in dependencies:
+ indegrees[dependency] += 1
+
+ # Place all elements with indegree 0 in queue
+ queue = [k for k in nodes.keys() if indegrees[k] == 0]
+
+ sorted_logical_resource_ids = []
+
+ # Continue until all nodes have been dealt with
+ while len(queue) > 0:
+ # node of current iteration is the first one from the queue
+ curr = queue.pop(0)
+ sorted_logical_resource_ids.append(curr)
+
+ # remove the current node from other dependencies
+ for dependency in nodes[curr]:
+ indegrees[dependency] -= 1
+
+ if indegrees[dependency] == 0:
+ queue.append(dependency)
+
+ # check for circular dependencies
+ if len(sorted_logical_resource_ids) != len(nodes):
+ raise Exception("Circular dependency found.")
+
+ sorted_mapping = []
+ for logical_resource_id in sorted_logical_resource_ids:
+ if properties := resources.get(logical_resource_id):
+ sorted_mapping.append((logical_resource_id, properties))
+ else:
+ if (
+ logical_resource_id not in resolved_parameters
+ and logical_resource_id not in resolved_conditions
+ ):
+ raise NoResourceInStack(logical_resource_id)
+
+ if reverse:
+ sorted_mapping = sorted_mapping[::-1]
+ return OrderedDict(sorted_mapping)
+
+
+def order_changes(
+ given_changes: list[ChangeConfig],
+ resources: dict,
+ resolved_parameters: dict[str, StackParameter],
+ # TODO: remove resolved conditions somehow
+ resolved_conditions: dict[str, bool],
+ reverse: bool = False,
+) -> list[ChangeConfig]:
+ """
+ Given a list of changes, a dictionary of resources and a dictionary of resolved conditions, topologically sort the
+ changes based on inter-resource dependencies (e.g. usages of intrinsic functions).
+ """
+ ordered_resources = order_resources(
+ resources=resources,
+ resolved_parameters=resolved_parameters,
+ resolved_conditions=resolved_conditions,
+ reverse=reverse,
+ )
+ sorted_changes = []
+ for logical_resource_id in ordered_resources.keys():
+ for change in given_changes:
+ if change["ResourceChange"]["LogicalResourceId"] == logical_resource_id:
+ sorted_changes.append(change)
+ break
+ assert len(sorted_changes) > 0
+ if reverse:
+ sorted_changes = sorted_changes[::-1]
+ return sorted_changes
diff --git a/localstack-core/localstack/services/cloudformation/engine/schema.py b/localstack-core/localstack/services/cloudformation/engine/schema.py
new file mode 100644
index 0000000000000..1a8e3d0a9d402
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/engine/schema.py
@@ -0,0 +1,15 @@
+import json
+import os
+import zipfile
+
+
+# TODO: unify with scaffolding
+class SchemaProvider:
+ def __init__(self, zipfile_path: str | os.PathLike[str]):
+ self.schemas = {}
+ with zipfile.ZipFile(zipfile_path) as infile:
+ for filename in infile.namelist():
+ with infile.open(filename) as schema_file:
+ schema = json.load(schema_file)
+ typename = schema["typeName"]
+ self.schemas[typename] = schema
diff --git a/localstack-core/localstack/services/cloudformation/engine/template_deployer.py b/localstack-core/localstack/services/cloudformation/engine/template_deployer.py
new file mode 100644
index 0000000000000..5a451e5171a73
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/engine/template_deployer.py
@@ -0,0 +1,1616 @@
+import base64
+import json
+import logging
+import re
+import traceback
+import uuid
+from typing import Optional
+
+from botocore.exceptions import ClientError
+
+from localstack import config
+from localstack.aws.connect import connect_to
+from localstack.constants import INTERNAL_AWS_SECRET_ACCESS_KEY
+from localstack.services.cloudformation.deployment_utils import (
+ PLACEHOLDER_AWS_NO_VALUE,
+ get_action_name_for_resource_change,
+ remove_none_values,
+)
+from localstack.services.cloudformation.engine.changes import ChangeConfig, ResourceChange
+from localstack.services.cloudformation.engine.entities import Stack, StackChangeSet
+from localstack.services.cloudformation.engine.parameters import StackParameter
+from localstack.services.cloudformation.engine.quirks import VALID_GETATT_PROPERTIES
+from localstack.services.cloudformation.engine.resource_ordering import (
+ order_changes,
+ order_resources,
+)
+from localstack.services.cloudformation.engine.template_utils import (
+ AWS_URL_SUFFIX,
+ fn_equals_type_conversion,
+ get_deps_for_resource,
+)
+from localstack.services.cloudformation.resource_provider import (
+ Credentials,
+ OperationStatus,
+ ProgressEvent,
+ ResourceProviderExecutor,
+ ResourceProviderPayload,
+ get_resource_type,
+)
+from localstack.services.cloudformation.service_models import (
+ DependencyNotYetSatisfied,
+)
+from localstack.services.cloudformation.stores import exports_map, find_stack
+from localstack.utils.aws.arns import get_partition
+from localstack.utils.functions import prevent_stack_overflow
+from localstack.utils.json import clone_safe
+from localstack.utils.strings import to_bytes, to_str
+from localstack.utils.threads import start_worker_thread
+
+from localstack.services.cloudformation.models import * # noqa: F401, F403, isort:skip
+from localstack.utils.urls import localstack_host
+
+ACTION_CREATE = "create"
+ACTION_DELETE = "delete"
+
+REGEX_OUTPUT_APIGATEWAY = re.compile(
+ rf"^(https?://.+\.execute-api\.)(?:[^-]+-){{2,3}}\d\.(amazonaws\.com|{AWS_URL_SUFFIX})/?(.*)$"
+)
+REGEX_DYNAMIC_REF = re.compile("{{resolve:([^:]+):(.+)}}")
+
+LOG = logging.getLogger(__name__)
+
+# list of static attribute references to be replaced in {'Fn::Sub': '...'} strings
+STATIC_REFS = ["AWS::Region", "AWS::Partition", "AWS::StackName", "AWS::AccountId"]
+
+# Mock value for unsupported type references
+MOCK_REFERENCE = "unknown"
+
+
+class NoStackUpdates(Exception):
+ """Exception indicating that no actions are to be performed in a stack update (which is not allowed)"""
+
+ pass
+
+
+# ---------------------
+# CF TEMPLATE HANDLING
+# ---------------------
+
+
+def get_attr_from_model_instance(
+ resource: dict,
+ attribute_name: str,
+ resource_type: str,
+ resource_id: str,
+ attribute_sub_name: Optional[str] = None,
+) -> str:
+ if resource["PhysicalResourceId"] == MOCK_REFERENCE:
+ LOG.warning(
+ "Attribute '%s' requested from unsupported resource with id %s",
+ attribute_name,
+ resource_id,
+ )
+ return MOCK_REFERENCE
+
+ properties = resource.get("Properties", {})
+ # if there's no entry in VALID_GETATT_PROPERTIES for the resource type we still default to "open" and accept anything
+ valid_atts = VALID_GETATT_PROPERTIES.get(resource_type)
+ if valid_atts is not None and attribute_name not in valid_atts:
+ LOG.warning(
+ "Invalid attribute in Fn::GetAtt for %s: | %s.%s",
+ resource_type,
+ resource_id,
+ attribute_name,
+ )
+ raise Exception(
+ f"Resource type {resource_type} does not support attribute {{{attribute_name}}}"
+ ) # TODO: check CFn behavior via snapshot
+
+ attribute_candidate = properties.get(attribute_name)
+ if attribute_sub_name:
+ return attribute_candidate.get(attribute_sub_name)
+ if "." in attribute_name:
+ # was used for legacy, but keeping it since it might have to work for a custom resource as well
+ if attribute_candidate:
+ return attribute_candidate
+
+ # some resources (e.g. ElastiCache) have their readOnly attributes defined as Aa.Bb but the property is named AaBb
+ if attribute_candidate := properties.get(attribute_name.replace(".", "")):
+ return attribute_candidate
+
+ # accessing nested properties
+ parts = attribute_name.split(".")
+ attribute = properties
+ # TODO: the attribute fetching below is a temporary workaround for the dependency resolution.
+ # It is caused by trying to access the resource attribute that has not been deployed yet.
+ # This should be a hard error.β
+ for part in parts:
+ if attribute is None:
+ return None
+ attribute = attribute.get(part)
+ return attribute
+
+ # If we couldn't find the attribute, this is actually an irrecoverable error.
+ # After the resource has a state of CREATE_COMPLETE, all attributes should already be set.
+ # TODO: raise here instead
+ # if attribute_candidate is None:
+ # raise Exception(
+ # f"Failed to resolve attribute for Fn::GetAtt in {resource_type}: {resource_id}.{attribute_name}"
+ # ) # TODO: check CFn behavior via snapshot
+ return attribute_candidate
+
+
+def resolve_ref(
+ account_id: str,
+ region_name: str,
+ stack_name: str,
+ resources: dict,
+ parameters: dict[str, StackParameter],
+ ref: str,
+):
+ """
+ ref always needs to be a static string
+ ref can be one of these:
+ 1. a pseudo-parameter (e.g. AWS::Region)
+ 2. a parameter
+ 3. the id of a resource (PhysicalResourceId
+ """
+ # pseudo parameter
+ if ref == "AWS::Region":
+ return region_name
+ if ref == "AWS::Partition":
+ return get_partition(region_name)
+ if ref == "AWS::StackName":
+ return stack_name
+ if ref == "AWS::StackId":
+ stack = find_stack(account_id, region_name, stack_name)
+ if not stack:
+ raise ValueError(f"No stack {stack_name} found")
+ return stack.stack_id
+ if ref == "AWS::AccountId":
+ return account_id
+ if ref == "AWS::NoValue":
+ return PLACEHOLDER_AWS_NO_VALUE
+ if ref == "AWS::NotificationARNs":
+ # TODO!
+ return {}
+ if ref == "AWS::URLSuffix":
+ return AWS_URL_SUFFIX
+
+ # parameter
+ if parameter := parameters.get(ref):
+ parameter_type: str = parameter["ParameterType"]
+ parameter_value = parameter.get("ResolvedValue") or parameter.get("ParameterValue")
+
+ if "CommaDelimitedList" in parameter_type or parameter_type.startswith("List<"):
+ return [p.strip() for p in parameter_value.split(",")]
+ else:
+ return parameter_value
+
+ # resource
+ resource = resources.get(ref)
+ if not resource:
+ raise Exception(
+ f"Resource target for `Ref {ref}` could not be found. Is there a resource with name {ref} in your stack?"
+ )
+
+ return resources[ref].get("PhysicalResourceId")
+
+
+# Using a @prevent_stack_overflow decorator here to avoid infinite recursion
+# in case we load stack exports that have circular dependencies (see issue 3438)
+# TODO: Potentially think about a better approach in the future
+@prevent_stack_overflow(match_parameters=True)
+def resolve_refs_recursively(
+ account_id: str,
+ region_name: str,
+ stack_name: str,
+ resources: dict,
+ mappings: dict,
+ conditions: dict[str, bool],
+ parameters: dict,
+ value,
+):
+ result = _resolve_refs_recursively(
+ account_id, region_name, stack_name, resources, mappings, conditions, parameters, value
+ )
+
+ # localstack specific patches
+ if isinstance(result, str):
+ # we're trying to filter constructed API urls here (e.g. via Join in the template)
+ api_match = REGEX_OUTPUT_APIGATEWAY.match(result)
+ if api_match and result in config.CFN_STRING_REPLACEMENT_DENY_LIST:
+ return result
+ elif api_match:
+ prefix = api_match[1]
+ host = api_match[2]
+ path = api_match[3]
+ port = localstack_host().port
+ return f"{prefix}{host}:{port}/{path}"
+
+ # basic dynamic reference support
+ # see: https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/dynamic-references.html
+ # technically there are more restrictions for each of these services but checking each of these
+ # isn't really necessary for the current level of emulation
+ dynamic_ref_match = REGEX_DYNAMIC_REF.match(result)
+ if dynamic_ref_match:
+ service_name = dynamic_ref_match[1]
+ reference_key = dynamic_ref_match[2]
+
+ # only these 3 services are supported for dynamic references right now
+ if service_name == "ssm":
+ ssm_client = connect_to(aws_access_key_id=account_id, region_name=region_name).ssm
+ try:
+ return ssm_client.get_parameter(Name=reference_key)["Parameter"]["Value"]
+ except ClientError as e:
+ LOG.error("client error accessing SSM parameter '%s': %s", reference_key, e)
+ raise
+ elif service_name == "ssm-secure":
+ ssm_client = connect_to(aws_access_key_id=account_id, region_name=region_name).ssm
+ try:
+ return ssm_client.get_parameter(Name=reference_key, WithDecryption=True)[
+ "Parameter"
+ ]["Value"]
+ except ClientError as e:
+ LOG.error("client error accessing SSM parameter '%s': %s", reference_key, e)
+ raise
+ elif service_name == "secretsmanager":
+ # reference key needs to be parsed further
+ # because {{resolve:secretsmanager:secret-id:secret-string:json-key:version-stage:version-id}}
+ # we match for "secret-id:secret-string:json-key:version-stage:version-id"
+ # where
+ # secret-id can either be the secret name or the full ARN of the secret
+ # secret-string *must* be SecretString
+ # all other values are optional
+ secret_id = reference_key
+ [json_key, version_stage, version_id] = [None, None, None]
+ if "SecretString" in reference_key:
+ parts = reference_key.split(":SecretString:")
+ secret_id = parts[0]
+ # json-key, version-stage and version-id are optional.
+ [json_key, version_stage, version_id] = f"{parts[1]}::".split(":")[:3]
+
+ kwargs = {} # optional args for get_secret_value
+ if version_id:
+ kwargs["VersionId"] = version_id
+ if version_stage:
+ kwargs["VersionStage"] = version_stage
+
+ secretsmanager_client = connect_to(
+ aws_access_key_id=account_id, region_name=region_name
+ ).secretsmanager
+ try:
+ secret_value = secretsmanager_client.get_secret_value(
+ SecretId=secret_id, **kwargs
+ )["SecretString"]
+ except ClientError:
+ LOG.error("client error while trying to access key '%s': %s", secret_id)
+ raise
+
+ if json_key:
+ json_secret = json.loads(secret_value)
+ if json_key not in json_secret:
+ raise DependencyNotYetSatisfied(
+ resource_ids=secret_id,
+ message=f"Key {json_key} is not yet available in secret {secret_id}.",
+ )
+ return json_secret[json_key]
+ else:
+ return secret_value
+ else:
+ LOG.warning(
+ "Unsupported service for dynamic parameter: service_name=%s", service_name
+ )
+
+ return result
+
+
+@prevent_stack_overflow(match_parameters=True)
+def _resolve_refs_recursively(
+ account_id: str,
+ region_name: str,
+ stack_name: str,
+ resources: dict,
+ mappings: dict,
+ conditions: dict,
+ parameters: dict,
+ value: dict | list | str | bytes | None,
+):
+ if isinstance(value, dict):
+ keys_list = list(value.keys())
+ stripped_fn_lower = keys_list[0].lower().split("::")[-1] if len(keys_list) == 1 else None
+
+ # process special operators
+ if keys_list == ["Ref"]:
+ ref = resolve_ref(
+ account_id, region_name, stack_name, resources, parameters, value["Ref"]
+ )
+ if ref is None:
+ msg = 'Unable to resolve Ref for resource "%s" (yet)' % value["Ref"]
+ LOG.debug("%s - %s", msg, resources.get(value["Ref"]) or set(resources.keys()))
+
+ raise DependencyNotYetSatisfied(resource_ids=value["Ref"], message=msg)
+
+ ref = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ ref,
+ )
+ return ref
+
+ if stripped_fn_lower == "getatt":
+ attr_ref = value[keys_list[0]]
+ attr_ref = attr_ref.split(".") if isinstance(attr_ref, str) else attr_ref
+ resource_logical_id = attr_ref[0]
+ attribute_name = attr_ref[1]
+ attribute_sub_name = attr_ref[2] if len(attr_ref) > 2 else None
+
+ # the attribute name can be a Ref
+ attribute_name = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ attribute_name,
+ )
+ resource = resources.get(resource_logical_id)
+
+ resource_type = get_resource_type(resource)
+ resolved_getatt = get_attr_from_model_instance(
+ resource,
+ attribute_name,
+ resource_type,
+ resource_logical_id,
+ attribute_sub_name,
+ )
+
+ # TODO: we should check the deployment state and not try to GetAtt from a resource that is still IN_PROGRESS or hasn't started yet.
+ if resolved_getatt is None:
+ raise DependencyNotYetSatisfied(
+ resource_ids=resource_logical_id,
+ message=f"Could not resolve attribute '{attribute_name}' on resource '{resource_logical_id}'",
+ )
+
+ return resolved_getatt
+
+ if stripped_fn_lower == "join":
+ join_values = value[keys_list[0]][1]
+
+ # this can actually be another ref that produces a list as output
+ if isinstance(join_values, dict):
+ join_values = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ join_values,
+ )
+
+ join_values = [
+ resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ v,
+ )
+ for v in join_values
+ ]
+
+ none_values = [v for v in join_values if v is None]
+ if none_values:
+ LOG.warning(
+ "Cannot resolve Fn::Join '%s' due to null values: '%s'", value, join_values
+ )
+ raise Exception(
+ f"Cannot resolve CF Fn::Join {value} due to null values: {join_values}"
+ )
+ return value[keys_list[0]][0].join(
+ [str(v) for v in join_values if v != "__aws_no_value__"]
+ )
+
+ if stripped_fn_lower == "sub":
+ item_to_sub = value[keys_list[0]]
+
+ attr_refs = {r: {"Ref": r} for r in STATIC_REFS}
+ if not isinstance(item_to_sub, list):
+ item_to_sub = [item_to_sub, {}]
+ result = item_to_sub[0]
+ item_to_sub[1].update(attr_refs)
+
+ for key, val in item_to_sub[1].items():
+ resolved_val = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ val,
+ )
+
+ if isinstance(resolved_val, (list, dict, tuple)):
+ # We don't have access to the resource that's a dependency in this case,
+ # so do the best we can with the resource ids
+ raise DependencyNotYetSatisfied(
+ resource_ids=key, message=f"Could not resolve {val} to terminal value type"
+ )
+ result = result.replace("${%s}" % key, str(resolved_val))
+
+ # resolve placeholders
+ result = resolve_placeholders_in_string(
+ account_id,
+ region_name,
+ result,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ )
+ return result
+
+ if stripped_fn_lower == "findinmap":
+ # "Fn::FindInMap"
+ mapping_id = value[keys_list[0]][0]
+
+ if isinstance(mapping_id, dict) and "Ref" in mapping_id:
+ # TODO: ??
+ mapping_id = resolve_ref(
+ account_id, region_name, stack_name, resources, parameters, mapping_id["Ref"]
+ )
+
+ selected_map = mappings.get(mapping_id)
+ if not selected_map:
+ raise Exception(
+ f"Cannot find Mapping with ID {mapping_id} for Fn::FindInMap: {value[keys_list[0]]} {list(resources.keys())}"
+ # TODO: verify
+ )
+
+ first_level_attribute = value[keys_list[0]][1]
+ first_level_attribute = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ first_level_attribute,
+ )
+
+ if first_level_attribute not in selected_map:
+ raise Exception(
+ f"Cannot find map key '{first_level_attribute}' in mapping '{mapping_id}'"
+ )
+ first_level_mapping = selected_map[first_level_attribute]
+
+ second_level_attribute = value[keys_list[0]][2]
+ if not isinstance(second_level_attribute, str):
+ second_level_attribute = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ second_level_attribute,
+ )
+ if second_level_attribute not in first_level_mapping:
+ raise Exception(
+ f"Cannot find map key '{second_level_attribute}' in mapping '{mapping_id}' under key '{first_level_attribute}'"
+ )
+
+ return first_level_mapping[second_level_attribute]
+
+ if stripped_fn_lower == "importvalue":
+ import_value_key = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ value[keys_list[0]],
+ )
+ exports = exports_map(account_id, region_name)
+ stack_export = exports.get(import_value_key) or {}
+ if not stack_export.get("Value"):
+ LOG.info(
+ 'Unable to find export "%s" in stack "%s", existing export names: %s',
+ import_value_key,
+ stack_name,
+ list(exports.keys()),
+ )
+ return None
+ return stack_export["Value"]
+
+ if stripped_fn_lower == "if":
+ condition, option1, option2 = value[keys_list[0]]
+ condition = conditions.get(condition)
+ if condition is None:
+ LOG.warning(
+ "Cannot find condition '%s' in conditions mapping: '%s'",
+ condition,
+ conditions.keys(),
+ )
+ raise KeyError(
+ f"Cannot find condition '{condition}' in conditions mapping: '{conditions.keys()}'"
+ )
+
+ result = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ option1 if condition else option2,
+ )
+ return result
+
+ if stripped_fn_lower == "condition":
+ # FIXME: this should only allow strings, no evaluation should be performed here
+ # see https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/intrinsic-function-reference-condition.html
+ key = value[keys_list[0]]
+ result = conditions.get(key)
+ if result is None:
+ LOG.warning("Cannot find key '%s' in conditions: '%s'", key, conditions.keys())
+ raise KeyError(f"Cannot find key '{key}' in conditions: '{conditions.keys()}'")
+ return result
+
+ if stripped_fn_lower == "not":
+ condition = value[keys_list[0]][0]
+ condition = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ condition,
+ )
+ return not condition
+
+ if stripped_fn_lower in ["and", "or"]:
+ conditions = value[keys_list[0]]
+ results = [
+ resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ cond,
+ )
+ for cond in conditions
+ ]
+ result = all(results) if stripped_fn_lower == "and" else any(results)
+ return result
+
+ if stripped_fn_lower == "equals":
+ operand1, operand2 = value[keys_list[0]]
+ operand1 = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ operand1,
+ )
+ operand2 = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ operand2,
+ )
+ # TODO: investigate type coercion here
+ return fn_equals_type_conversion(operand1) == fn_equals_type_conversion(operand2)
+
+ if stripped_fn_lower == "select":
+ index, values = value[keys_list[0]]
+ index = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ index,
+ )
+ values = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ values,
+ )
+ try:
+ return values[index]
+ except TypeError:
+ return values[int(index)]
+
+ if stripped_fn_lower == "split":
+ delimiter, string = value[keys_list[0]]
+ delimiter = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ delimiter,
+ )
+ string = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ string,
+ )
+ return string.split(delimiter)
+
+ if stripped_fn_lower == "getazs":
+ region = (
+ resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ value["Fn::GetAZs"],
+ )
+ or region_name
+ )
+
+ ec2_client = connect_to(aws_access_key_id=account_id, region_name=region).ec2
+ try:
+ get_availability_zones = ec2_client.describe_availability_zones()[
+ "AvailabilityZones"
+ ]
+ except ClientError:
+ LOG.error("client error describing availability zones")
+ raise
+
+ azs = [az["ZoneName"] for az in get_availability_zones]
+
+ return azs
+
+ if stripped_fn_lower == "base64":
+ value_to_encode = value[keys_list[0]]
+ value_to_encode = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ value_to_encode,
+ )
+ return to_str(base64.b64encode(to_bytes(value_to_encode)))
+
+ for key, val in dict(value).items():
+ value[key] = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ val,
+ )
+
+ if isinstance(value, list):
+ # in some cases, intrinsic functions are passed in as, e.g., `[['Fn::Sub', '${MyRef}']]`
+ if len(value) == 1 and isinstance(value[0], list) and len(value[0]) == 2:
+ inner_list = value[0]
+ if str(inner_list[0]).lower().startswith("fn::"):
+ return resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ {inner_list[0]: inner_list[1]},
+ )
+
+ for i in range(len(value)):
+ value[i] = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ value[i],
+ )
+
+ return value
+
+
+def resolve_placeholders_in_string(
+ account_id: str,
+ region_name: str,
+ result,
+ stack_name: str,
+ resources: dict,
+ mappings: dict,
+ conditions: dict[str, bool],
+ parameters: dict,
+):
+ """
+ Resolve individual Fn::Sub variable replacements
+
+ Variables can be template parameter names, resource logical IDs, resource attributes, or a variable in a key-value map
+ """
+
+ def _validate_result_type(value: str):
+ is_another_account_id = value.isdigit() and len(value) == len(account_id)
+ if value == account_id or is_another_account_id:
+ return value
+
+ if value.isdigit():
+ return int(value)
+ else:
+ try:
+ res = float(value)
+ return res
+ except ValueError:
+ return value
+
+ def _replace(match):
+ ref_expression = match.group(1)
+ parts = ref_expression.split(".")
+ if len(parts) >= 2:
+ # Resource attributes specified => Use GetAtt to resolve
+ logical_resource_id, _, attr_name = ref_expression.partition(".")
+ resolved = get_attr_from_model_instance(
+ resources[logical_resource_id],
+ attr_name,
+ get_resource_type(resources[logical_resource_id]),
+ logical_resource_id,
+ )
+ if resolved is None:
+ raise DependencyNotYetSatisfied(
+ resource_ids=logical_resource_id,
+ message=f"Unable to resolve attribute ref {ref_expression}",
+ )
+ if not isinstance(resolved, str):
+ resolved = str(resolved)
+ return resolved
+ if len(parts) == 1:
+ if parts[0] in resources or parts[0].startswith("AWS::"):
+ # Logical resource ID or parameter name specified => Use Ref for lookup
+ result = resolve_ref(
+ account_id, region_name, stack_name, resources, parameters, parts[0]
+ )
+
+ if result is None:
+ raise DependencyNotYetSatisfied(
+ resource_ids=parts[0],
+ message=f"Unable to resolve attribute ref {ref_expression}",
+ )
+ # TODO: is this valid?
+ # make sure we resolve any functions/placeholders in the extracted string
+ result = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ parameters,
+ result,
+ )
+ # make sure we convert the result to string
+ # TODO: do this more systematically
+ result = "" if result is None else str(result)
+ return result
+ elif parts[0] in parameters:
+ parameter = parameters[parts[0]]
+ parameter_type: str = parameter["ParameterType"]
+ parameter_value = parameter.get("ResolvedValue") or parameter.get("ParameterValue")
+
+ if parameter_type in ["CommaDelimitedList"] or parameter_type.startswith("List<"):
+ return [p.strip() for p in parameter_value.split(",")]
+ elif parameter_type == "Number":
+ return str(parameter_value)
+ else:
+ return parameter_value
+ else:
+ raise DependencyNotYetSatisfied(
+ resource_ids=parts[0],
+ message=f"Unable to resolve attribute ref {ref_expression}",
+ )
+ # TODO raise exception here?
+ return match.group(0)
+
+ regex = r"\$\{([^\}]+)\}"
+ result = re.sub(regex, _replace, result)
+ return _validate_result_type(result)
+
+
+def evaluate_resource_condition(conditions: dict[str, bool], resource: dict) -> bool:
+ if condition := resource.get("Condition"):
+ return conditions.get(condition, True)
+ return True
+
+
+# -----------------------
+# MAIN TEMPLATE DEPLOYER
+# -----------------------
+
+
+class TemplateDeployer:
+ def __init__(self, account_id: str, region_name: str, stack):
+ self.stack = stack
+ self.account_id = account_id
+ self.region_name = region_name
+
+ @property
+ def resources(self):
+ return self.stack.resources
+
+ @property
+ def mappings(self):
+ return self.stack.mappings
+
+ @property
+ def stack_name(self):
+ return self.stack.stack_name
+
+ # ------------------
+ # MAIN ENTRY POINTS
+ # ------------------
+
+ def deploy_stack(self):
+ self.stack.set_stack_status("CREATE_IN_PROGRESS")
+ try:
+ self.apply_changes(
+ self.stack,
+ self.stack,
+ initialize=True,
+ action="CREATE",
+ )
+ except Exception as e:
+ log_method = LOG.info
+ if config.CFN_VERBOSE_ERRORS:
+ log_method = LOG.exception
+ log_method("Unable to create stack %s: %s", self.stack.stack_name, e)
+ self.stack.set_stack_status("CREATE_FAILED")
+ raise
+
+ def apply_change_set(self, change_set: StackChangeSet):
+ action = (
+ "UPDATE"
+ if change_set.stack.status in {"CREATE_COMPLETE", "UPDATE_COMPLETE"}
+ else "CREATE"
+ )
+ change_set.stack.set_stack_status(f"{action}_IN_PROGRESS")
+ # update parameters on parent stack
+ change_set.stack.set_resolved_parameters(change_set.resolved_parameters)
+ # update conditions on parent stack
+ change_set.stack.set_resolved_stack_conditions(change_set.resolved_conditions)
+
+ # update attributes that the stack inherits from the changeset
+ change_set.stack.metadata["Capabilities"] = change_set.metadata.get("Capabilities")
+
+ try:
+ self.apply_changes(
+ change_set.stack,
+ change_set,
+ action=action,
+ )
+ except Exception as e:
+ LOG.info(
+ "Unable to apply change set %s: %s", change_set.metadata.get("ChangeSetName"), e
+ )
+ change_set.metadata["Status"] = f"{action}_FAILED"
+ self.stack.set_stack_status(f"{action}_FAILED")
+ raise
+
+ def update_stack(self, new_stack):
+ self.stack.set_stack_status("UPDATE_IN_PROGRESS")
+ # apply changes
+ self.apply_changes(self.stack, new_stack, action="UPDATE")
+ self.stack.set_time_attribute("LastUpdatedTime")
+
+ # ----------------------------
+ # DEPENDENCY RESOLUTION UTILS
+ # ----------------------------
+
+ def is_deployed(self, resource):
+ return self.stack.resource_states.get(resource["LogicalResourceId"], {}).get(
+ "ResourceStatus"
+ ) in [
+ "CREATE_COMPLETE",
+ "UPDATE_COMPLETE",
+ ]
+
+ def all_resource_dependencies_satisfied(self, resource) -> bool:
+ unsatisfied = self.get_unsatisfied_dependencies(resource)
+ return not unsatisfied
+
+ def get_unsatisfied_dependencies(self, resource):
+ res_deps = self.get_resource_dependencies(
+ resource
+ ) # the output here is currently a set of merged IDs from both resources and parameters
+ parameter_deps = {d for d in res_deps if d in self.stack.resolved_parameters}
+ resource_deps = res_deps.difference(parameter_deps)
+ res_deps_mapped = {v: self.stack.resources.get(v) for v in resource_deps}
+ return self.get_unsatisfied_dependencies_for_resources(res_deps_mapped, resource)
+
+ def get_unsatisfied_dependencies_for_resources(
+ self, resources, depending_resource=None, return_first=True
+ ):
+ result = {}
+ for resource_id, resource in resources.items():
+ if not resource:
+ raise Exception(
+ f"Resource '{resource_id}' not found in stack {self.stack.stack_name}"
+ )
+ if not self.is_deployed(resource):
+ LOG.debug(
+ "Dependency for resource %s not yet deployed: %s %s",
+ depending_resource,
+ resource_id,
+ resource,
+ )
+ result[resource_id] = resource
+ if return_first:
+ break
+ return result
+
+ def get_resource_dependencies(self, resource: dict) -> set[str]:
+ """
+ Takes a resource and returns its dependencies on other resources via a str -> str mapping
+ """
+ # Note: using the original, unmodified template here to preserve Ref's ...
+ raw_resources = self.stack.template_original["Resources"]
+ raw_resource = raw_resources[resource["LogicalResourceId"]]
+ return get_deps_for_resource(raw_resource, self.stack.resolved_conditions)
+
+ # -----------------
+ # DEPLOYMENT UTILS
+ # -----------------
+
+ def init_resource_status(self, resources=None, stack=None, action="CREATE"):
+ resources = resources or self.resources
+ stack = stack or self.stack
+ for resource_id, resource in resources.items():
+ stack.set_resource_status(resource_id, f"{action}_IN_PROGRESS")
+
+ def get_change_config(
+ self, action: str, resource: dict, change_set_id: Optional[str] = None
+ ) -> ChangeConfig:
+ result = ChangeConfig(
+ **{
+ "Type": "Resource",
+ "ResourceChange": ResourceChange(
+ **{
+ "Action": action,
+ # TODO(srw): how can the resource not contain a logical resource id?
+ "LogicalResourceId": resource.get("LogicalResourceId"),
+ "PhysicalResourceId": resource.get("PhysicalResourceId"),
+ "ResourceType": resource["Type"],
+ # TODO ChangeSetId is only set for *nested* change sets
+ # "ChangeSetId": change_set_id,
+ "Scope": [], # TODO
+ "Details": [], # TODO
+ }
+ ),
+ }
+ )
+ if action == "Modify":
+ result["ResourceChange"]["Replacement"] = "False"
+ return result
+
+ def resource_config_differs(self, resource_new):
+ """Return whether the given resource properties differ from the existing config (for stack updates)."""
+ # TODO: this is broken for default fields and result_handler property modifications when they're added to the properties in the model
+ resource_id = resource_new["LogicalResourceId"]
+ resource_old = self.resources[resource_id]
+ props_old = resource_old.get("SpecifiedProperties", {})
+ props_new = resource_new["Properties"]
+ ignored_keys = ["LogicalResourceId", "PhysicalResourceId"]
+ old_keys = set(props_old.keys()) - set(ignored_keys)
+ new_keys = set(props_new.keys()) - set(ignored_keys)
+ if old_keys != new_keys:
+ return True
+ for key in old_keys:
+ if props_old[key] != props_new[key]:
+ return True
+ old_status = self.stack.resource_states.get(resource_id) or {}
+ previous_state = (
+ old_status.get("PreviousResourceStatus") or old_status.get("ResourceStatus") or ""
+ )
+ if old_status and "DELETE" in previous_state:
+ return True
+
+ # TODO: ?
+ def merge_properties(self, resource_id: str, old_stack, new_stack) -> None:
+ old_resources = old_stack.template["Resources"]
+ new_resources = new_stack.template["Resources"]
+ new_resource = new_resources[resource_id]
+
+ old_resource = old_resources[resource_id] = old_resources.get(resource_id) or {}
+ for key, value in new_resource.items():
+ if key == "Properties":
+ continue
+ old_resource[key] = old_resource.get(key, value)
+ old_res_props = old_resource["Properties"] = old_resource.get("Properties", {})
+ for key, value in new_resource["Properties"].items():
+ old_res_props[key] = value
+
+ old_res_props = {
+ k: v for k, v in old_res_props.items() if k in new_resource["Properties"].keys()
+ }
+ old_resource["Properties"] = old_res_props
+
+ # overwrite original template entirely
+ old_stack.template_original["Resources"][resource_id] = new_stack.template_original[
+ "Resources"
+ ][resource_id]
+
+ def construct_changes(
+ self,
+ existing_stack,
+ new_stack,
+ # TODO: remove initialize argument from here, and determine action based on resource status
+ initialize: Optional[bool] = False,
+ change_set_id=None,
+ append_to_changeset: Optional[bool] = False,
+ filter_unchanged_resources: Optional[bool] = False,
+ ) -> list[ChangeConfig]:
+ old_resources = existing_stack.template["Resources"]
+ new_resources = new_stack.template["Resources"]
+ deletes = [val for key, val in old_resources.items() if key not in new_resources]
+ adds = [val for key, val in new_resources.items() if initialize or key not in old_resources]
+ modifies = [
+ val for key, val in new_resources.items() if not initialize and key in old_resources
+ ]
+
+ changes = []
+ for action, items in (("Remove", deletes), ("Add", adds), ("Modify", modifies)):
+ for item in items:
+ item["Properties"] = item.get("Properties", {})
+ if (
+ not filter_unchanged_resources # TODO: find out purpose of this
+ or action != "Modify"
+ or self.resource_config_differs(item)
+ ):
+ change = self.get_change_config(action, item, change_set_id=change_set_id)
+ changes.append(change)
+
+ # append changes to change set
+ if append_to_changeset and isinstance(new_stack, StackChangeSet):
+ new_stack.changes.extend(changes)
+
+ return changes
+
+ def apply_changes(
+ self,
+ existing_stack: Stack,
+ new_stack: StackChangeSet,
+ change_set_id: Optional[str] = None,
+ initialize: Optional[bool] = False,
+ action: Optional[str] = None,
+ ):
+ old_resources = existing_stack.template["Resources"]
+ new_resources = new_stack.template["Resources"]
+ action = action or "CREATE"
+ # TODO: this seems wrong, not every resource here will be in an UPDATE_IN_PROGRESS state? (only the ones that will actually be updated)
+ self.init_resource_status(old_resources, action="UPDATE")
+
+ # apply parameter changes to existing stack
+ # self.apply_parameter_changes(existing_stack, new_stack)
+
+ # construct changes
+ changes = self.construct_changes(
+ existing_stack,
+ new_stack,
+ initialize=initialize,
+ change_set_id=change_set_id,
+ )
+
+ # check if we have actual changes in the stack, and prepare properties
+ contains_changes = False
+ for change in changes:
+ res_action = change["ResourceChange"]["Action"]
+ resource = new_resources.get(change["ResourceChange"]["LogicalResourceId"])
+ # FIXME: we need to resolve refs before diffing to detect if for example a parameter causes the change or not
+ # unfortunately this would currently cause issues because we might not be able to resolve everything yet
+ # resource = resolve_refs_recursively(
+ # self.stack_name,
+ # self.resources,
+ # self.mappings,
+ # self.stack.resolved_conditions,
+ # self.stack.resolved_parameters,
+ # resource,
+ # )
+ if res_action in ["Add", "Remove"] or self.resource_config_differs(resource):
+ contains_changes = True
+ if res_action in ["Modify", "Add"]:
+ # mutating call that overwrites resource properties with new properties and overwrites the template in old stack with new template
+ self.merge_properties(resource["LogicalResourceId"], existing_stack, new_stack)
+ if not contains_changes:
+ raise NoStackUpdates("No updates are to be performed.")
+
+ # merge stack outputs and conditions
+ existing_stack.outputs.update(new_stack.outputs)
+ existing_stack.conditions.update(new_stack.conditions)
+
+ # TODO: ideally the entire template has to be replaced, but tricky at this point
+ existing_stack.template["Metadata"] = new_stack.template.get("Metadata")
+ existing_stack.template_body = new_stack.template_body
+
+ # start deployment loop
+ return self.apply_changes_in_loop(
+ changes, existing_stack, action=action, new_stack=new_stack
+ )
+
+ def apply_changes_in_loop(
+ self,
+ changes: list[ChangeConfig],
+ stack: Stack,
+ action: Optional[str] = None,
+ new_stack=None,
+ ):
+ def _run(*args):
+ status_reason = None
+ try:
+ self.do_apply_changes_in_loop(changes, stack)
+ status = f"{action}_COMPLETE"
+ except Exception as e:
+ log_method = LOG.debug
+ if config.CFN_VERBOSE_ERRORS:
+ log_method = LOG.exception
+ log_method(
+ 'Error applying changes for CloudFormation stack "%s": %s %s',
+ stack.stack_name,
+ e,
+ traceback.format_exc(),
+ )
+ status = f"{action}_FAILED"
+ status_reason = str(e)
+ stack.set_stack_status(status, status_reason)
+ if isinstance(new_stack, StackChangeSet):
+ new_stack.metadata["Status"] = status
+ exec_result = "EXECUTE_FAILED" if "FAILED" in status else "EXECUTE_COMPLETE"
+ new_stack.metadata["ExecutionStatus"] = exec_result
+ result = "failed" if "FAILED" in status else "succeeded"
+ new_stack.metadata["StatusReason"] = status_reason or f"Deployment {result}"
+
+ # run deployment in background loop, to avoid client network timeouts
+ return start_worker_thread(_run)
+
+ def prepare_should_deploy_change(
+ self, resource_id: str, change: ResourceChange, stack, new_resources: dict
+ ) -> bool:
+ """
+ TODO: document
+ """
+ resource = new_resources[resource_id]
+ res_change = change["ResourceChange"]
+ action = res_change["Action"]
+
+ # check resource condition, if present
+ if not evaluate_resource_condition(stack.resolved_conditions, resource):
+ LOG.debug(
+ 'Skipping deployment of "%s", as resource condition evaluates to false', resource_id
+ )
+ return False
+
+ # resolve refs in resource details
+ resolve_refs_recursively(
+ self.account_id,
+ self.region_name,
+ stack.stack_name,
+ stack.resources,
+ stack.mappings,
+ stack.resolved_conditions,
+ stack.resolved_parameters,
+ resource,
+ )
+
+ if action in ["Add", "Modify"]:
+ is_deployed = self.is_deployed(resource)
+ # TODO: Attaching the cached _deployed info here, as we should not change the "Add"/"Modify" attribute
+ # here, which is used further down the line to determine the resource action CREATE/UPDATE. This is a
+ # temporary workaround for now - to be refactored once we introduce proper stack resource state models.
+ res_change["_deployed"] = is_deployed
+ if not is_deployed:
+ return True
+ if action == "Add":
+ return False
+ elif action == "Remove":
+ return True
+ return True
+
+ # Stack is needed here
+ def apply_change(self, change: ChangeConfig, stack: Stack) -> None:
+ change_details = change["ResourceChange"]
+ action = change_details["Action"]
+ resource_id = change_details["LogicalResourceId"]
+ resources = stack.resources
+ resource = resources[resource_id]
+
+ # TODO: this should not be needed as resources are filtered out if the
+ # condition evaluates to False.
+ if not evaluate_resource_condition(stack.resolved_conditions, resource):
+ return
+
+ # remove AWS::NoValue entries
+ resource_props = resource.get("Properties")
+ if resource_props:
+ resource["Properties"] = remove_none_values(resource_props)
+
+ executor = self.create_resource_provider_executor()
+ resource_provider_payload = self.create_resource_provider_payload(
+ action, logical_resource_id=resource_id
+ )
+
+ resource_provider = executor.try_load_resource_provider(get_resource_type(resource))
+ if resource_provider is not None:
+ # add in-progress event
+ resource_status = f"{get_action_name_for_resource_change(action)}_IN_PROGRESS"
+ physical_resource_id = None
+ if action in ("Modify", "Remove"):
+ previous_state = self.resources[resource_id].get("_last_deployed_state")
+ if not previous_state:
+ # TODO: can this happen?
+ previous_state = self.resources[resource_id]["Properties"]
+ physical_resource_id = executor.extract_physical_resource_id_from_model_with_schema(
+ resource_model=previous_state,
+ resource_type=resource["Type"],
+ resource_type_schema=resource_provider.SCHEMA,
+ )
+ stack.add_stack_event(
+ resource_id=resource_id,
+ physical_res_id=physical_resource_id,
+ status=resource_status,
+ )
+
+ # perform the deploy
+ progress_event = executor.deploy_loop(
+ resource_provider, resource, resource_provider_payload
+ )
+ else:
+ resource["PhysicalResourceId"] = MOCK_REFERENCE
+ progress_event = ProgressEvent(OperationStatus.SUCCESS, resource_model={})
+
+ # TODO: clean up the surrounding loop (do_apply_changes_in_loop) so that the responsibilities are clearer
+ stack_action = get_action_name_for_resource_change(action)
+ match progress_event.status:
+ case OperationStatus.FAILED:
+ stack.set_resource_status(
+ resource_id,
+ f"{stack_action}_FAILED",
+ status_reason=progress_event.message or "",
+ )
+ # TODO: remove exception raising here?
+ # TODO: fix request token
+ raise Exception(
+ f'Resource handler returned message: "{progress_event.message}" (RequestToken: 10c10335-276a-33d3-5c07-018b684c3d26, HandlerErrorCode: InvalidRequest){progress_event.error_code}'
+ )
+ case OperationStatus.SUCCESS:
+ stack.set_resource_status(resource_id, f"{stack_action}_COMPLETE")
+ case OperationStatus.PENDING:
+ # signal to the main loop that we should come back to this resource in the future
+ raise DependencyNotYetSatisfied(
+ resource_ids=[], message="Resource dependencies not yet satisfied"
+ )
+ case OperationStatus.IN_PROGRESS:
+ raise Exception("Resource deployment loop should not finish in this state")
+ case unknown_status:
+ raise Exception(f"Unknown operation status: {unknown_status}")
+
+ # TODO: this is probably already done in executor, try removing this
+ resource["Properties"] = progress_event.resource_model
+
+ def create_resource_provider_executor(self) -> ResourceProviderExecutor:
+ return ResourceProviderExecutor(
+ stack_name=self.stack.stack_name,
+ stack_id=self.stack.stack_id,
+ )
+
+ def create_resource_provider_payload(
+ self, action: str, logical_resource_id: str
+ ) -> ResourceProviderPayload:
+ # FIXME: use proper credentials
+ creds: Credentials = {
+ "accessKeyId": self.account_id,
+ "secretAccessKey": INTERNAL_AWS_SECRET_ACCESS_KEY,
+ "sessionToken": "",
+ }
+ resource = self.resources[logical_resource_id]
+
+ resource_provider_payload: ResourceProviderPayload = {
+ "awsAccountId": self.account_id,
+ "callbackContext": {},
+ "stackId": self.stack.stack_name,
+ "resourceType": resource["Type"],
+ "resourceTypeVersion": "000000",
+ # TODO: not actually a UUID
+ "bearerToken": str(uuid.uuid4()),
+ "region": self.region_name,
+ "action": action,
+ "requestData": {
+ "logicalResourceId": logical_resource_id,
+ "resourceProperties": resource["Properties"],
+ "previousResourceProperties": resource.get("_last_deployed_state"), # TODO
+ "callerCredentials": creds,
+ "providerCredentials": creds,
+ "systemTags": {},
+ "previousSystemTags": {},
+ "stackTags": {},
+ "previousStackTags": {},
+ },
+ }
+ return resource_provider_payload
+
+ def delete_stack(self):
+ if not self.stack:
+ return
+ self.stack.set_stack_status("DELETE_IN_PROGRESS")
+ stack_resources = list(self.stack.resources.values())
+ resources = {r["LogicalResourceId"]: clone_safe(r) for r in stack_resources}
+ original_resources = self.stack.template_original["Resources"]
+
+ # TODO: what is this doing?
+ for key, resource in resources.items():
+ resource["Properties"] = resource.get(
+ "Properties", clone_safe(resource)
+ ) # TODO: why is there a fallback?
+ resource["ResourceType"] = get_resource_type(resource)
+
+ def _safe_lookup_is_deleted(r_id):
+ """handles the case where self.stack.resource_status(..) fails for whatever reason"""
+ try:
+ return self.stack.resource_status(r_id).get("ResourceStatus") == "DELETE_COMPLETE"
+ except Exception:
+ if config.CFN_VERBOSE_ERRORS:
+ LOG.exception("failed to lookup if resource %s is deleted", r_id)
+ return True # just an assumption
+
+ ordered_resource_ids = list(
+ order_resources(
+ resources=original_resources,
+ resolved_conditions=self.stack.resolved_conditions,
+ resolved_parameters=self.stack.resolved_parameters,
+ reverse=True,
+ ).keys()
+ )
+ for i, resource_id in enumerate(ordered_resource_ids):
+ resource = resources[resource_id]
+ try:
+ # TODO: cache condition value in resource details on deployment and use cached value here
+ if not evaluate_resource_condition(
+ self.stack.resolved_conditions,
+ resource,
+ ):
+ continue
+
+ executor = self.create_resource_provider_executor()
+ resource_provider_payload = self.create_resource_provider_payload(
+ "Remove", logical_resource_id=resource_id
+ )
+ LOG.debug(
+ 'Handling "Remove" for resource "%s" (%s/%s) type "%s"',
+ resource_id,
+ i + 1,
+ len(resources),
+ resource["ResourceType"],
+ )
+ resource_provider = executor.try_load_resource_provider(get_resource_type(resource))
+ if resource_provider is not None:
+ event = executor.deploy_loop(
+ resource_provider, resource, resource_provider_payload
+ )
+ else:
+ event = ProgressEvent(OperationStatus.SUCCESS, resource_model={})
+ match event.status:
+ case OperationStatus.SUCCESS:
+ self.stack.set_resource_status(resource_id, "DELETE_COMPLETE")
+ case OperationStatus.PENDING:
+ # the resource is still being deleted, specifically the provider has
+ # signalled that the deployment loop should skip this resource this
+ # time and come back to it later, likely due to unmet child
+ # resources still existing because we don't delete things in the
+ # correct order yet.
+ continue
+ case OperationStatus.FAILED:
+ LOG.exception(
+ "Failed to delete resource with id %s. Reason: %s",
+ resource_id,
+ event.message or "unknown",
+ )
+ case OperationStatus.IN_PROGRESS:
+ # the resource provider executor should not return this state, so
+ # this state is a programming error
+ raise Exception(
+ "Programming error: ResourceProviderExecutor cannot return IN_PROGRESS"
+ )
+ case other_status:
+ raise Exception(f"Use of unsupported status found: {other_status}")
+
+ except Exception as e:
+ LOG.exception(
+ "Failed to delete resource with id %s. Final exception: %s",
+ resource_id,
+ e,
+ )
+
+ # update status
+ self.stack.set_stack_status("DELETE_COMPLETE")
+ self.stack.set_time_attribute("DeletionTime")
+
+ def do_apply_changes_in_loop(self, changes: list[ChangeConfig], stack: Stack) -> list:
+ # apply changes in a retry loop, to resolve resource dependencies and converge to the target state
+ changes_done = []
+ new_resources = stack.resources
+
+ sorted_changes = order_changes(
+ given_changes=changes,
+ resources=new_resources,
+ resolved_conditions=stack.resolved_conditions,
+ resolved_parameters=stack.resolved_parameters,
+ )
+ for change_idx, change in enumerate(sorted_changes):
+ res_change = change["ResourceChange"]
+ action = res_change["Action"]
+ is_add_or_modify = action in ["Add", "Modify"]
+ resource_id = res_change["LogicalResourceId"]
+
+ # TODO: do resolve_refs_recursively once here
+ try:
+ if is_add_or_modify:
+ should_deploy = self.prepare_should_deploy_change(
+ resource_id, change, stack, new_resources
+ )
+ LOG.debug(
+ 'Handling "%s" for resource "%s" (%s/%s) type "%s" (should_deploy=%s)',
+ action,
+ resource_id,
+ change_idx + 1,
+ len(changes),
+ res_change["ResourceType"],
+ should_deploy,
+ )
+ if not should_deploy:
+ stack_action = get_action_name_for_resource_change(action)
+ stack.set_resource_status(resource_id, f"{stack_action}_COMPLETE")
+ continue
+ elif action == "Remove":
+ should_remove = self.prepare_should_deploy_change(
+ resource_id, change, stack, new_resources
+ )
+ if not should_remove:
+ continue
+ LOG.debug(
+ 'Handling "%s" for resource "%s" (%s/%s) type "%s"',
+ action,
+ resource_id,
+ change_idx + 1,
+ len(changes),
+ res_change["ResourceType"],
+ )
+ self.apply_change(change, stack=stack)
+ changes_done.append(change)
+ except Exception as e:
+ status_action = {
+ "Add": "CREATE",
+ "Modify": "UPDATE",
+ "Dynamic": "UPDATE",
+ "Remove": "DELETE",
+ }[action]
+ stack.add_stack_event(
+ resource_id=resource_id,
+ physical_res_id=new_resources[resource_id].get("PhysicalResourceId"),
+ status=f"{status_action}_FAILED",
+ status_reason=str(e),
+ )
+ if config.CFN_VERBOSE_ERRORS:
+ LOG.exception("Failed to deploy resource %s, stack deploy failed", resource_id)
+ raise
+
+ # clean up references to deleted resources in stack
+ deletes = [c for c in changes_done if c["ResourceChange"]["Action"] == "Remove"]
+ for delete in deletes:
+ stack.template["Resources"].pop(delete["ResourceChange"]["LogicalResourceId"], None)
+
+ # resolve outputs
+ stack.resolved_outputs = resolve_outputs(self.account_id, self.region_name, stack)
+
+ return changes_done
+
+
+# FIXME: resolve_refs_recursively should not be needed, the resources themselves should have those values available already
+def resolve_outputs(account_id: str, region_name: str, stack) -> list[dict]:
+ result = []
+ for k, details in stack.outputs.items():
+ if not evaluate_resource_condition(stack.resolved_conditions, details):
+ continue
+ value = None
+ try:
+ resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack.stack_name,
+ stack.resources,
+ stack.mappings,
+ stack.resolved_conditions,
+ stack.resolved_parameters,
+ details,
+ )
+ value = details["Value"]
+ except Exception as e:
+ log_method = LOG.debug
+ if config.CFN_VERBOSE_ERRORS:
+ raise # unresolvable outputs cause a stack failure
+ # log_method = getattr(LOG, "exception")
+ log_method("Unable to resolve references in stack outputs: %s - %s", details, e)
+ exports = details.get("Export") or {}
+ export = exports.get("Name")
+ export = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack.stack_name,
+ stack.resources,
+ stack.mappings,
+ stack.resolved_conditions,
+ stack.resolved_parameters,
+ export,
+ )
+ description = details.get("Description")
+ entry = {
+ "OutputKey": k,
+ "OutputValue": value,
+ "Description": description,
+ "ExportName": export,
+ }
+ result.append(entry)
+ return result
diff --git a/localstack-core/localstack/services/cloudformation/engine/template_preparer.py b/localstack-core/localstack/services/cloudformation/engine/template_preparer.py
new file mode 100644
index 0000000000000..8206a7d6a99fc
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/engine/template_preparer.py
@@ -0,0 +1,68 @@
+import json
+import logging
+
+from localstack.services.cloudformation.engine import yaml_parser
+from localstack.services.cloudformation.engine.transformers import (
+ apply_global_transformations,
+ apply_intrinsic_transformations,
+)
+from localstack.utils.json import clone_safe
+
+LOG = logging.getLogger(__name__)
+
+
+def parse_template(template: str) -> dict:
+ try:
+ return json.loads(template)
+ except Exception:
+ try:
+ return clone_safe(yaml_parser.parse_yaml(template))
+ except Exception as e:
+ LOG.debug("Unable to parse CloudFormation template (%s): %s", e, template)
+ raise
+
+
+def template_to_json(template: str) -> str:
+ template = parse_template(template)
+ return json.dumps(template)
+
+
+# TODO: consider moving to transformers.py as well
+def transform_template(
+ account_id: str,
+ region_name: str,
+ template: dict,
+ stack_name: str,
+ resources: dict,
+ mappings: dict,
+ conditions: dict[str, bool],
+ resolved_parameters: dict,
+) -> dict:
+ proccesed_template = dict(template)
+
+ # apply 'Fn::Transform' intrinsic functions (note: needs to be applied before global
+ # transforms below, as some utils - incl samtransformer - expect them to be resolved already)
+ proccesed_template = apply_intrinsic_transformations(
+ account_id,
+ region_name,
+ proccesed_template,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ resolved_parameters,
+ )
+
+ # apply global transforms
+ proccesed_template = apply_global_transformations(
+ account_id,
+ region_name,
+ proccesed_template,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ resolved_parameters,
+ )
+
+ return proccesed_template
diff --git a/localstack-core/localstack/services/cloudformation/engine/template_utils.py b/localstack-core/localstack/services/cloudformation/engine/template_utils.py
new file mode 100644
index 0000000000000..062e4a3f1f840
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/engine/template_utils.py
@@ -0,0 +1,430 @@
+import re
+from typing import Any
+
+from localstack.services.cloudformation.deployment_utils import PLACEHOLDER_AWS_NO_VALUE
+from localstack.services.cloudformation.engine.errors import TemplateError
+from localstack.utils.urls import localstack_host
+
+AWS_URL_SUFFIX = localstack_host().host # value is "amazonaws.com" in real AWS
+
+
+def get_deps_for_resource(resource: dict, evaluated_conditions: dict[str, bool]) -> set[str]:
+ """
+ :param resource: the resource definition to be checked for dependencies
+ :param evaluated_conditions:
+ :return: a set of logical resource IDs which this resource depends on
+ """
+ property_dependencies = resolve_dependencies(
+ resource.get("Properties", {}), evaluated_conditions
+ )
+ explicit_dependencies = resource.get("DependsOn", [])
+ if not isinstance(explicit_dependencies, list):
+ explicit_dependencies = [explicit_dependencies]
+ return property_dependencies.union(explicit_dependencies)
+
+
+def resolve_dependencies(d: dict, evaluated_conditions: dict[str, bool]) -> set[str]:
+ items = set()
+
+ if isinstance(d, dict):
+ for k, v in d.items():
+ if k == "Fn::If":
+ # check the condition and only traverse down the correct path
+ condition_name, true_value, false_value = v
+ if evaluated_conditions[condition_name]:
+ items = items.union(resolve_dependencies(true_value, evaluated_conditions))
+ else:
+ items = items.union(resolve_dependencies(false_value, evaluated_conditions))
+ elif k == "Ref":
+ items.add(v)
+ elif k == "Fn::GetAtt":
+ items.add(v[0] if isinstance(v, list) else v.split(".")[0])
+ elif k == "Fn::Sub":
+ # we can assume anything in there is a ref
+ if isinstance(v, str):
+ # { "Fn::Sub" : "Hello ${Name}" }
+ variables_found = re.findall("\\${([^}]+)}", v)
+ for var in variables_found:
+ if "." in var:
+ var = var.split(".")[0]
+ items.add(var)
+ elif isinstance(v, list):
+ # { "Fn::Sub" : [ "Hello ${Name}", { "Name": "SomeName" } ] }
+ variables_found = re.findall("\\${([^}]+)}", v[0])
+ for var in variables_found:
+ if var in v[1]:
+ # variable is included in provided mapping and can either be a static value or another reference
+ if isinstance(v[1][var], dict):
+ # e.g. { "Fn::Sub" : [ "Hello ${Name}", { "Name": {"Ref": "NameParam"} } ] }
+ # the values can have references, so we need to go deeper
+ items = items.union(
+ resolve_dependencies(v[1][var], evaluated_conditions)
+ )
+ else:
+ # it's now either a GetAtt call or a direct reference
+ if "." in var:
+ var = var.split(".")[0]
+ items.add(var)
+ else:
+ raise Exception(f"Invalid template structure in Fn::Sub: {v}")
+ elif isinstance(v, dict):
+ items = items.union(resolve_dependencies(v, evaluated_conditions))
+ elif isinstance(v, list):
+ for item in v:
+ # TODO: assumption that every element is a dict might not be true
+ items = items.union(resolve_dependencies(item, evaluated_conditions))
+ else:
+ pass
+ elif isinstance(d, list):
+ for item in d:
+ items = items.union(resolve_dependencies(item, evaluated_conditions))
+ r = {i for i in items if not i.startswith("AWS::")}
+ return r
+
+
+def resolve_stack_conditions(
+ account_id: str,
+ region_name: str,
+ conditions: dict,
+ parameters: dict,
+ mappings: dict,
+ stack_name: str,
+) -> dict[str, bool]:
+ """
+ Within each condition, you can reference another:
+ condition
+ parameter value
+ mapping
+
+ You can use the following intrinsic functions to define conditions:
+ Fn::And
+ Fn::Equals
+ Fn::If
+ Fn::Not
+ Fn::Or
+
+ TODO: more checks on types from references (e.g. in a mapping value)
+ TODO: does a ref ever return a non-string value?
+ TODO: when unifying/reworking intrinsic functions rework this to a class structure
+ """
+ result = {}
+ for condition_name, condition in conditions.items():
+ result[condition_name] = resolve_condition(
+ account_id, region_name, condition, conditions, parameters, mappings, stack_name
+ )
+ return result
+
+
+def resolve_pseudo_parameter(
+ account_id: str, region_name: str, pseudo_parameter: str, stack_name: str
+) -> Any:
+ """
+ TODO: this function needs access to more stack context
+ """
+ # pseudo parameters
+ match pseudo_parameter:
+ case "AWS::Region":
+ return region_name
+ case "AWS::Partition":
+ return "aws"
+ case "AWS::StackName":
+ return stack_name
+ case "AWS::StackId":
+ # TODO return proper stack id!
+ return stack_name
+ case "AWS::AccountId":
+ return account_id
+ case "AWS::NoValue":
+ return PLACEHOLDER_AWS_NO_VALUE
+ case "AWS::NotificationARNs":
+ # TODO!
+ return {}
+ case "AWS::URLSuffix":
+ return AWS_URL_SUFFIX
+
+
+def resolve_conditional_mapping_ref(
+ ref_name, account_id: str, region_name: str, stack_name: str, parameters
+):
+ if ref_name.startswith("AWS::"):
+ ref_value = resolve_pseudo_parameter(account_id, region_name, ref_name, stack_name)
+ if ref_value is None:
+ raise TemplateError(f"Invalid pseudo parameter '{ref_name}'")
+ else:
+ param = parameters.get(ref_name)
+ if not param:
+ raise TemplateError(
+ f"Invalid reference: '{ref_name}' does not exist in parameters: '{parameters}'"
+ )
+ ref_value = param.get("ResolvedValue") or param.get("ParameterValue")
+
+ return ref_value
+
+
+def resolve_condition(
+ account_id: str, region_name: str, condition, conditions, parameters, mappings, stack_name
+):
+ if isinstance(condition, dict):
+ for k, v in condition.items():
+ match k:
+ case "Ref":
+ if isinstance(v, str) and v.startswith("AWS::"):
+ return resolve_pseudo_parameter(
+ account_id, region_name, v, stack_name
+ ) # TODO: this pseudo parameter resolving needs context(!)
+ # TODO: add util function for resolving individual refs (e.g. one util for resolving pseudo parameters)
+ # TODO: pseudo-parameters like AWS::Region
+ # can only really be a parameter here
+ # TODO: how are conditions references written here? as {"Condition": "ConditionA"} or via Ref?
+ # TODO: test for a boolean parameter?
+ param = parameters[v]
+ parameter_type: str = param["ParameterType"]
+ parameter_value = param.get("ResolvedValue") or param.get("ParameterValue")
+
+ if parameter_type in ["CommaDelimitedList"] or parameter_type.startswith(
+ "List<"
+ ):
+ return [p.strip() for p in parameter_value.split(",")]
+ else:
+ return parameter_value
+
+ case "Condition":
+ return resolve_condition(
+ account_id,
+ region_name,
+ conditions[v],
+ conditions,
+ parameters,
+ mappings,
+ stack_name,
+ )
+ case "Fn::FindInMap":
+ map_name, top_level_key, second_level_key = v
+ if isinstance(map_name, dict) and "Ref" in map_name:
+ ref_name = map_name["Ref"]
+ map_name = resolve_conditional_mapping_ref(
+ ref_name, account_id, region_name, stack_name, parameters
+ )
+
+ if isinstance(top_level_key, dict) and "Ref" in top_level_key:
+ ref_name = top_level_key["Ref"]
+ top_level_key = resolve_conditional_mapping_ref(
+ ref_name, account_id, region_name, stack_name, parameters
+ )
+
+ if isinstance(second_level_key, dict) and "Ref" in second_level_key:
+ ref_name = second_level_key["Ref"]
+ second_level_key = resolve_conditional_mapping_ref(
+ ref_name, account_id, region_name, stack_name, parameters
+ )
+
+ mapping = mappings.get(map_name)
+ if not mapping:
+ raise TemplateError(
+ f"Invalid reference: '{map_name}' could not be found in the template mappings: '{list(mappings.keys())}'"
+ )
+
+ top_level_map = mapping.get(top_level_key)
+ if not top_level_map:
+ raise TemplateError(
+ f"Invalid reference: '{top_level_key}' could not be found in the '{map_name}' mapping: '{list(mapping.keys())}'"
+ )
+
+ value = top_level_map.get(second_level_key)
+ if not value:
+ raise TemplateError(
+ f"Invalid reference: '{second_level_key}' could not be found in the '{top_level_key}' mapping: '{top_level_map}'"
+ )
+
+ return value
+ case "Fn::If":
+ if_condition_name, true_branch, false_branch = v
+ if resolve_condition(
+ account_id,
+ region_name,
+ if_condition_name,
+ conditions,
+ parameters,
+ mappings,
+ stack_name,
+ ):
+ return resolve_condition(
+ account_id,
+ region_name,
+ true_branch,
+ conditions,
+ parameters,
+ mappings,
+ stack_name,
+ )
+ else:
+ return resolve_condition(
+ account_id,
+ region_name,
+ false_branch,
+ conditions,
+ parameters,
+ mappings,
+ stack_name,
+ )
+ case "Fn::Not":
+ return not resolve_condition(
+ account_id, region_name, v[0], conditions, parameters, mappings, stack_name
+ )
+ case "Fn::And":
+ # TODO: should actually restrict this a bit
+ return resolve_condition(
+ account_id, region_name, v[0], conditions, parameters, mappings, stack_name
+ ) and resolve_condition(
+ account_id, region_name, v[1], conditions, parameters, mappings, stack_name
+ )
+ case "Fn::Or":
+ return resolve_condition(
+ account_id, region_name, v[0], conditions, parameters, mappings, stack_name
+ ) or resolve_condition(
+ account_id, region_name, v[1], conditions, parameters, mappings, stack_name
+ )
+ case "Fn::Equals":
+ left = resolve_condition(
+ account_id, region_name, v[0], conditions, parameters, mappings, stack_name
+ )
+ right = resolve_condition(
+ account_id, region_name, v[1], conditions, parameters, mappings, stack_name
+ )
+ return fn_equals_type_conversion(left) == fn_equals_type_conversion(right)
+ case "Fn::Join":
+ join_list = v[1]
+ if isinstance(v[1], dict):
+ join_list = resolve_condition(
+ account_id,
+ region_name,
+ v[1],
+ conditions,
+ parameters,
+ mappings,
+ stack_name,
+ )
+ result = v[0].join(
+ [
+ resolve_condition(
+ account_id,
+ region_name,
+ x,
+ conditions,
+ parameters,
+ mappings,
+ stack_name,
+ )
+ for x in join_list
+ ]
+ )
+ return result
+ case "Fn::Select":
+ index = v[0]
+ options = v[1]
+ for i, option in enumerate(options):
+ if isinstance(option, dict):
+ options[i] = resolve_condition(
+ account_id,
+ region_name,
+ option,
+ conditions,
+ parameters,
+ mappings,
+ stack_name,
+ )
+ return options[index]
+ case "Fn::Sub":
+ # we can assume anything in there is a ref
+ if isinstance(v, str):
+ # { "Fn::Sub" : "Hello ${Name}" }
+ result = v
+ variables_found = re.findall("\\${([^}]+)}", v)
+ for var in variables_found:
+ # can't be a resource here (!), so also not attribute access
+ if var.startswith("AWS::"):
+ # pseudo-parameter
+ resolved_pseudo_param = resolve_pseudo_parameter(
+ account_id, region_name, var, stack_name
+ )
+ result = result.replace(f"${{{var}}}", resolved_pseudo_param)
+ else:
+ # parameter
+ param = parameters[var]
+ parameter_type: str = param["ParameterType"]
+ resolved_parameter = param.get("ResolvedValue") or param.get(
+ "ParameterValue"
+ )
+
+ if parameter_type in [
+ "CommaDelimitedList"
+ ] or parameter_type.startswith("List<"):
+ resolved_parameter = [
+ p.strip() for p in resolved_parameter.split(",")
+ ]
+
+ result = result.replace(f"${{{var}}}", resolved_parameter)
+
+ return result
+ elif isinstance(v, list):
+ # { "Fn::Sub" : [ "Hello ${Name}", { "Name": "SomeName" } ] }
+ result = v[0]
+ variables_found = re.findall("\\${([^}]+)}", v[0])
+ for var in variables_found:
+ if var in v[1]:
+ # variable is included in provided mapping and can either be a static value or another reference
+ if isinstance(v[1][var], dict):
+ # e.g. { "Fn::Sub" : [ "Hello ${Name}", { "Name": {"Ref": "NameParam"} } ] }
+ # the values can have references, so we need to go deeper
+ resolved_var = resolve_condition(
+ account_id,
+ region_name,
+ v[1][var],
+ conditions,
+ parameters,
+ mappings,
+ stack_name,
+ )
+ result = result.replace(f"${{{var}}}", resolved_var)
+ else:
+ result = result.replace(f"${{{var}}}", v[1][var])
+ else:
+ # it's now either a GetAtt call or a direct reference
+ if var.startswith("AWS::"):
+ # pseudo-parameter
+ resolved_pseudo_param = resolve_pseudo_parameter(
+ account_id, region_name, var, stack_name
+ )
+ result = result.replace(f"${{{var}}}", resolved_pseudo_param)
+ else:
+ # parameter
+ param = parameters[var]
+ parameter_type: str = param["ParameterType"]
+ resolved_parameter = param.get("ResolvedValue") or param.get(
+ "ParameterValue"
+ )
+
+ if parameter_type in [
+ "CommaDelimitedList"
+ ] or parameter_type.startswith("List<"):
+ resolved_parameter = [
+ p.strip() for p in resolved_parameter.split(",")
+ ]
+
+ result = result.replace(f"${{{var}}}", resolved_parameter)
+ return result
+ else:
+ raise Exception(f"Invalid template structure in Fn::Sub: {v}")
+ case _:
+ raise Exception(f"Invalid condition structure encountered: {condition=}")
+ else:
+ return condition
+
+
+def fn_equals_type_conversion(value) -> str:
+ if isinstance(value, str):
+ return value
+ elif isinstance(value, bool):
+ return "true" if value else "false"
+ else:
+ return str(value) # TODO: investigate correct behavior
diff --git a/localstack-core/localstack/services/cloudformation/engine/transformers.py b/localstack-core/localstack/services/cloudformation/engine/transformers.py
new file mode 100644
index 0000000000000..fea83f5ca4533
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/engine/transformers.py
@@ -0,0 +1,304 @@
+import json
+import logging
+import os
+from copy import deepcopy
+from typing import Dict, Optional, Type, Union
+
+import boto3
+from botocore.exceptions import ClientError
+from samtranslator.translator.transform import transform as transform_sam
+
+from localstack.aws.api import CommonServiceException
+from localstack.aws.connect import connect_to
+from localstack.services.cloudformation.engine.policy_loader import create_policy_loader
+from localstack.services.cloudformation.engine.template_deployer import resolve_refs_recursively
+from localstack.services.cloudformation.stores import get_cloudformation_store
+from localstack.utils import testutil
+from localstack.utils.objects import recurse_object
+from localstack.utils.strings import long_uid
+
+LOG = logging.getLogger(__name__)
+
+SERVERLESS_TRANSFORM = "AWS::Serverless-2016-10-31"
+EXTENSIONS_TRANSFORM = "AWS::LanguageExtensions"
+SECRETSMANAGER_TRANSFORM = "AWS::SecretsManager-2020-07-23"
+
+TransformResult = Union[dict, str]
+
+
+class Transformer:
+ """Abstract class for Fn::Transform intrinsic functions"""
+
+ def transform(self, account_id: str, region_name: str, parameters: dict) -> TransformResult:
+ """Apply the transformer to the given parameters and return the modified construct"""
+
+
+class AwsIncludeTransformer(Transformer):
+ """Implements the 'AWS::Include' transform intrinsic function"""
+
+ def transform(self, account_id: str, region_name: str, parameters: dict) -> TransformResult:
+ from localstack.services.cloudformation.engine.template_preparer import parse_template
+
+ location = parameters.get("Location")
+ if location and location.startswith("s3://"):
+ s3_client = connect_to(aws_access_key_id=account_id, region_name=region_name).s3
+ bucket, _, path = location.removeprefix("s3://").partition("/")
+ try:
+ content = testutil.download_s3_object(s3_client, bucket, path)
+ except ClientError:
+ LOG.error("client error downloading S3 object '%s/%s'", bucket, path)
+ raise
+ content = parse_template(content)
+ return content
+ else:
+ LOG.warning("Unexpected Location parameter for AWS::Include transformer: %s", location)
+ return parameters
+
+
+# maps transformer names to implementing classes
+transformers: Dict[str, Type] = {"AWS::Include": AwsIncludeTransformer}
+
+
+def apply_intrinsic_transformations(
+ account_id: str,
+ region_name: str,
+ template: dict,
+ stack_name: str,
+ resources: dict,
+ mappings: dict,
+ conditions: dict[str, bool],
+ stack_parameters: dict,
+) -> dict:
+ """Resolve constructs using the 'Fn::Transform' intrinsic function."""
+
+ def _visit(obj, path, **_):
+ if isinstance(obj, dict) and "Fn::Transform" in obj:
+ transform = (
+ obj["Fn::Transform"]
+ if isinstance(obj["Fn::Transform"], dict)
+ else {"Name": obj["Fn::Transform"]}
+ )
+ transform_name = transform.get("Name")
+ transformer_class = transformers.get(transform_name)
+ macro_store = get_cloudformation_store(account_id, region_name).macros
+ parameters = transform.get("Parameters") or {}
+ parameters = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ stack_parameters,
+ parameters,
+ )
+ if transformer_class:
+ transformer = transformer_class()
+ transformed = transformer.transform(account_id, region_name, parameters)
+ obj_copy = deepcopy(obj)
+ obj_copy.pop("Fn::Transform")
+ obj_copy.update(transformed)
+ return obj_copy
+
+ elif transform_name in macro_store:
+ obj_copy = deepcopy(obj)
+ obj_copy.pop("Fn::Transform")
+ result = execute_macro(
+ account_id, region_name, obj_copy, transform, stack_parameters, parameters, True
+ )
+ return result
+ else:
+ LOG.warning(
+ "Unsupported transform function '%s' used in %s", transform_name, stack_name
+ )
+ return obj
+
+ return recurse_object(template, _visit)
+
+
+def apply_global_transformations(
+ account_id: str,
+ region_name: str,
+ template: dict,
+ stack_name: str,
+ resources: dict,
+ mappings: dict,
+ conditions: dict[str, bool],
+ stack_parameters: dict,
+) -> dict:
+ processed_template = deepcopy(template)
+ transformations = format_template_transformations_into_list(
+ processed_template.get("Transform", [])
+ )
+ for transformation in transformations:
+ transformation_parameters = resolve_refs_recursively(
+ account_id,
+ region_name,
+ stack_name,
+ resources,
+ mappings,
+ conditions,
+ stack_parameters,
+ transformation.get("Parameters", {}),
+ )
+
+ if not isinstance(transformation["Name"], str):
+ # TODO this should be done during template validation
+ raise CommonServiceException(
+ code="ValidationError",
+ status_code=400,
+ message="Key Name of transform definition must be a string.",
+ sender_fault=True,
+ )
+ elif transformation["Name"] == SERVERLESS_TRANSFORM:
+ processed_template = apply_serverless_transformation(
+ account_id, region_name, processed_template, stack_parameters
+ )
+ elif transformation["Name"] == EXTENSIONS_TRANSFORM:
+ continue
+ elif transformation["Name"] == SECRETSMANAGER_TRANSFORM:
+ # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/transform-aws-secretsmanager.html
+ LOG.warning("%s is not yet supported. Ignoring.", SECRETSMANAGER_TRANSFORM)
+ else:
+ processed_template = execute_macro(
+ account_id,
+ region_name,
+ parsed_template=template,
+ macro=transformation,
+ stack_parameters=stack_parameters,
+ transformation_parameters=transformation_parameters,
+ )
+
+ return processed_template
+
+
+def format_template_transformations_into_list(transforms: list | dict | str) -> list[dict]:
+ """
+ The value of the Transform attribute can be:
+ - a transformation name
+ - an object like {Name: transformation, Parameters:{}}
+ - a list a list of names of the transformations to apply
+ - a list of objects defining a transformation
+ so the objective of this function is to normalize the list of transformations to apply into a list of transformation objects
+ """
+ formatted_transformations = []
+ if isinstance(transforms, str):
+ formatted_transformations.append({"Name": transforms})
+
+ if isinstance(transforms, dict):
+ formatted_transformations.append(transforms)
+
+ if isinstance(transforms, list):
+ for transformation in transforms:
+ if isinstance(transformation, str):
+ formatted_transformations.append({"Name": transformation})
+ if isinstance(transformation, dict):
+ formatted_transformations.append(transformation)
+
+ return formatted_transformations
+
+
+def execute_macro(
+ account_id: str,
+ region_name: str,
+ parsed_template: dict,
+ macro: dict,
+ stack_parameters: dict,
+ transformation_parameters: dict,
+ is_intrinsic=False,
+) -> str:
+ macro_definition = get_cloudformation_store(account_id, region_name).macros.get(macro["Name"])
+ if not macro_definition:
+ raise FailedTransformationException(
+ macro["Name"], f"Transformation {macro['Name']} is not supported."
+ )
+
+ formatted_stack_parameters = {}
+ for key, value in stack_parameters.items():
+ # TODO: we want to support other types of parameters
+ if value.get("ParameterType") == "CommaDelimitedList":
+ formatted_stack_parameters[key] = value.get("ParameterValue").split(",")
+ else:
+ formatted_stack_parameters[key] = value.get("ParameterValue")
+
+ transformation_id = f"{account_id}::{macro['Name']}"
+ event = {
+ "region": region_name,
+ "accountId": account_id,
+ "fragment": parsed_template,
+ "transformId": transformation_id,
+ "params": transformation_parameters,
+ "requestId": long_uid(),
+ "templateParameterValues": formatted_stack_parameters,
+ }
+
+ client = connect_to(aws_access_key_id=account_id, region_name=region_name).lambda_
+ try:
+ invocation = client.invoke(
+ FunctionName=macro_definition["FunctionName"], Payload=json.dumps(event)
+ )
+ except ClientError:
+ LOG.error(
+ "client error executing lambda function '%s' with payload '%s'",
+ macro_definition["FunctionName"],
+ json.dumps(event),
+ )
+ raise
+ if invocation.get("StatusCode") != 200 or invocation.get("FunctionError") == "Unhandled":
+ raise FailedTransformationException(
+ transformation=macro["Name"],
+ message=f"Received malformed response from transform {transformation_id}. Rollback requested by user.",
+ )
+ result = json.loads(invocation["Payload"].read())
+
+ if result.get("status") != "success":
+ error_message = result.get("errorMessage")
+ message = (
+ f"Transform {transformation_id} failed with: {error_message}. Rollback requested by user."
+ if error_message
+ else f"Transform {transformation_id} failed without an error message.. Rollback requested by user."
+ )
+ raise FailedTransformationException(transformation=macro["Name"], message=message)
+
+ if not isinstance(result.get("fragment"), dict) and not is_intrinsic:
+ raise FailedTransformationException(
+ transformation=macro["Name"],
+ message="Template format error: unsupported structure.. Rollback requested by user.",
+ )
+
+ return result.get("fragment")
+
+
+def apply_serverless_transformation(
+ account_id: str, region_name: str, parsed_template: dict, template_parameters: dict
+) -> Optional[str]:
+ """only returns string when parsing SAM template, otherwise None"""
+ # TODO: we might also want to override the access key ID to account ID
+ region_before = os.environ.get("AWS_DEFAULT_REGION")
+ if boto3.session.Session().region_name is None:
+ os.environ["AWS_DEFAULT_REGION"] = region_name
+ loader = create_policy_loader()
+ simplified_parameters = {
+ k: v.get("ResolvedValue") or v["ParameterValue"] for k, v in template_parameters.items()
+ }
+
+ try:
+ transformed = transform_sam(parsed_template, simplified_parameters, loader)
+ return transformed
+ except Exception as e:
+ raise FailedTransformationException(transformation=SERVERLESS_TRANSFORM, message=str(e))
+ finally:
+ # Note: we need to fix boto3 region, otherwise AWS SAM transformer fails
+ os.environ.pop("AWS_DEFAULT_REGION", None)
+ if region_before is not None:
+ os.environ["AWS_DEFAULT_REGION"] = region_before
+
+
+class FailedTransformationException(Exception):
+ transformation: str
+ msg: str
+
+ def __init__(self, transformation: str, message: str = ""):
+ self.transformation = transformation
+ self.message = message
+ super().__init__(self.message)
diff --git a/localstack-core/localstack/services/cloudformation/engine/types.py b/localstack-core/localstack/services/cloudformation/engine/types.py
new file mode 100644
index 0000000000000..2a4f6efa06031
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/engine/types.py
@@ -0,0 +1,45 @@
+from typing import Any, Callable, Optional, TypedDict
+
+# ---------------------
+# TYPES
+# ---------------------
+
+# Callable here takes the arguments:
+# - resource_props
+# - stack_name
+# - resources
+# - resource_id
+ResourceProp = str | Callable[[dict, str, dict, str], dict]
+ResourceDefinition = dict[str, ResourceProp]
+
+
+class FuncDetailsValue(TypedDict):
+ # Callable here takes the arguments:
+ # - logical_resource_id
+ # - resource
+ # - stack_name
+ function: str | Callable[[str, dict, str], Any]
+ """Either an api method to call directly with `parameters` or a callable to directly invoke"""
+ # Callable here takes the arguments:
+ # - resource_props
+ # - stack_name
+ # - resources
+ # - resource_id
+ parameters: Optional[ResourceDefinition | Callable[[dict, str, list[dict], str], dict]]
+ """arguments to the function, or a function that generates the arguments to the function"""
+ # Callable here takes the arguments
+ # - result
+ # - resource_id
+ # - resources
+ # - resource_type
+ result_handler: Optional[Callable[[dict, str, list[dict], str], None]]
+ """Take the result of the operation and patch the state of the resources, yuck..."""
+ types: Optional[dict[str, Callable]]
+ """Possible type conversions"""
+
+
+# Type definition for func_details supplied to invoke_function
+FuncDetails = list[FuncDetailsValue] | FuncDetailsValue
+
+# Type definition returned by GenericBaseModel.get_deploy_templates
+DeployTemplates = dict[str, FuncDetails | Callable]
diff --git a/localstack-core/localstack/services/cloudformation/engine/validations.py b/localstack-core/localstack/services/cloudformation/engine/validations.py
new file mode 100644
index 0000000000000..c65d0a5b307fc
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/engine/validations.py
@@ -0,0 +1,86 @@
+"""
+Provide validations for use within the CFn engine
+"""
+
+from typing import Protocol
+
+from localstack.aws.api import CommonServiceException
+
+
+class ValidationError(CommonServiceException):
+ """General validation error type (defined in the AWS docs, but not part of the botocore spec)"""
+
+ def __init__(self, message=None):
+ super().__init__("ValidationError", message=message, sender_fault=True)
+
+
+class TemplateValidationStep(Protocol):
+ """
+ Base class for static analysis of the template
+ """
+
+ def __call__(self, template: dict):
+ """
+ Execute a specific validation on the template
+ """
+
+
+def outputs_have_values(template: dict):
+ outputs: dict[str, dict] = template.get("Outputs", {})
+
+ for output_name, output_defn in outputs.items():
+ if "Value" not in output_defn:
+ raise ValidationError(
+ "Template format error: Every Outputs member must contain a Value object"
+ )
+
+ if output_defn["Value"] is None:
+ key = f"/Outputs/{output_name}/Value"
+ raise ValidationError(f"[{key}] 'null' values are not allowed in templates")
+
+
+# TODO: this would need to be split into different validations pre- and post- transform
+def resources_top_level_keys(template: dict):
+ """
+ Validate that each resource
+ - there is a resources key
+ - includes the `Properties` key
+ - does not include any other keys that should not be there
+ """
+ resources = template.get("Resources")
+ if resources is None:
+ raise ValidationError(
+ "Template format error: At least one Resources member must be defined."
+ )
+
+ allowed_keys = {
+ "Type",
+ "Properties",
+ "DependsOn",
+ "CreationPolicy",
+ "DeletionPolicy",
+ "Metadata",
+ "UpdatePolicy",
+ "UpdateReplacePolicy",
+ "Condition",
+ }
+ for resource_id, resource in resources.items():
+ if "Type" not in resource:
+ raise ValidationError(
+ f"Template format error: [/Resources/{resource_id}] Every Resources object must contain a Type member."
+ )
+
+ # check for invalid keys
+ for key in resource:
+ if key not in allowed_keys:
+ raise ValidationError(f"Invalid template resource property '{key}'")
+
+
+DEFAULT_TEMPLATE_VALIDATIONS: list[TemplateValidationStep] = [
+ # FIXME: disabled for now due to the template validation not fitting well with the template that we use here.
+ # We don't have access to a "raw" processed template here and it's questionable if we should have it at all,
+ # since later transformations can again introduce issues.
+ # => Reevaluate this when reworking how we mutate the template dict in the provider
+ # outputs_have_values,
+ # resources_top_level_keys,
+]
diff --git a/localstack-core/localstack/services/cloudformation/engine/yaml_parser.py b/localstack-core/localstack/services/cloudformation/engine/yaml_parser.py
new file mode 100644
index 0000000000000..c0b72ead58f8f
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/engine/yaml_parser.py
@@ -0,0 +1,64 @@
+import yaml
+
+
+def construct_raw(_, node):
+ return node.value
+
+
+class NoDatesSafeLoader(yaml.SafeLoader):
+ @classmethod
+ def remove_tag_constructor(cls, tag):
+ """
+ Remove the YAML constructor for a given tag and replace it with a raw constructor
+ """
+ # needed to make sure we're not changing the constructors of the base class
+ # otherwise usage across the code base is affected as well
+ if "yaml_constructors" not in cls.__dict__:
+ cls.yaml_constructors = cls.yaml_constructors.copy()
+
+ cls.yaml_constructors[tag] = construct_raw
+
+
+NoDatesSafeLoader.remove_tag_constructor("tag:yaml.org,2002:timestamp")
+
+
+def shorthand_constructor(loader: yaml.Loader, tag_suffix: str, node: yaml.Node):
+ """
+ TODO: proper exceptions (introduce this when fixing the provider)
+ TODO: fix select & split (is this even necessary?)
+ { "Fn::Select" : [ "2", { "Fn::Split": [",", {"Fn::ImportValue": "AccountSubnetIDs"}]}] }
+ !Select [2, !Split [",", !ImportValue AccountSubnetIDs]]
+ shorthand: 2 => canonical "2"
+ """
+ match tag_suffix:
+ case "Ref":
+ fn_name = "Ref"
+ case "Condition":
+ fn_name = "Condition"
+ case _:
+ fn_name = f"Fn::{tag_suffix}"
+
+ if tag_suffix == "GetAtt" and isinstance(node, yaml.ScalarNode):
+ # !GetAtt A.B.C => {"Fn::GetAtt": ["A", "B.C"]}
+ parts = node.value.partition(".")
+ if len(parts) != 3:
+ raise ValueError(f"Node value contains unexpected format for !GetAtt: {parts}")
+ return {fn_name: [parts[0], parts[2]]}
+
+ if isinstance(node, yaml.ScalarNode):
+ return {fn_name: node.value}
+ elif isinstance(node, yaml.SequenceNode):
+ return {fn_name: loader.construct_sequence(node)}
+ elif isinstance(node, yaml.MappingNode):
+ return {fn_name: loader.construct_mapping(node)}
+ else:
+ raise ValueError(f"Unexpected yaml Node type: {type(node)}")
+
+
+customloader = NoDatesSafeLoader
+
+yaml.add_multi_constructor("!", shorthand_constructor, customloader)
+
+
+def parse_yaml(input_data: str):
+ return yaml.load(input_data, customloader)
diff --git a/localstack-core/localstack/services/cloudformation/models/__init__.py b/localstack-core/localstack/services/cloudformation/models/__init__.py
new file mode 100644
index 0000000000000..a9a2c5b3bb437
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/models/__init__.py
@@ -0,0 +1 @@
+__all__ = []
diff --git a/localstack-core/localstack/services/cloudformation/plugins.py b/localstack-core/localstack/services/cloudformation/plugins.py
new file mode 100644
index 0000000000000..72ef0104aaeb2
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/plugins.py
@@ -0,0 +1,12 @@
+from rolo import Resource
+
+from localstack.runtime import hooks
+
+
+@hooks.on_infra_start()
+def register_cloudformation_deploy_ui():
+ from localstack.services.internal import get_internal_apis
+
+ from .deploy_ui import CloudFormationUi
+
+ get_internal_apis().add(Resource("/_localstack/cloudformation/deploy", CloudFormationUi()))
diff --git a/localstack-core/localstack/services/cloudformation/provider.py b/localstack-core/localstack/services/cloudformation/provider.py
new file mode 100644
index 0000000000000..7bf2a110a9d9f
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/provider.py
@@ -0,0 +1,1326 @@
+import copy
+import json
+import logging
+import re
+from collections import defaultdict
+from copy import deepcopy
+
+from localstack.aws.api import CommonServiceException, RequestContext, handler
+from localstack.aws.api.cloudformation import (
+ AlreadyExistsException,
+ CallAs,
+ ChangeSetNameOrId,
+ ChangeSetNotFoundException,
+ ChangeSetType,
+ ClientRequestToken,
+ CloudformationApi,
+ CreateChangeSetInput,
+ CreateChangeSetOutput,
+ CreateStackInput,
+ CreateStackInstancesInput,
+ CreateStackInstancesOutput,
+ CreateStackOutput,
+ CreateStackSetInput,
+ CreateStackSetOutput,
+ DeleteChangeSetOutput,
+ DeleteStackInstancesInput,
+ DeleteStackInstancesOutput,
+ DeleteStackSetOutput,
+ DeletionMode,
+ DescribeChangeSetOutput,
+ DescribeStackEventsOutput,
+ DescribeStackResourceOutput,
+ DescribeStackResourcesOutput,
+ DescribeStackSetOperationOutput,
+ DescribeStackSetOutput,
+ DescribeStacksOutput,
+ DisableRollback,
+ EnableTerminationProtection,
+ ExecuteChangeSetOutput,
+ ExecutionStatus,
+ ExportName,
+ GetTemplateOutput,
+ GetTemplateSummaryInput,
+ GetTemplateSummaryOutput,
+ IncludePropertyValues,
+ InsufficientCapabilitiesException,
+ InvalidChangeSetStatusException,
+ ListChangeSetsOutput,
+ ListExportsOutput,
+ ListImportsOutput,
+ ListStackInstancesInput,
+ ListStackInstancesOutput,
+ ListStackResourcesOutput,
+ ListStackSetsInput,
+ ListStackSetsOutput,
+ ListStacksOutput,
+ ListTypesInput,
+ ListTypesOutput,
+ LogicalResourceId,
+ NextToken,
+ Parameter,
+ PhysicalResourceId,
+ RegisterTypeInput,
+ RegisterTypeOutput,
+ RegistryType,
+ RetainExceptOnCreate,
+ RetainResources,
+ RoleARN,
+ StackName,
+ StackNameOrId,
+ StackSetName,
+ StackStatus,
+ StackStatusFilter,
+ TemplateParameter,
+ TemplateStage,
+ TypeSummary,
+ UpdateStackInput,
+ UpdateStackOutput,
+ UpdateStackSetInput,
+ UpdateStackSetOutput,
+ UpdateTerminationProtectionOutput,
+ ValidateTemplateInput,
+ ValidateTemplateOutput,
+)
+from localstack.aws.connect import connect_to
+from localstack.services.cloudformation import api_utils
+from localstack.services.cloudformation.engine import parameters as param_resolver
+from localstack.services.cloudformation.engine import template_deployer, template_preparer
+from localstack.services.cloudformation.engine.entities import (
+ Stack,
+ StackChangeSet,
+ StackInstance,
+ StackSet,
+)
+from localstack.services.cloudformation.engine.parameters import mask_no_echo, strip_parameter_type
+from localstack.services.cloudformation.engine.resource_ordering import (
+ NoResourceInStack,
+ order_resources,
+)
+from localstack.services.cloudformation.engine.template_deployer import (
+ NoStackUpdates,
+)
+from localstack.services.cloudformation.engine.template_utils import resolve_stack_conditions
+from localstack.services.cloudformation.engine.transformers import (
+ FailedTransformationException,
+)
+from localstack.services.cloudformation.engine.validations import (
+ DEFAULT_TEMPLATE_VALIDATIONS,
+ ValidationError,
+)
+from localstack.services.cloudformation.resource_provider import (
+ PRO_RESOURCE_PROVIDERS,
+ ResourceProvider,
+)
+from localstack.services.cloudformation.stores import (
+ cloudformation_stores,
+ find_active_stack_by_name_or_id,
+ find_change_set,
+ find_stack,
+ find_stack_by_id,
+ get_cloudformation_store,
+)
+from localstack.state import StateVisitor
+from localstack.utils.collections import (
+ remove_attributes,
+ select_attributes,
+ select_from_typed_dict,
+)
+from localstack.utils.json import clone
+from localstack.utils.strings import long_uid, short_uid
+
+LOG = logging.getLogger(__name__)
+
+ARN_CHANGESET_REGEX = re.compile(
+ r"arn:(aws|aws-us-gov|aws-cn):cloudformation:[-a-zA-Z0-9]+:\d{12}:changeSet/[a-zA-Z][-a-zA-Z0-9]*/[-a-zA-Z0-9:/._+]+"
+)
+ARN_STACK_REGEX = re.compile(
+ r"arn:(aws|aws-us-gov|aws-cn):cloudformation:[-a-zA-Z0-9]+:\d{12}:stack/[a-zA-Z][-a-zA-Z0-9]*/[-a-zA-Z0-9:/._+]+"
+)
+
+
+def clone_stack_params(stack_params):
+ try:
+ return clone(stack_params)
+ except Exception as e:
+ LOG.info("Unable to clone stack parameters: %s", e)
+ return stack_params
+
+
+def find_stack_instance(stack_set: StackSet, account: str, region: str):
+ for instance in stack_set.stack_instances:
+ if instance.metadata["Account"] == account and instance.metadata["Region"] == region:
+ return instance
+ return None
+
+
+def stack_not_found_error(stack_name: str):
+ # FIXME
+ raise ValidationError("Stack with id %s does not exist" % stack_name)
+
+
+def not_found_error(message: str):
+ # FIXME
+ raise ResourceNotFoundException(message)
+
+
+class ResourceNotFoundException(CommonServiceException):
+ def __init__(self, message=None):
+ super().__init__("ResourceNotFoundException", status_code=404, message=message)
+
+
+class InternalFailure(CommonServiceException):
+ def __init__(self, message=None):
+ super().__init__("InternalFailure", status_code=500, message=message, sender_fault=False)
+
+
+class CloudformationProvider(CloudformationApi):
+ def _stack_status_is_active(self, stack_status: str) -> bool:
+ return stack_status not in [StackStatus.DELETE_COMPLETE]
+
+ def accept_state_visitor(self, visitor: StateVisitor):
+ visitor.visit(cloudformation_stores)
+
+ @handler("CreateStack", expand=False)
+ def create_stack(self, context: RequestContext, request: CreateStackInput) -> CreateStackOutput:
+ # TODO: test what happens when both TemplateUrl and Body are specified
+ state = get_cloudformation_store(context.account_id, context.region)
+
+ stack_name = request.get("StackName")
+
+ # get stacks by name
+ active_stack_candidates = [
+ s
+ for s in state.stacks.values()
+ if s.stack_name == stack_name and self._stack_status_is_active(s.status)
+ ]
+
+ # TODO: fix/implement this code path
+ # this needs more investigation how Cloudformation handles it (e.g. normal stack create or does it create a separate changeset?)
+ # REVIEW_IN_PROGRESS is another special status
+ # in this case existing changesets are set to obsolete and the stack is created
+ # review_stack_candidates = [s for s in stack_candidates if s.status == StackStatus.REVIEW_IN_PROGRESS]
+ # if review_stack_candidates:
+ # set changesets to obsolete
+ # for cs in review_stack_candidates[0].change_sets:
+ # cs.execution_status = ExecutionStatus.OBSOLETE
+
+ if active_stack_candidates:
+ raise AlreadyExistsException(f"Stack [{stack_name}] already exists")
+
+ template_body = request.get("TemplateBody") or ""
+ if len(template_body) > 51200:
+ raise ValidationError(
+ f"1 validation error detected: Value '{request['TemplateBody']}' at 'templateBody' "
+ "failed to satisfy constraint: Member must have length less than or equal to 51200"
+ )
+ api_utils.prepare_template_body(request) # TODO: avoid mutating request directly
+
+ template = template_preparer.parse_template(request["TemplateBody"])
+
+ stack_name = template["StackName"] = request.get("StackName")
+ if api_utils.validate_stack_name(stack_name) is False:
+ raise ValidationError(
+ f"1 validation error detected: Value '{stack_name}' at 'stackName' failed to satisfy constraint:\
+ Member must satisfy regular expression pattern: [a-zA-Z][-a-zA-Z0-9]*|arn:[-a-zA-Z0-9:/._+]*"
+ )
+
+ if (
+ "CAPABILITY_AUTO_EXPAND" not in request.get("Capabilities", [])
+ and "Transform" in template.keys()
+ ):
+ raise InsufficientCapabilitiesException(
+ "Requires capabilities : [CAPABILITY_AUTO_EXPAND]"
+ )
+
+ # resolve stack parameters
+ new_parameters = param_resolver.convert_stack_parameters_to_dict(request.get("Parameters"))
+ parameter_declarations = param_resolver.extract_stack_parameter_declarations(template)
+ resolved_parameters = param_resolver.resolve_parameters(
+ account_id=context.account_id,
+ region_name=context.region,
+ parameter_declarations=parameter_declarations,
+ new_parameters=new_parameters,
+ old_parameters={},
+ )
+
+ # handle conditions
+ stack = Stack(context.account_id, context.region, request, template)
+
+ try:
+ template = template_preparer.transform_template(
+ context.account_id,
+ context.region,
+ template,
+ stack.stack_name,
+ stack.resources,
+ stack.mappings,
+ {}, # TODO
+ resolved_parameters,
+ )
+ except FailedTransformationException as e:
+ stack.add_stack_event(
+ stack.stack_name,
+ stack.stack_id,
+ status="ROLLBACK_IN_PROGRESS",
+ status_reason=e.message,
+ )
+ stack.set_stack_status("ROLLBACK_COMPLETE")
+ state.stacks[stack.stack_id] = stack
+ return CreateStackOutput(StackId=stack.stack_id)
+
+ # perform basic static analysis on the template
+ for validation_fn in DEFAULT_TEMPLATE_VALIDATIONS:
+ validation_fn(template)
+
+ stack = Stack(context.account_id, context.region, request, template)
+
+ # resolve conditions
+ raw_conditions = template.get("Conditions", {})
+ resolved_stack_conditions = resolve_stack_conditions(
+ account_id=context.account_id,
+ region_name=context.region,
+ conditions=raw_conditions,
+ parameters=resolved_parameters,
+ mappings=stack.mappings,
+ stack_name=stack_name,
+ )
+ stack.set_resolved_stack_conditions(resolved_stack_conditions)
+
+ stack.set_resolved_parameters(resolved_parameters)
+ stack.template_body = template_body
+ state.stacks[stack.stack_id] = stack
+ LOG.debug(
+ 'Creating stack "%s" with %s resources ...',
+ stack.stack_name,
+ len(stack.template_resources),
+ )
+ deployer = template_deployer.TemplateDeployer(context.account_id, context.region, stack)
+ try:
+ deployer.deploy_stack()
+ except Exception as e:
+ stack.set_stack_status("CREATE_FAILED")
+ msg = 'Unable to create stack "%s": %s' % (stack.stack_name, e)
+ LOG.exception("%s")
+ raise ValidationError(msg) from e
+
+ return CreateStackOutput(StackId=stack.stack_id)
+
+ @handler("DeleteStack")
+ def delete_stack(
+ self,
+ context: RequestContext,
+ stack_name: StackName,
+ retain_resources: RetainResources = None,
+ role_arn: RoleARN = None,
+ client_request_token: ClientRequestToken = None,
+ deletion_mode: DeletionMode = None,
+ **kwargs,
+ ) -> None:
+ stack = find_active_stack_by_name_or_id(context.account_id, context.region, stack_name)
+ if not stack:
+ # aws will silently ignore invalid stack names - we should do the same
+ return
+ deployer = template_deployer.TemplateDeployer(context.account_id, context.region, stack)
+ deployer.delete_stack()
+
+ @handler("UpdateStack", expand=False)
+ def update_stack(
+ self,
+ context: RequestContext,
+ request: UpdateStackInput,
+ ) -> UpdateStackOutput:
+ stack_name = request.get("StackName")
+ stack = find_stack(context.account_id, context.region, stack_name)
+ if not stack:
+ return not_found_error(f'Unable to update non-existing stack "{stack_name}"')
+
+ api_utils.prepare_template_body(request)
+ template = template_preparer.parse_template(request["TemplateBody"])
+
+ if (
+ "CAPABILITY_AUTO_EXPAND" not in request.get("Capabilities", [])
+ and "Transform" in template.keys()
+ ):
+ raise InsufficientCapabilitiesException(
+ "Requires capabilities : [CAPABILITY_AUTO_EXPAND]"
+ )
+
+ new_parameters: dict[str, Parameter] = param_resolver.convert_stack_parameters_to_dict(
+ request.get("Parameters")
+ )
+ parameter_declarations = param_resolver.extract_stack_parameter_declarations(template)
+ resolved_parameters = param_resolver.resolve_parameters(
+ account_id=context.account_id,
+ region_name=context.region,
+ parameter_declarations=parameter_declarations,
+ new_parameters=new_parameters,
+ old_parameters=stack.resolved_parameters,
+ )
+
+ resolved_stack_conditions = resolve_stack_conditions(
+ account_id=context.account_id,
+ region_name=context.region,
+ conditions=template.get("Conditions", {}),
+ parameters=resolved_parameters,
+ mappings=template.get("Mappings", {}),
+ stack_name=stack_name,
+ )
+
+ raw_new_template = copy.deepcopy(template)
+ try:
+ template = template_preparer.transform_template(
+ context.account_id,
+ context.region,
+ template,
+ stack.stack_name,
+ stack.resources,
+ stack.mappings,
+ resolved_stack_conditions,
+ resolved_parameters,
+ )
+ processed_template = copy.deepcopy(
+ template
+ ) # copying it here since it's being mutated somewhere downstream
+ except FailedTransformationException as e:
+ stack.add_stack_event(
+ stack.stack_name,
+ stack.stack_id,
+ status="ROLLBACK_IN_PROGRESS",
+ status_reason=e.message,
+ )
+ stack.set_stack_status("ROLLBACK_COMPLETE")
+ return CreateStackOutput(StackId=stack.stack_id)
+
+ # perform basic static analysis on the template
+ for validation_fn in DEFAULT_TEMPLATE_VALIDATIONS:
+ validation_fn(template)
+
+ # update the template
+ stack.template_original = template
+
+ deployer = template_deployer.TemplateDeployer(context.account_id, context.region, stack)
+ # TODO: there shouldn't be a "new" stack on update
+ new_stack = Stack(
+ context.account_id, context.region, request, template, request["TemplateBody"]
+ )
+ new_stack.set_resolved_parameters(resolved_parameters)
+ stack.set_resolved_parameters(resolved_parameters)
+ stack.set_resolved_stack_conditions(resolved_stack_conditions)
+ try:
+ deployer.update_stack(new_stack)
+ except NoStackUpdates as e:
+ stack.set_stack_status("UPDATE_COMPLETE")
+ if raw_new_template != processed_template:
+ # processed templates seem to never return an exception here
+ return UpdateStackOutput(StackId=stack.stack_id)
+ raise ValidationError(str(e))
+ except Exception as e:
+ stack.set_stack_status("UPDATE_FAILED")
+ msg = f'Unable to update stack "{stack_name}": {e}'
+ LOG.exception("%s", msg)
+ raise ValidationError(msg) from e
+
+ return UpdateStackOutput(StackId=stack.stack_id)
+
+ @handler("DescribeStacks")
+ def describe_stacks(
+ self,
+ context: RequestContext,
+ stack_name: StackName = None,
+ next_token: NextToken = None,
+ **kwargs,
+ ) -> DescribeStacksOutput:
+ # TODO: test & implement pagination
+ state = get_cloudformation_store(context.account_id, context.region)
+
+ if stack_name:
+ if ARN_STACK_REGEX.match(stack_name):
+ # we can get the stack directly since we index the store by ARN/stackID
+ stack = state.stacks.get(stack_name)
+ stacks = [stack.describe_details()] if stack else []
+ else:
+ # otherwise we have to find the active stack with the given name
+ stack_candidates: list[Stack] = [
+ s for stack_arn, s in state.stacks.items() if s.stack_name == stack_name
+ ]
+ active_stack_candidates = [
+ s for s in stack_candidates if self._stack_status_is_active(s.status)
+ ]
+ stacks = [s.describe_details() for s in active_stack_candidates]
+ else:
+ # return all active stacks
+ stack_list = list(state.stacks.values())
+ stacks = [
+ s.describe_details() for s in stack_list if self._stack_status_is_active(s.status)
+ ]
+
+ if stack_name and not stacks:
+ raise ValidationError(f"Stack with id {stack_name} does not exist")
+
+ return DescribeStacksOutput(Stacks=stacks)
+
+ @handler("ListStacks")
+ def list_stacks(
+ self,
+ context: RequestContext,
+ next_token: NextToken = None,
+ stack_status_filter: StackStatusFilter = None,
+ **kwargs,
+ ) -> ListStacksOutput:
+ state = get_cloudformation_store(context.account_id, context.region)
+
+ stacks = [
+ s.describe_details()
+ for s in state.stacks.values()
+ if not stack_status_filter or s.status in stack_status_filter
+ ]
+
+ attrs = [
+ "StackId",
+ "StackName",
+ "TemplateDescription",
+ "CreationTime",
+ "LastUpdatedTime",
+ "DeletionTime",
+ "StackStatus",
+ "StackStatusReason",
+ "ParentId",
+ "RootId",
+ "DriftInformation",
+ ]
+ stacks = [select_attributes(stack, attrs) for stack in stacks]
+ return ListStacksOutput(StackSummaries=stacks)
+
+ @handler("GetTemplate")
+ def get_template(
+ self,
+ context: RequestContext,
+ stack_name: StackName = None,
+ change_set_name: ChangeSetNameOrId = None,
+ template_stage: TemplateStage = None,
+ **kwargs,
+ ) -> GetTemplateOutput:
+ if change_set_name:
+ stack = find_change_set(
+ context.account_id, context.region, stack_name=stack_name, cs_name=change_set_name
+ )
+ else:
+ stack = find_stack(context.account_id, context.region, stack_name)
+ if not stack:
+ return stack_not_found_error(stack_name)
+
+ if template_stage == TemplateStage.Processed and "Transform" in stack.template_body:
+ copy_template = clone(stack.template_original)
+ copy_template.pop("ChangeSetName", None)
+ copy_template.pop("StackName", None)
+ for resource in copy_template.get("Resources", {}).values():
+ resource.pop("LogicalResourceId", None)
+ template_body = json.dumps(copy_template)
+ else:
+ template_body = stack.template_body
+
+ return GetTemplateOutput(
+ TemplateBody=template_body,
+ StagesAvailable=[TemplateStage.Original, TemplateStage.Processed],
+ )
+
+ @handler("GetTemplateSummary", expand=False)
+ def get_template_summary(
+ self,
+ context: RequestContext,
+ request: GetTemplateSummaryInput,
+ ) -> GetTemplateSummaryOutput:
+ stack_name = request.get("StackName")
+
+ if stack_name:
+ stack = find_stack(context.account_id, context.region, stack_name)
+ if not stack:
+ return stack_not_found_error(stack_name)
+ template = stack.template
+ else:
+ api_utils.prepare_template_body(request)
+ template = template_preparer.parse_template(request["TemplateBody"])
+ request["StackName"] = "tmp-stack"
+ stack = Stack(context.account_id, context.region, request, template)
+
+ result: GetTemplateSummaryOutput = stack.describe_details()
+
+ # build parameter declarations
+ result["Parameters"] = list(
+ param_resolver.extract_stack_parameter_declarations(template).values()
+ )
+
+ id_summaries = defaultdict(list)
+ for resource_id, resource in stack.template_resources.items():
+ res_type = resource["Type"]
+ id_summaries[res_type].append(resource_id)
+
+ result["ResourceTypes"] = list(id_summaries.keys())
+ result["ResourceIdentifierSummaries"] = [
+ {"ResourceType": key, "LogicalResourceIds": values}
+ for key, values in id_summaries.items()
+ ]
+ result["Metadata"] = stack.template.get("Metadata")
+ result["Version"] = stack.template.get("AWSTemplateFormatVersion", "2010-09-09")
+ # these do not appear in the output
+ result.pop("Capabilities", None)
+
+ return select_from_typed_dict(GetTemplateSummaryOutput, result)
+
+ def update_termination_protection(
+ self,
+ context: RequestContext,
+ enable_termination_protection: EnableTerminationProtection,
+ stack_name: StackNameOrId,
+ **kwargs,
+ ) -> UpdateTerminationProtectionOutput:
+ stack = find_stack(context.account_id, context.region, stack_name)
+ if not stack:
+ raise ValidationError(f"Stack '{stack_name}' does not exist.")
+ stack.metadata["EnableTerminationProtection"] = enable_termination_protection
+ return UpdateTerminationProtectionOutput(StackId=stack.stack_id)
+
+ @handler("CreateChangeSet", expand=False)
+ def create_change_set(
+ self, context: RequestContext, request: CreateChangeSetInput
+ ) -> CreateChangeSetOutput:
+ state = get_cloudformation_store(context.account_id, context.region)
+
+ req_params = request
+ change_set_type = req_params.get("ChangeSetType", "UPDATE")
+ stack_name = req_params.get("StackName")
+ change_set_name = req_params.get("ChangeSetName")
+ template_body = req_params.get("TemplateBody")
+ # s3 or secretsmanager url
+ template_url = req_params.get("TemplateURL")
+
+ # validate and resolve template
+ if template_body and template_url:
+ raise ValidationError(
+ "Specify exactly one of 'TemplateBody' or 'TemplateUrl'"
+ ) # TODO: check proper message
+
+ if not template_body and not template_url:
+ raise ValidationError(
+ "Specify exactly one of 'TemplateBody' or 'TemplateUrl'"
+ ) # TODO: check proper message
+
+ api_utils.prepare_template_body(
+ req_params
+ ) # TODO: function has too many unclear responsibilities
+ if not template_body:
+ template_body = req_params[
+ "TemplateBody"
+ ] # should then have been set by prepare_template_body
+ template = template_preparer.parse_template(req_params["TemplateBody"])
+
+ del req_params["TemplateBody"] # TODO: stop mutating req_params
+ template["StackName"] = stack_name
+ # TODO: validate with AWS what this is actually doing?
+ template["ChangeSetName"] = change_set_name
+
+ # this is intentionally not in a util yet. Let's first see how the different operations deal with these before generalizing
+ # handle ARN stack_name here (not valid for initial CREATE, since stack doesn't exist yet)
+ if ARN_STACK_REGEX.match(stack_name):
+ if not (stack := state.stacks.get(stack_name)):
+ raise ValidationError(f"Stack '{stack_name}' does not exist.")
+ else:
+ # stack name specified, so fetch the stack by name
+ stack_candidates: list[Stack] = [
+ s for stack_arn, s in state.stacks.items() if s.stack_name == stack_name
+ ]
+ active_stack_candidates = [
+ s for s in stack_candidates if self._stack_status_is_active(s.status)
+ ]
+
+ # on a CREATE an empty Stack should be generated if we didn't find an active one
+ if not active_stack_candidates and change_set_type == ChangeSetType.CREATE:
+ empty_stack_template = dict(template)
+ empty_stack_template["Resources"] = {}
+ req_params_copy = clone_stack_params(req_params)
+ stack = Stack(
+ context.account_id,
+ context.region,
+ req_params_copy,
+ empty_stack_template,
+ template_body=template_body,
+ )
+ state.stacks[stack.stack_id] = stack
+ stack.set_stack_status("REVIEW_IN_PROGRESS")
+ else:
+ if not active_stack_candidates:
+ raise ValidationError(f"Stack '{stack_name}' does not exist.")
+ stack = active_stack_candidates[0]
+
+ # TODO: test if rollback status is allowed as well
+ if (
+ change_set_type == ChangeSetType.CREATE
+ and stack.status != StackStatus.REVIEW_IN_PROGRESS
+ ):
+ raise ValidationError(
+ f"Stack [{stack_name}] already exists and cannot be created again with the changeSet [{change_set_name}]."
+ )
+
+ old_parameters: dict[str, Parameter] = {}
+ match change_set_type:
+ case ChangeSetType.UPDATE:
+ # add changeset to existing stack
+ old_parameters = {
+ k: mask_no_echo(strip_parameter_type(v))
+ for k, v in stack.resolved_parameters.items()
+ }
+ case ChangeSetType.IMPORT:
+ raise NotImplementedError() # TODO: implement importing resources
+ case ChangeSetType.CREATE:
+ pass
+ case _:
+ msg = (
+ f"1 validation error detected: Value '{change_set_type}' at 'changeSetType' failed to satisfy "
+ f"constraint: Member must satisfy enum value set: [IMPORT, UPDATE, CREATE] "
+ )
+ raise ValidationError(msg)
+
+ # resolve parameters
+ new_parameters: dict[str, Parameter] = param_resolver.convert_stack_parameters_to_dict(
+ request.get("Parameters")
+ )
+ parameter_declarations = param_resolver.extract_stack_parameter_declarations(template)
+ resolved_parameters = param_resolver.resolve_parameters(
+ account_id=context.account_id,
+ region_name=context.region,
+ parameter_declarations=parameter_declarations,
+ new_parameters=new_parameters,
+ old_parameters=old_parameters,
+ )
+
+ # TODO: remove this when fixing Stack.resources and transformation order
+ # currently we need to create a stack with existing resources + parameters so that resolve refs recursively in here will work.
+ # The correct way to do it would be at a later stage anyway just like a normal intrinsic function
+ req_params_copy = clone_stack_params(req_params)
+ temp_stack = Stack(context.account_id, context.region, req_params_copy, template)
+ temp_stack.set_resolved_parameters(resolved_parameters)
+
+ # TODO: everything below should be async
+ # apply template transformations
+ transformed_template = template_preparer.transform_template(
+ context.account_id,
+ context.region,
+ template,
+ stack_name=temp_stack.stack_name,
+ resources=temp_stack.resources,
+ mappings=temp_stack.mappings,
+ conditions={}, # TODO: we don't have any resolved conditions yet at this point but we need the conditions because of the samtranslator...
+ resolved_parameters=resolved_parameters,
+ )
+
+ # perform basic static analysis on the template
+ for validation_fn in DEFAULT_TEMPLATE_VALIDATIONS:
+ validation_fn(template)
+
+ # create change set for the stack and apply changes
+ change_set = StackChangeSet(
+ context.account_id, context.region, stack, req_params, transformed_template
+ )
+ # only set parameters for the changeset, then switch to stack on execute_change_set
+ change_set.set_resolved_parameters(resolved_parameters)
+ change_set.template_body = template_body
+
+ # TODO: evaluate conditions
+ raw_conditions = transformed_template.get("Conditions", {})
+ resolved_stack_conditions = resolve_stack_conditions(
+ account_id=context.account_id,
+ region_name=context.region,
+ conditions=raw_conditions,
+ parameters=resolved_parameters,
+ mappings=temp_stack.mappings,
+ stack_name=stack_name,
+ )
+ change_set.set_resolved_stack_conditions(resolved_stack_conditions)
+
+ # a bit gross but use the template ordering to validate missing resources
+ try:
+ order_resources(
+ transformed_template["Resources"],
+ resolved_parameters=resolved_parameters,
+ resolved_conditions=resolved_stack_conditions,
+ )
+ except NoResourceInStack as e:
+ raise ValidationError(str(e)) from e
+
+ deployer = template_deployer.TemplateDeployer(
+ context.account_id, context.region, change_set
+ )
+ changes = deployer.construct_changes(
+ stack,
+ change_set,
+ change_set_id=change_set.change_set_id,
+ append_to_changeset=True,
+ filter_unchanged_resources=True,
+ )
+ stack.change_sets.append(change_set)
+ if not changes:
+ change_set.metadata["Status"] = "FAILED"
+ change_set.metadata["ExecutionStatus"] = "UNAVAILABLE"
+ change_set.metadata["StatusReason"] = (
+ "The submitted information didn't contain changes. Submit different information to create a change set."
+ )
+ else:
+ change_set.metadata["Status"] = (
+ "CREATE_COMPLETE" # technically for some time this should first be CREATE_PENDING
+ )
+ change_set.metadata["ExecutionStatus"] = (
+ "AVAILABLE" # technically for some time this should first be UNAVAILABLE
+ )
+
+ return CreateChangeSetOutput(StackId=change_set.stack_id, Id=change_set.change_set_id)
+
+ @handler("DescribeChangeSet")
+ def describe_change_set(
+ self,
+ context: RequestContext,
+ change_set_name: ChangeSetNameOrId,
+ stack_name: StackNameOrId = None,
+ next_token: NextToken = None,
+ include_property_values: IncludePropertyValues = None,
+ **kwargs,
+ ) -> DescribeChangeSetOutput:
+ # TODO add support for include_property_values
+ # only relevant if change_set_name isn't an ARN
+ if not ARN_CHANGESET_REGEX.match(change_set_name):
+ if not stack_name:
+ raise ValidationError(
+ "StackName must be specified if ChangeSetName is not specified as an ARN."
+ )
+
+ stack = find_stack(context.account_id, context.region, stack_name)
+ if not stack:
+ raise ValidationError(f"Stack [{stack_name}] does not exist")
+
+ change_set = find_change_set(
+ context.account_id, context.region, change_set_name, stack_name=stack_name
+ )
+ if not change_set:
+ raise ChangeSetNotFoundException(f"ChangeSet [{change_set_name}] does not exist")
+
+ attrs = [
+ "ChangeSetType",
+ "StackStatus",
+ "LastUpdatedTime",
+ "DisableRollback",
+ "EnableTerminationProtection",
+ "Transform",
+ ]
+ result = remove_attributes(deepcopy(change_set.metadata), attrs)
+ # TODO: replace this patch with a better solution
+ result["Parameters"] = [
+ mask_no_echo(strip_parameter_type(p)) for p in result.get("Parameters", [])
+ ]
+ return result
+
+ @handler("DeleteChangeSet")
+ def delete_change_set(
+ self,
+ context: RequestContext,
+ change_set_name: ChangeSetNameOrId,
+ stack_name: StackNameOrId = None,
+ **kwargs,
+ ) -> DeleteChangeSetOutput:
+ # only relevant if change_set_name isn't an ARN
+ if not ARN_CHANGESET_REGEX.match(change_set_name):
+ if not stack_name:
+ raise ValidationError(
+ "StackName must be specified if ChangeSetName is not specified as an ARN."
+ )
+
+ stack = find_stack(context.account_id, context.region, stack_name)
+ if not stack:
+ raise ValidationError(f"Stack [{stack_name}] does not exist")
+
+ change_set = find_change_set(
+ context.account_id, context.region, change_set_name, stack_name=stack_name
+ )
+ if not change_set:
+ raise ChangeSetNotFoundException(f"ChangeSet [{change_set_name}] does not exist")
+ change_set.stack.change_sets = [
+ cs
+ for cs in change_set.stack.change_sets
+ if change_set_name not in (cs.change_set_name, cs.change_set_id)
+ ]
+ return DeleteChangeSetOutput()
+
+ @handler("ExecuteChangeSet")
+ def execute_change_set(
+ self,
+ context: RequestContext,
+ change_set_name: ChangeSetNameOrId,
+ stack_name: StackNameOrId = None,
+ client_request_token: ClientRequestToken = None,
+ disable_rollback: DisableRollback = None,
+ retain_except_on_create: RetainExceptOnCreate = None,
+ **kwargs,
+ ) -> ExecuteChangeSetOutput:
+ change_set = find_change_set(
+ context.account_id,
+ context.region,
+ change_set_name,
+ stack_name=stack_name,
+ active_only=True,
+ )
+ if not change_set:
+ raise ChangeSetNotFoundException(f"ChangeSet [{change_set_name}] does not exist")
+ if change_set.metadata.get("ExecutionStatus") != ExecutionStatus.AVAILABLE:
+ LOG.debug("Change set %s not in execution status 'AVAILABLE'", change_set_name)
+ raise InvalidChangeSetStatusException(
+ f"ChangeSet [{change_set.metadata['ChangeSetId']}] cannot be executed in its current status of [{change_set.metadata.get('Status')}]"
+ )
+ stack_name = change_set.stack.stack_name
+ LOG.debug(
+ 'Executing change set "%s" for stack "%s" with %s resources ...',
+ change_set_name,
+ stack_name,
+ len(change_set.template_resources),
+ )
+ deployer = template_deployer.TemplateDeployer(
+ context.account_id, context.region, change_set.stack
+ )
+ try:
+ deployer.apply_change_set(change_set)
+ change_set.stack.metadata["ChangeSetId"] = change_set.change_set_id
+ except NoStackUpdates:
+ # TODO: parity-check if this exception should be re-raised or swallowed
+ raise ValidationError("No updates to be performed for stack change set")
+
+ return ExecuteChangeSetOutput()
+
+ @handler("ListChangeSets")
+ def list_change_sets(
+ self,
+ context: RequestContext,
+ stack_name: StackNameOrId,
+ next_token: NextToken = None,
+ **kwargs,
+ ) -> ListChangeSetsOutput:
+ stack = find_stack(context.account_id, context.region, stack_name)
+ if not stack:
+ return not_found_error(f'Unable to find stack "{stack_name}"')
+ result = [cs.metadata for cs in stack.change_sets]
+ return ListChangeSetsOutput(Summaries=result)
+
+ @handler("ListExports")
+ def list_exports(
+ self, context: RequestContext, next_token: NextToken = None, **kwargs
+ ) -> ListExportsOutput:
+ state = get_cloudformation_store(context.account_id, context.region)
+ return ListExportsOutput(Exports=state.exports)
+
+ @handler("ListImports")
+ def list_imports(
+ self,
+ context: RequestContext,
+ export_name: ExportName,
+ next_token: NextToken = None,
+ **kwargs,
+ ) -> ListImportsOutput:
+ state = get_cloudformation_store(context.account_id, context.region)
+
+ importing_stack_names = []
+ for stack in state.stacks.values():
+ if export_name in stack.imports:
+ importing_stack_names.append(stack.stack_name)
+
+ return ListImportsOutput(Imports=importing_stack_names)
+
+ @handler("DescribeStackEvents")
+ def describe_stack_events(
+ self,
+ context: RequestContext,
+ stack_name: StackName = None,
+ next_token: NextToken = None,
+ **kwargs,
+ ) -> DescribeStackEventsOutput:
+ if stack_name is None:
+ raise ValidationError(
+ "1 validation error detected: Value null at 'stackName' failed to satisfy constraint: Member must not be null"
+ )
+
+ stack = find_active_stack_by_name_or_id(context.account_id, context.region, stack_name)
+ if not stack:
+ stack = find_stack_by_id(
+ account_id=context.account_id, region_name=context.region, stack_id=stack_name
+ )
+ if not stack:
+ raise ValidationError(f"Stack [{stack_name}] does not exist")
+ return DescribeStackEventsOutput(StackEvents=stack.events)
+
+ @handler("DescribeStackResource")
+ def describe_stack_resource(
+ self,
+ context: RequestContext,
+ stack_name: StackName,
+ logical_resource_id: LogicalResourceId,
+ **kwargs,
+ ) -> DescribeStackResourceOutput:
+ stack = find_stack(context.account_id, context.region, stack_name)
+
+ if not stack:
+ return stack_not_found_error(stack_name)
+
+ details = stack.resource_status(logical_resource_id)
+ return DescribeStackResourceOutput(StackResourceDetail=details)
+
+ @handler("DescribeStackResources")
+ def describe_stack_resources(
+ self,
+ context: RequestContext,
+ stack_name: StackName = None,
+ logical_resource_id: LogicalResourceId = None,
+ physical_resource_id: PhysicalResourceId = None,
+ **kwargs,
+ ) -> DescribeStackResourcesOutput:
+ if physical_resource_id and stack_name:
+ raise ValidationError("Cannot specify both StackName and PhysicalResourceId")
+ # TODO: filter stack by PhysicalResourceId!
+ stack = find_stack(context.account_id, context.region, stack_name)
+ if not stack:
+ return stack_not_found_error(stack_name)
+ statuses = [
+ res_status
+ for res_id, res_status in stack.resource_states.items()
+ if logical_resource_id in [res_id, None]
+ ]
+ for status in statuses:
+ status.setdefault("DriftInformation", {"StackResourceDriftStatus": "NOT_CHECKED"})
+ return DescribeStackResourcesOutput(StackResources=statuses)
+
+ @handler("ListStackResources")
+ def list_stack_resources(
+ self, context: RequestContext, stack_name: StackName, next_token: NextToken = None, **kwargs
+ ) -> ListStackResourcesOutput:
+ result = self.describe_stack_resources(context, stack_name)
+
+ resources = deepcopy(result.get("StackResources", []))
+ for resource in resources:
+ attrs = ["StackName", "StackId", "Timestamp", "PreviousResourceStatus"]
+ remove_attributes(resource, attrs)
+
+ return ListStackResourcesOutput(StackResourceSummaries=resources)
+
+ @handler("ValidateTemplate", expand=False)
+ def validate_template(
+ self, context: RequestContext, request: ValidateTemplateInput
+ ) -> ValidateTemplateOutput:
+ try:
+ # TODO implement actual validation logic
+ template_body = api_utils.get_template_body(request)
+ valid_template = json.loads(template_preparer.template_to_json(template_body))
+
+ parameters = [
+ TemplateParameter(
+ ParameterKey=k,
+ DefaultValue=v.get("Default", ""),
+ NoEcho=v.get("NoEcho", False),
+ Description=v.get("Description", ""),
+ )
+ for k, v in valid_template.get("Parameters", {}).items()
+ ]
+
+ return ValidateTemplateOutput(
+ Description=valid_template.get("Description"), Parameters=parameters
+ )
+ except Exception as e:
+ LOG.exception("Error validating template")
+ raise ValidationError("Template Validation Error") from e
+
+ # =======================================
+ # ============= Stack Set =============
+ # =======================================
+
+ @handler("CreateStackSet", expand=False)
+ def create_stack_set(
+ self, context: RequestContext, request: CreateStackSetInput
+ ) -> CreateStackSetOutput:
+ state = get_cloudformation_store(context.account_id, context.region)
+ stack_set = StackSet(request)
+ stack_set_id = f"{stack_set.stack_set_name}:{long_uid()}"
+ stack_set.metadata["StackSetId"] = stack_set_id
+ state.stack_sets[stack_set_id] = stack_set
+
+ return CreateStackSetOutput(StackSetId=stack_set_id)
+
+ @handler("DescribeStackSetOperation")
+ def describe_stack_set_operation(
+ self,
+ context: RequestContext,
+ stack_set_name: StackSetName,
+ operation_id: ClientRequestToken,
+ call_as: CallAs = None,
+ **kwargs,
+ ) -> DescribeStackSetOperationOutput:
+ state = get_cloudformation_store(context.account_id, context.region)
+
+ set_name = stack_set_name
+
+ stack_set = [sset for sset in state.stack_sets.values() if sset.stack_set_name == set_name]
+ if not stack_set:
+ return not_found_error(f'Unable to find stack set "{set_name}"')
+ stack_set = stack_set[0]
+ result = stack_set.operations.get(operation_id)
+ if not result:
+ LOG.debug(
+ 'Unable to find operation ID "%s" for stack set "%s" in list: %s',
+ operation_id,
+ set_name,
+ list(stack_set.operations.keys()),
+ )
+ return not_found_error(
+ f'Unable to find operation ID "{operation_id}" for stack set "{set_name}"'
+ )
+
+ return DescribeStackSetOperationOutput(StackSetOperation=result)
+
+ @handler("DescribeStackSet")
+ def describe_stack_set(
+ self,
+ context: RequestContext,
+ stack_set_name: StackSetName,
+ call_as: CallAs = None,
+ **kwargs,
+ ) -> DescribeStackSetOutput:
+ state = get_cloudformation_store(context.account_id, context.region)
+ result = [
+ sset.metadata
+ for sset in state.stack_sets.values()
+ if sset.stack_set_name == stack_set_name
+ ]
+ if not result:
+ return not_found_error(f'Unable to find stack set "{stack_set_name}"')
+
+ return DescribeStackSetOutput(StackSet=result[0])
+
+ @handler("ListStackSets", expand=False)
+ def list_stack_sets(
+ self, context: RequestContext, request: ListStackSetsInput
+ ) -> ListStackSetsOutput:
+ state = get_cloudformation_store(context.account_id, context.region)
+ result = [sset.metadata for sset in state.stack_sets.values()]
+ return ListStackSetsOutput(Summaries=result)
+
+ @handler("UpdateStackSet", expand=False)
+ def update_stack_set(
+ self, context: RequestContext, request: UpdateStackSetInput
+ ) -> UpdateStackSetOutput:
+ state = get_cloudformation_store(context.account_id, context.region)
+ set_name = request.get("StackSetName")
+ stack_set = [sset for sset in state.stack_sets.values() if sset.stack_set_name == set_name]
+ if not stack_set:
+ return not_found_error(f'Stack set named "{set_name}" does not exist')
+ stack_set = stack_set[0]
+ stack_set.metadata.update(request)
+ op_id = request.get("OperationId") or short_uid()
+ operation = {
+ "OperationId": op_id,
+ "StackSetId": stack_set.metadata["StackSetId"],
+ "Action": "UPDATE",
+ "Status": "SUCCEEDED",
+ }
+ stack_set.operations[op_id] = operation
+ return UpdateStackSetOutput(OperationId=op_id)
+
+ @handler("DeleteStackSet")
+ def delete_stack_set(
+ self,
+ context: RequestContext,
+ stack_set_name: StackSetName,
+ call_as: CallAs = None,
+ **kwargs,
+ ) -> DeleteStackSetOutput:
+ state = get_cloudformation_store(context.account_id, context.region)
+ stack_set = [
+ sset for sset in state.stack_sets.values() if sset.stack_set_name == stack_set_name
+ ]
+
+ if not stack_set:
+ return not_found_error(f'Stack set named "{stack_set_name}" does not exist')
+
+ # TODO: add a check for remaining stack instances
+
+ for instance in stack_set[0].stack_instances:
+ deployer = template_deployer.TemplateDeployer(
+ context.account_id, context.region, instance.stack
+ )
+ deployer.delete_stack()
+ return DeleteStackSetOutput()
+
+ @handler("CreateStackInstances", expand=False)
+ def create_stack_instances(
+ self,
+ context: RequestContext,
+ request: CreateStackInstancesInput,
+ ) -> CreateStackInstancesOutput:
+ state = get_cloudformation_store(context.account_id, context.region)
+
+ set_name = request.get("StackSetName")
+ stack_set = [sset for sset in state.stack_sets.values() if sset.stack_set_name == set_name]
+
+ if not stack_set:
+ return not_found_error(f'Stack set named "{set_name}" does not exist')
+
+ stack_set = stack_set[0]
+ op_id = request.get("OperationId") or short_uid()
+ sset_meta = stack_set.metadata
+ accounts = request["Accounts"]
+ regions = request["Regions"]
+
+ stacks_to_await = []
+ for account in accounts:
+ for region in regions:
+ # deploy new stack
+ LOG.debug(
+ 'Deploying instance for stack set "%s" in account: %s region %s',
+ set_name,
+ account,
+ region,
+ )
+ cf_client = connect_to(aws_access_key_id=account, region_name=region).cloudformation
+ kwargs = select_attributes(sset_meta, ["TemplateBody"]) or select_attributes(
+ sset_meta, ["TemplateURL"]
+ )
+ stack_name = f"sset-{set_name}-{account}"
+
+ # skip creation of existing stacks
+ if find_stack(context.account_id, context.region, stack_name):
+ continue
+
+ result = cf_client.create_stack(StackName=stack_name, **kwargs)
+ stacks_to_await.append((stack_name, account, region))
+ # store stack instance
+ instance = {
+ "StackSetId": sset_meta["StackSetId"],
+ "OperationId": op_id,
+ "Account": account,
+ "Region": region,
+ "StackId": result["StackId"],
+ "Status": "CURRENT",
+ "StackInstanceStatus": {"DetailedStatus": "SUCCEEDED"},
+ }
+ instance = StackInstance(instance)
+ stack_set.stack_instances.append(instance)
+
+ # wait for completion of stack
+ for stack_name, account_id, region_name in stacks_to_await:
+ client = connect_to(
+ aws_access_key_id=account_id, region_name=region_name
+ ).cloudformation
+ client.get_waiter("stack_create_complete").wait(StackName=stack_name)
+
+ # record operation
+ operation = {
+ "OperationId": op_id,
+ "StackSetId": stack_set.metadata["StackSetId"],
+ "Action": "CREATE",
+ "Status": "SUCCEEDED",
+ }
+ stack_set.operations[op_id] = operation
+
+ return CreateStackInstancesOutput(OperationId=op_id)
+
+ @handler("ListStackInstances", expand=False)
+ def list_stack_instances(
+ self,
+ context: RequestContext,
+ request: ListStackInstancesInput,
+ ) -> ListStackInstancesOutput:
+ set_name = request.get("StackSetName")
+ state = get_cloudformation_store(context.account_id, context.region)
+ stack_set = [sset for sset in state.stack_sets.values() if sset.stack_set_name == set_name]
+ if not stack_set:
+ return not_found_error(f'Stack set named "{set_name}" does not exist')
+
+ stack_set = stack_set[0]
+ result = [inst.metadata for inst in stack_set.stack_instances]
+ return ListStackInstancesOutput(Summaries=result)
+
+ @handler("DeleteStackInstances", expand=False)
+ def delete_stack_instances(
+ self,
+ context: RequestContext,
+ request: DeleteStackInstancesInput,
+ ) -> DeleteStackInstancesOutput:
+ op_id = request.get("OperationId") or short_uid()
+
+ accounts = request["Accounts"]
+ regions = request["Regions"]
+
+ state = get_cloudformation_store(context.account_id, context.region)
+ stack_sets = state.stack_sets.values()
+
+ set_name = request.get("StackSetName")
+ stack_set = next((sset for sset in stack_sets if sset.stack_set_name == set_name), None)
+
+ if not stack_set:
+ return not_found_error(f'Stack set named "{set_name}" does not exist')
+
+ for account in accounts:
+ for region in regions:
+ instance = find_stack_instance(stack_set, account, region)
+ if instance:
+ stack_set.stack_instances.remove(instance)
+
+ # record operation
+ operation = {
+ "OperationId": op_id,
+ "StackSetId": stack_set.metadata["StackSetId"],
+ "Action": "DELETE",
+ "Status": "SUCCEEDED",
+ }
+ stack_set.operations[op_id] = operation
+
+ return DeleteStackInstancesOutput(OperationId=op_id)
+
+ @handler("RegisterType", expand=False)
+ def register_type(
+ self,
+ context: RequestContext,
+ request: RegisterTypeInput,
+ ) -> RegisterTypeOutput:
+ return RegisterTypeOutput()
+
+ def list_types(
+ self, context: RequestContext, request: ListTypesInput, **kwargs
+ ) -> ListTypesOutput:
+ def is_list_overridden(child_class, parent_class):
+ if hasattr(child_class, "list"):
+ import inspect
+
+ child_method = child_class.list
+ parent_method = parent_class.list
+ return inspect.unwrap(child_method) is not inspect.unwrap(parent_method)
+ return False
+
+ def get_listable_types_summaries(plugin_manager):
+ plugins = plugin_manager.list_names()
+ type_summaries = []
+ for plugin in plugins:
+ type_summary = TypeSummary(
+ Type=RegistryType.RESOURCE,
+ TypeName=plugin,
+ )
+ provider = plugin_manager.load(plugin)
+ if is_list_overridden(provider.factory, ResourceProvider):
+ type_summaries.append(type_summary)
+ return type_summaries
+
+ from localstack.services.cloudformation.resource_provider import (
+ plugin_manager,
+ )
+
+ type_summaries = get_listable_types_summaries(plugin_manager)
+ if PRO_RESOURCE_PROVIDERS:
+ from localstack.services.cloudformation.resource_provider import (
+ pro_plugin_manager,
+ )
+
+ type_summaries.extend(get_listable_types_summaries(pro_plugin_manager))
+
+ return ListTypesOutput(TypeSummaries=type_summaries)
diff --git a/localstack-core/localstack/services/cloudformation/provider_utils.py b/localstack-core/localstack/services/cloudformation/provider_utils.py
new file mode 100644
index 0000000000000..a69c69a17ba7c
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/provider_utils.py
@@ -0,0 +1,233 @@
+"""
+A set of utils for use in resource providers.
+
+Avoid any imports to localstack here and keep external imports to a minimum!
+This is because we want to be able to package a resource provider without including localstack code.
+"""
+
+import builtins
+import json
+import re
+import uuid
+from copy import deepcopy
+from pathlib import Path
+from typing import Callable, List, Optional
+
+from botocore.model import Shape, StructureShape
+
+
+def generate_default_name(stack_name: str, logical_resource_id: str):
+ random_id_part = str(uuid.uuid4())[0:8]
+ resource_id_part = logical_resource_id[:24]
+ stack_name_part = stack_name[: 63 - 2 - (len(random_id_part) + len(resource_id_part))]
+ return f"{stack_name_part}-{resource_id_part}-{random_id_part}"
+
+
+def generate_default_name_without_stack(logical_resource_id: str):
+ random_id_part = str(uuid.uuid4())[0:8]
+ resource_id_part = logical_resource_id[: 63 - 1 - len(random_id_part)]
+ return f"{resource_id_part}-{random_id_part}"
+
+
+# ========= Helpers for boto calls ==========
+# (equivalent to the old ones in deployment_utils.py)
+
+
+def deselect_attributes(model: dict, params: list[str]) -> dict:
+ return {k: v for k, v in model.items() if k not in params}
+
+
+def select_attributes(model: dict, params: list[str]) -> dict:
+ return {k: v for k, v in model.items() if k in params}
+
+
+def keys_lower(model: dict) -> dict:
+ return {k.lower(): v for k, v in model.items()}
+
+
+def convert_pascalcase_to_lower_camelcase(item: str) -> str:
+ if len(item) <= 1:
+ return item.lower()
+ else:
+ return f"{item[0].lower()}{item[1:]}"
+
+
+def _recurse_properties(obj: dict | list, fn: Callable) -> dict | list:
+ obj = fn(obj)
+ if isinstance(obj, dict):
+ return {k: _recurse_properties(v, fn) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ return [_recurse_properties(v, fn) for v in obj]
+ else:
+ return obj
+
+
+def recurse_properties(properties: dict, fn: Callable) -> dict:
+ return _recurse_properties(deepcopy(properties), fn)
+
+
+def keys_pascalcase_to_lower_camelcase(model: dict) -> dict:
+ """Recursively change any dicts keys to lower camelcase"""
+
+ def _keys_pascalcase_to_lower_camelcase(obj):
+ if isinstance(obj, dict):
+ return {convert_pascalcase_to_lower_camelcase(k): v for k, v in obj.items()}
+ else:
+ return obj
+
+ return _recurse_properties(model, _keys_pascalcase_to_lower_camelcase)
+
+
+def transform_list_to_dict(param, key_attr_name="Key", value_attr_name="Value"):
+ result = {}
+ for entry in param:
+ key = entry[key_attr_name]
+ value = entry[value_attr_name]
+ result[key] = value
+ return result
+
+
+def remove_none_values(obj):
+ """Remove None values (recursively) in the given object."""
+ if isinstance(obj, dict):
+ return {k: remove_none_values(v) for k, v in obj.items() if v is not None}
+ elif isinstance(obj, list):
+ return [o for o in obj if o is not None]
+ else:
+ return obj
+
+
+# FIXME: this shouldn't be necessary in the future
+param_validation = re.compile(
+ r"Invalid type for parameter (?P [\w.]+), value: (?P\w+), type: \w+)'>, valid types: \w+)'>"
+)
+
+
+def get_nested(obj: dict, path: str):
+ parts = path.split(".")
+ result = obj
+ for p in parts[:-1]:
+ result = result.get(p, {})
+ return result.get(parts[-1])
+
+
+def set_nested(obj: dict, path: str, value):
+ parts = path.split(".")
+ result = obj
+ for p in parts[:-1]:
+ result = result.get(p, {})
+ result[parts[-1]] = value
+
+
+def fix_boto_parameters_based_on_report(original_params: dict, report: str) -> dict:
+ """
+ Fix invalid type parameter validation errors in boto request parameters
+
+ :param original_params: original boto request parameters that lead to the parameter validation error
+ :param report: error report from botocore ParamValidator
+ :return: a copy of original_params with all values replaced by their correctly cast ones
+ """
+ params = deepcopy(original_params)
+ for found in param_validation.findall(report):
+ param_name, value, wrong_class, valid_class = found
+ cast_class = getattr(builtins, valid_class)
+ old_value = get_nested(params, param_name)
+
+ if cast_class == bool and str(old_value).lower() in ["true", "false"]:
+ new_value = str(old_value).lower() == "true"
+ else:
+ new_value = cast_class(old_value)
+ set_nested(params, param_name, new_value)
+ return params
+
+
+def convert_request_kwargs(parameters: dict, input_shape: StructureShape) -> dict:
+ """
+ Transform a dict of request kwargs for a boto3 request by making sure the keys in the structure recursively conform to the specified input shape.
+ :param parameters: the kwargs that would be passed to the boto3 client call, e.g. boto3.client("s3").create_bucket(**parameters)
+ :param input_shape: The botocore input shape of the operation that you want to call later with the fixed inputs
+ :return: a transformed dictionary with the correct casing recursively applied
+ """
+
+ def get_fixed_key(key: str, members: dict[str, Shape]) -> str:
+ """return the case-insensitively matched key from the shape or default to the current key"""
+ for k in members:
+ if k.lower() == key.lower():
+ return k
+ return key
+
+ def transform_value(value, member_shape):
+ if isinstance(value, dict) and hasattr(member_shape, "members"):
+ return convert_request_kwargs(value, member_shape)
+ elif isinstance(value, list) and hasattr(member_shape, "member"):
+ return [transform_value(item, member_shape.member) for item in value]
+
+ # fix the typing of the value
+ match member_shape.type_name:
+ case "string":
+ return str(value)
+ case "integer" | "long":
+ return int(value)
+ case "boolean":
+ if isinstance(value, bool):
+ return value
+ return True if value.lower() == "true" else False
+ case _:
+ return value
+
+ transformed_dict = {}
+ for key, value in parameters.items():
+ correct_key = get_fixed_key(key, input_shape.members)
+ member_shape = input_shape.members.get(correct_key)
+
+ if member_shape is None:
+ continue # skipping this entry, so it's not included in the transformed dict
+ elif isinstance(value, dict) and hasattr(member_shape, "members"):
+ transformed_dict[correct_key] = convert_request_kwargs(value, member_shape)
+ elif isinstance(value, list) and hasattr(member_shape, "member"):
+ transformed_dict[correct_key] = [
+ transform_value(item, member_shape.member) for item in value
+ ]
+ else:
+ transformed_dict[correct_key] = transform_value(value, member_shape)
+
+ return transformed_dict
+
+
+def convert_values_to_numbers(input_dict: dict, keys_to_skip: Optional[List[str]] = None):
+ """
+ Recursively converts all string values that represent valid integers
+ in a dictionary (including nested dictionaries and lists) to integers.
+
+ Example:
+ original_dict = {'Gid': '1322', 'SecondaryGids': ['1344', '1452'], 'Uid': '13234'}
+ output_dict = {'Gid': 1322, 'SecondaryGids': [1344, 1452], 'Uid': 13234}
+
+ :param input_dict input dict with values to convert
+ :param keys_to_skip keys to which values are not meant to be converted
+ :return output_dict
+ """
+
+ keys_to_skip = keys_to_skip or []
+
+ def recursive_convert(obj):
+ if isinstance(obj, dict):
+ return {
+ key: recursive_convert(value) if key not in keys_to_skip else value
+ for key, value in obj.items()
+ }
+ elif isinstance(obj, list):
+ return [recursive_convert(item) for item in obj]
+ elif isinstance(obj, str) and obj.isdigit():
+ return int(obj)
+ else:
+ return obj
+
+ return recursive_convert(input_dict)
+
+
+# LocalStack specific utilities
+def get_schema_path(file_path: Path) -> dict:
+ file_name_base = file_path.name.removesuffix(".py").removesuffix(".py.enc")
+ with Path(file_path).parent.joinpath(f"{file_name_base}.schema.json").open() as fd:
+ return json.load(fd)
diff --git a/localstack-core/localstack/services/cloudformation/resource_provider.py b/localstack-core/localstack/services/cloudformation/resource_provider.py
new file mode 100644
index 0000000000000..fa8744324d437
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/resource_provider.py
@@ -0,0 +1,648 @@
+from __future__ import annotations
+
+import copy
+import logging
+import re
+import time
+import uuid
+from dataclasses import dataclass, field
+from enum import Enum, auto
+from logging import Logger
+from math import ceil
+from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, Type, TypedDict, TypeVar
+
+import botocore
+from botocore.client import BaseClient
+from botocore.exceptions import ClientError
+from botocore.model import OperationModel
+from plux import Plugin, PluginManager
+
+from localstack import config
+from localstack.aws.connect import InternalClientFactory, ServiceLevelClientFactory
+from localstack.services.cloudformation import usage
+from localstack.services.cloudformation.deployment_utils import (
+ check_not_found_exception,
+ convert_data_types,
+ fix_account_id_in_arns,
+ fix_boto_parameters_based_on_report,
+ log_not_available_message,
+ remove_none_values,
+)
+from localstack.services.cloudformation.engine.quirks import PHYSICAL_RESOURCE_ID_SPECIAL_CASES
+from localstack.services.cloudformation.provider_utils import convert_request_kwargs
+from localstack.services.cloudformation.service_models import KEY_RESOURCE_STATE
+
+PRO_RESOURCE_PROVIDERS = False
+try:
+ from localstack.pro.core.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPluginExt,
+ )
+
+ PRO_RESOURCE_PROVIDERS = True
+except ImportError:
+ pass
+
+if TYPE_CHECKING:
+ from localstack.services.cloudformation.engine.types import (
+ FuncDetails,
+ FuncDetailsValue,
+ ResourceDefinition,
+ )
+
+LOG = logging.getLogger(__name__)
+
+Properties = TypeVar("Properties")
+
+PUBLIC_REGISTRY: dict[str, Type[ResourceProvider]] = {}
+
+PROVIDER_DEFAULTS = {} # TODO: remove this after removing patching in -ext
+
+
+class OperationStatus(Enum):
+ PENDING = auto()
+ IN_PROGRESS = auto()
+ SUCCESS = auto()
+ FAILED = auto()
+
+
+@dataclass
+class ProgressEvent(Generic[Properties]):
+ status: OperationStatus
+ resource_model: Optional[Properties] = None
+ resource_models: Optional[list[Properties]] = None
+
+ message: str = ""
+ result: Optional[str] = None
+ error_code: Optional[str] = None # TODO: enum
+ custom_context: dict = field(default_factory=dict)
+
+
+class Credentials(TypedDict):
+ accessKeyId: str
+ secretAccessKey: str
+ sessionToken: str
+
+
+class ResourceProviderPayloadRequestData(TypedDict):
+ logicalResourceId: str
+ resourceProperties: Properties
+ previousResourceProperties: Optional[Properties]
+ callerCredentials: Credentials
+ providerCredentials: Credentials
+ systemTags: dict[str, str]
+ previousSystemTags: dict[str, str]
+ stackTags: dict[str, str]
+ previousStackTags: dict[str, str]
+
+
+class ResourceProviderPayload(TypedDict):
+ callbackContext: dict
+ stackId: str
+ requestData: ResourceProviderPayloadRequestData
+ resourceType: str
+ resourceTypeVersion: str
+ awsAccountId: str
+ bearerToken: str
+ region: str
+ action: str
+
+
+ResourceProperties = TypeVar("ResourceProperties")
+
+
+def _handler_provide_client_params(event_name: str, params: dict, model: OperationModel, **kwargs):
+ """
+ A botocore hook handler that will try to convert the passed parameters according to the given operation model
+ """
+ return convert_request_kwargs(params, model.input_shape)
+
+
+class ConvertingInternalClientFactory(InternalClientFactory):
+ def _get_client_post_hook(self, client: BaseClient) -> BaseClient:
+ """
+ Register handlers that modify the passed properties to make them compatible with the API structure
+ """
+
+ client.meta.events.register(
+ "provide-client-params.*.*", handler=_handler_provide_client_params
+ )
+
+ return super()._get_client_post_hook(client)
+
+
+_cfn_resource_client_factory = ConvertingInternalClientFactory(use_ssl=config.DISTRIBUTED_MODE)
+
+
+def convert_payload(
+ stack_name: str, stack_id: str, payload: ResourceProviderPayload
+) -> ResourceRequest[Properties]:
+ client_factory = _cfn_resource_client_factory(
+ aws_access_key_id=payload["requestData"]["callerCredentials"]["accessKeyId"],
+ aws_session_token=payload["requestData"]["callerCredentials"]["sessionToken"],
+ aws_secret_access_key=payload["requestData"]["callerCredentials"]["secretAccessKey"],
+ region_name=payload["region"],
+ )
+ desired_state = payload["requestData"]["resourceProperties"]
+ rr = ResourceRequest(
+ _original_payload=desired_state,
+ aws_client_factory=client_factory,
+ request_token=str(uuid.uuid4()), # TODO: not actually a UUID
+ stack_name=stack_name,
+ stack_id=stack_id,
+ account_id=payload["awsAccountId"],
+ region_name=payload["region"],
+ desired_state=desired_state,
+ logical_resource_id=payload["requestData"]["logicalResourceId"],
+ resource_type=payload["resourceType"],
+ logger=logging.getLogger("abc"),
+ custom_context=payload["callbackContext"],
+ action=payload["action"],
+ )
+
+ if previous_properties := payload["requestData"].get("previousResourceProperties"):
+ rr.previous_state = previous_properties
+
+ return rr
+
+
+@dataclass
+class ResourceRequest(Generic[Properties]):
+ _original_payload: Properties
+
+ aws_client_factory: ServiceLevelClientFactory
+ request_token: str
+ stack_name: str
+ stack_id: str
+ account_id: str
+ region_name: str
+ action: str
+
+ desired_state: Properties
+
+ logical_resource_id: str
+ resource_type: str
+
+ logger: Logger
+
+ custom_context: dict = field(default_factory=dict)
+
+ previous_state: Optional[Properties] = None
+ previous_tags: Optional[dict[str, str]] = None
+ tags: dict[str, str] = field(default_factory=dict)
+
+
+class CloudFormationResourceProviderPlugin(Plugin):
+ """
+ Base class for resource provider plugins.
+ """
+
+ namespace = "localstack.cloudformation.resource_providers"
+
+
+class ResourceProvider(Generic[Properties]):
+ """
+ This provides a base class onto which service-specific resource providers are built.
+ """
+
+ SCHEMA: dict
+
+ def create(self, request: ResourceRequest[Properties]) -> ProgressEvent[Properties]:
+ raise NotImplementedError
+
+ def update(self, request: ResourceRequest[Properties]) -> ProgressEvent[Properties]:
+ raise NotImplementedError
+
+ def delete(self, request: ResourceRequest[Properties]) -> ProgressEvent[Properties]:
+ raise NotImplementedError
+
+ def read(self, request: ResourceRequest[Properties]) -> ProgressEvent[Properties]:
+ raise NotImplementedError
+
+ def list(self, request: ResourceRequest[Properties]) -> ProgressEvent[Properties]:
+ raise NotImplementedError
+
+
+# legacy helpers
+def get_resource_type(resource: dict) -> str:
+ """this is currently overwritten in PRO to add support for custom resources"""
+ if isinstance(resource, str):
+ raise ValueError(f"Invalid argument: {resource}")
+ try:
+ resource_type: str = resource["Type"]
+
+ if resource_type.startswith("Custom::"):
+ return "AWS::CloudFormation::CustomResource"
+ return resource_type
+ except Exception:
+ LOG.warning(
+ "Failed to retrieve resource type %s",
+ resource.get("Type"),
+ exc_info=LOG.isEnabledFor(logging.DEBUG),
+ )
+
+
+def invoke_function(
+ account_id: str,
+ region_name: str,
+ function: Callable,
+ params: dict,
+ resource_type: str,
+ func_details: FuncDetails,
+ action_name: str,
+ resource: Any,
+) -> Any:
+ try:
+ LOG.debug(
+ 'Request for resource type "%s" in account %s region %s: %s %s',
+ resource_type,
+ account_id,
+ region_name,
+ func_details["function"],
+ params,
+ )
+ try:
+ result = function(**params)
+ except botocore.exceptions.ParamValidationError as e:
+ # alternatively we could also use the ParamValidator directly
+ report = e.kwargs.get("report")
+ if not report:
+ raise
+
+ LOG.debug("Converting parameters to allowed types")
+ LOG.debug("Report: %s", report)
+ converted_params = fix_boto_parameters_based_on_report(params, report)
+ LOG.debug("Original parameters: %s", params)
+ LOG.debug("Converted parameters: %s", converted_params)
+
+ result = function(**converted_params)
+ except Exception as e:
+ if action_name == "Remove" and check_not_found_exception(e, resource_type, resource):
+ return
+ log_method = LOG.warning
+ if config.CFN_VERBOSE_ERRORS:
+ log_method = LOG.exception
+ log_method("Error calling %s with params: %s for resource: %s", function, params, resource)
+ raise e
+
+ return result
+
+
+def get_service_name(resource):
+ res_type = resource["Type"]
+ parts = res_type.split("::")
+ if len(parts) == 1:
+ return None
+ if "Cognito::IdentityPool" in res_type:
+ return "cognito-identity"
+ if res_type.endswith("Cognito::UserPool"):
+ return "cognito-idp"
+ if parts[-2] == "Cognito":
+ return "cognito-idp"
+ if parts[-2] == "Elasticsearch":
+ return "es"
+ if parts[-2] == "OpenSearchService":
+ return "opensearch"
+ if parts[-2] == "KinesisFirehose":
+ return "firehose"
+ if parts[-2] == "ResourceGroups":
+ return "resource-groups"
+ if parts[-2] == "CertificateManager":
+ return "acm"
+ if "ElasticLoadBalancing::" in res_type:
+ return "elb"
+ if "ElasticLoadBalancingV2::" in res_type:
+ return "elbv2"
+ if "ApplicationAutoScaling::" in res_type:
+ return "application-autoscaling"
+ if "MSK::" in res_type:
+ return "kafka"
+ if "Timestream::" in res_type:
+ return "timestream-write"
+ return parts[1].lower()
+
+
+def resolve_resource_parameters(
+ account_id_: str,
+ region_name_: str,
+ stack_name: str,
+ resource_definition: ResourceDefinition,
+ resources: dict[str, ResourceDefinition],
+ resource_id: str,
+ func_details: FuncDetailsValue,
+) -> dict | None:
+ params = func_details.get("parameters") or (
+ lambda account_id, region_name, properties, logical_resource_id, *args, **kwargs: properties
+ )
+ resource_props = resource_definition["Properties"] = resource_definition.get("Properties", {})
+ resource_props = dict(resource_props)
+ resource_state = resource_definition.get(KEY_RESOURCE_STATE, {})
+ last_deployed_state = resource_definition.get("_last_deployed_state", {})
+
+ if callable(params):
+ # resolve parameter map via custom function
+ params = params(
+ account_id_, region_name_, resource_props, resource_id, resource_definition, stack_name
+ )
+ else:
+ # it could be a list like ['param1', 'param2', {'apiCallParamName': 'cfResourcePropName'}]
+ if isinstance(params, list):
+ _params = {}
+ for param in params:
+ if isinstance(param, dict):
+ _params.update(param)
+ else:
+ _params[param] = param
+ params = _params
+
+ params = dict(params)
+ # TODO(srw): mutably mapping params :(
+ for param_key, prop_keys in dict(params).items():
+ params.pop(param_key, None)
+ if not isinstance(prop_keys, list):
+ prop_keys = [prop_keys]
+ for prop_key in prop_keys:
+ if callable(prop_key):
+ prop_value = prop_key(
+ account_id_,
+ region_name_,
+ resource_props,
+ resource_id,
+ resource_definition,
+ stack_name,
+ )
+ else:
+ prop_value = resource_props.get(
+ prop_key,
+ resource_definition.get(
+ prop_key,
+ resource_state.get(prop_key, last_deployed_state.get(prop_key)),
+ ),
+ )
+ if prop_value is not None:
+ params[param_key] = prop_value
+ break
+
+ # this is an indicator that we should skip this resource deployment, and return
+ if params is None:
+ return
+
+ # FIXME: move this to a single place after template processing is finished
+ # convert any moto account IDs (123456789012) in ARNs to our format (000000000000)
+ params = fix_account_id_in_arns(params, account_id_)
+ # convert data types (e.g., boolean strings to bool)
+ # TODO: this might not be needed anymore
+ params = convert_data_types(func_details.get("types", {}), params)
+ # remove None values, as they usually raise boto3 errors
+ params = remove_none_values(params)
+
+ return params
+
+
+class NoResourceProvider(Exception):
+ pass
+
+
+def resolve_json_pointer(resource_props: Properties, primary_id_path: str) -> str:
+ primary_id_path = primary_id_path.replace("/properties", "")
+ parts = [p for p in primary_id_path.split("/") if p]
+
+ resolved_part = resource_props.copy()
+ for i in range(len(parts)):
+ part = parts[i]
+ resolved_part = resolved_part.get(part)
+ if i == len(parts) - 1:
+ # last part
+ return resolved_part
+
+ raise Exception(f"Resource properties is missing field: {part}")
+
+
+class ResourceProviderExecutor:
+ """
+ Point of abstraction between our integration with generic base models, and the new providers.
+ """
+
+ def __init__(
+ self,
+ *,
+ stack_name: str,
+ stack_id: str,
+ ):
+ self.stack_name = stack_name
+ self.stack_id = stack_id
+
+ def deploy_loop(
+ self,
+ resource_provider: ResourceProvider,
+ resource: dict,
+ raw_payload: ResourceProviderPayload,
+ max_timeout: int = config.CFN_PER_RESOURCE_TIMEOUT,
+ sleep_time: float = 5,
+ ) -> ProgressEvent[Properties]:
+ payload = copy.deepcopy(raw_payload)
+
+ max_iterations = max(ceil(max_timeout / sleep_time), 2)
+
+ for current_iteration in range(max_iterations):
+ resource_type = get_resource_type(
+ {"Type": raw_payload["resourceType"]}
+ ) # TODO: simplify signature of get_resource_type to just take the type
+ resource["SpecifiedProperties"] = raw_payload["requestData"]["resourceProperties"]
+
+ try:
+ event = self.execute_action(resource_provider, payload)
+ except ClientError:
+ LOG.error(
+ "client error invoking '%s' handler for resource '%s' (type '%s')",
+ raw_payload["action"],
+ raw_payload["requestData"]["logicalResourceId"],
+ resource_type,
+ )
+ raise
+
+ match event.status:
+ case OperationStatus.FAILED:
+ return event
+ case OperationStatus.SUCCESS:
+ if not hasattr(resource_provider, "SCHEMA"):
+ raise Exception(
+ "A ResourceProvider should always have a SCHEMA property defined."
+ )
+ resource_type_schema = resource_provider.SCHEMA
+ physical_resource_id = self.extract_physical_resource_id_from_model_with_schema(
+ event.resource_model,
+ raw_payload["resourceType"],
+ resource_type_schema,
+ )
+
+ resource["PhysicalResourceId"] = physical_resource_id
+ resource["Properties"] = event.resource_model
+ resource["_last_deployed_state"] = copy.deepcopy(event.resource_model)
+ return event
+ case OperationStatus.IN_PROGRESS:
+ # update the shared state
+ context = {**payload["callbackContext"], **event.custom_context}
+ payload["callbackContext"] = context
+ payload["requestData"]["resourceProperties"] = event.resource_model
+ resource["Properties"] = event.resource_model
+
+ if current_iteration == 0:
+ time.sleep(0)
+ else:
+ time.sleep(sleep_time)
+ case OperationStatus.PENDING:
+ # come back to this resource in another iteration
+ return event
+ case invalid_status:
+ raise ValueError(
+ f"Invalid OperationStatus ({invalid_status}) returned for resource {raw_payload['requestData']['logicalResourceId']} (type {raw_payload['resourceType']})"
+ )
+
+ else:
+ raise TimeoutError(
+ f"Resource deployment for resource {raw_payload['requestData']['logicalResourceId']} (type {raw_payload['resourceType']}) timed out."
+ )
+
+ def execute_action(
+ self, resource_provider: ResourceProvider, raw_payload: ResourceProviderPayload
+ ) -> ProgressEvent[Properties]:
+ change_type = raw_payload["action"]
+ request = convert_payload(
+ stack_name=self.stack_name, stack_id=self.stack_id, payload=raw_payload
+ )
+
+ match change_type:
+ case "Add":
+ # replicate previous event emitting behaviour
+ usage.resource_type.record(request.resource_type)
+
+ return resource_provider.create(request)
+ case "Dynamic" | "Modify":
+ try:
+ return resource_provider.update(request)
+ except NotImplementedError:
+ LOG.warning(
+ 'Unable to update resource type "%s", id "%s"',
+ request.resource_type,
+ request.logical_resource_id,
+ )
+ if request.previous_state is None:
+ # this is an issue with our update detection. We should never be in this state.
+ request.action = "Add"
+ return resource_provider.create(request)
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS, resource_model=request.previous_state
+ )
+ except Exception as e:
+ # FIXME: this fallback should be removed after fixing updates in general (order/dependenies)
+ # catch-all for any exception that looks like a not found exception
+ if check_not_found_exception(e, request.resource_type, request.desired_state):
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS, resource_model=request.previous_state
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.FAILED,
+ resource_model={},
+ message=f"Failed to delete resource with id {request.logical_resource_id} of type {request.resource_type}",
+ )
+ case "Remove":
+ try:
+ return resource_provider.delete(request)
+ except Exception as e:
+ # catch-all for any exception that looks like a not found exception
+ if check_not_found_exception(e, request.resource_type, request.desired_state):
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model={})
+
+ return ProgressEvent(
+ status=OperationStatus.FAILED,
+ resource_model={},
+ message=f"Failed to delete resource with id {request.logical_resource_id} of type {request.resource_type}",
+ )
+ case _:
+ raise NotImplementedError(change_type) # TODO: change error type
+
+ @staticmethod
+ def try_load_resource_provider(resource_type: str) -> ResourceProvider | None:
+ # TODO: unify namespace of plugins
+
+ # 1. try to load pro resource provider
+ # prioritise pro resource providers
+ if PRO_RESOURCE_PROVIDERS:
+ try:
+ plugin = pro_plugin_manager.load(resource_type)
+ return plugin.factory()
+ except ValueError:
+ # could not find a plugin for that name
+ pass
+ except Exception:
+ LOG.warning(
+ "Failed to load PRO resource type %s as a ResourceProvider.",
+ resource_type,
+ exc_info=LOG.isEnabledFor(logging.DEBUG),
+ )
+
+ # 2. try to load community resource provider
+ try:
+ plugin = plugin_manager.load(resource_type)
+ return plugin.factory()
+ except ValueError:
+ # could not find a plugin for that name
+ pass
+ except Exception:
+ if config.CFN_VERBOSE_ERRORS:
+ LOG.warning(
+ "Failed to load community resource type %s as a ResourceProvider.",
+ resource_type,
+ exc_info=LOG.isEnabledFor(logging.DEBUG),
+ )
+
+ # 3. we could not find the resource provider so log the missing resource provider
+ log_not_available_message(
+ resource_type,
+ f'No resource provider found for "{resource_type}"',
+ )
+
+ usage.missing_resource_types.record(resource_type)
+
+ if config.CFN_IGNORE_UNSUPPORTED_RESOURCE_TYPES:
+ # TODO: figure out a better way to handle non-implemented here?
+ return None
+ else:
+ raise NoResourceProvider
+
+ def extract_physical_resource_id_from_model_with_schema(
+ self, resource_model: Properties, resource_type: str, resource_type_schema: dict
+ ) -> str:
+ if resource_type in PHYSICAL_RESOURCE_ID_SPECIAL_CASES:
+ primary_id_path = PHYSICAL_RESOURCE_ID_SPECIAL_CASES[resource_type]
+
+ if "<" in primary_id_path:
+ # composite quirk, e.g. something like MyRef|MyName
+ # try to extract parts
+ physical_resource_id = primary_id_path
+ find_results = re.findall("<([^>]+)>", primary_id_path)
+ for found_part in find_results:
+ resolved_part = resolve_json_pointer(resource_model, found_part)
+ physical_resource_id = physical_resource_id.replace(
+ f"<{found_part}>", resolved_part
+ )
+ else:
+ physical_resource_id = resolve_json_pointer(resource_model, primary_id_path)
+ else:
+ primary_id_paths = resource_type_schema["primaryIdentifier"]
+ if len(primary_id_paths) > 1:
+ # TODO: auto-merge. Verify logic here with AWS
+ physical_resource_id = "-".join(
+ [resolve_json_pointer(resource_model, pip) for pip in primary_id_paths]
+ )
+ else:
+ physical_resource_id = resolve_json_pointer(resource_model, primary_id_paths[0])
+
+ return physical_resource_id
+
+
+plugin_manager = PluginManager(CloudFormationResourceProviderPlugin.namespace)
+if PRO_RESOURCE_PROVIDERS:
+ pro_plugin_manager = PluginManager(CloudFormationResourceProviderPluginExt.namespace)
diff --git a/localstack-core/localstack/services/cloudformation/resource_providers/__init__.py b/localstack-core/localstack/services/cloudformation/resource_providers/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_macro.py b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_macro.py
new file mode 100644
index 0000000000000..8f17b3d36368e
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_macro.py
@@ -0,0 +1,102 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+from localstack.services.cloudformation.stores import get_cloudformation_store
+
+
+class CloudFormationMacroProperties(TypedDict):
+ FunctionName: Optional[str]
+ Name: Optional[str]
+ Description: Optional[str]
+ Id: Optional[str]
+ LogGroupName: Optional[str]
+ LogRoleARN: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class CloudFormationMacroProvider(ResourceProvider[CloudFormationMacroProperties]):
+ TYPE = "AWS::CloudFormation::Macro" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[CloudFormationMacroProperties],
+ ) -> ProgressEvent[CloudFormationMacroProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - FunctionName
+ - Name
+
+ Create-only properties:
+ - /properties/Name
+
+ Read-only properties:
+ - /properties/Id
+
+
+
+ """
+ model = request.desired_state
+
+ # TODO: fix or validate that we want to keep this here.
+ # AWS::CloudFormation:: resources need special handling since they seem to require access to internal APIs
+ store = get_cloudformation_store(request.account_id, request.region_name)
+ store.macros[model["Name"]] = model
+ model["Id"] = model["Name"]
+
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ def read(
+ self,
+ request: ResourceRequest[CloudFormationMacroProperties],
+ ) -> ProgressEvent[CloudFormationMacroProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[CloudFormationMacroProperties],
+ ) -> ProgressEvent[CloudFormationMacroProperties]:
+ """
+ Delete a resource
+
+
+ """
+ model = request.desired_state
+
+ store = get_cloudformation_store(request.account_id, request.region_name)
+ store.macros.pop(model["Name"], None)
+
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ def update(
+ self,
+ request: ResourceRequest[CloudFormationMacroProperties],
+ ) -> ProgressEvent[CloudFormationMacroProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_macro.schema.json b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_macro.schema.json
new file mode 100644
index 0000000000000..a04056992eb09
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_macro.schema.json
@@ -0,0 +1,38 @@
+{
+ "typeName": "AWS::CloudFormation::Macro",
+ "description": "Resource Type definition for AWS::CloudFormation::Macro",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "type": "string"
+ },
+ "Description": {
+ "type": "string"
+ },
+ "FunctionName": {
+ "type": "string"
+ },
+ "LogGroupName": {
+ "type": "string"
+ },
+ "LogRoleARN": {
+ "type": "string"
+ },
+ "Name": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "FunctionName",
+ "Name"
+ ],
+ "createOnlyProperties": [
+ "/properties/Name"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ]
+}
diff --git a/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_macro_plugin.py b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_macro_plugin.py
new file mode 100644
index 0000000000000..9c6572792fc21
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_macro_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class CloudFormationMacroProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::CloudFormation::Macro"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.cloudformation.resource_providers.aws_cloudformation_macro import (
+ CloudFormationMacroProvider,
+ )
+
+ self.factory = CloudFormationMacroProvider
diff --git a/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_stack.py b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_stack.py
new file mode 100644
index 0000000000000..b30c629682cc6
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_stack.py
@@ -0,0 +1,220 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class CloudFormationStackProperties(TypedDict):
+ TemplateURL: Optional[str]
+ Id: Optional[str]
+ NotificationARNs: Optional[list[str]]
+ Parameters: Optional[dict]
+ Tags: Optional[list[Tag]]
+ TimeoutInMinutes: Optional[int]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class CloudFormationStackProvider(ResourceProvider[CloudFormationStackProperties]):
+ TYPE = "AWS::CloudFormation::Stack" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[CloudFormationStackProperties],
+ ) -> ProgressEvent[CloudFormationStackProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - TemplateURL
+
+
+
+ Read-only properties:
+ - /properties/Id
+
+
+
+ """
+ model = request.desired_state
+
+ # TODO: validations
+
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ if not model.get("StackName"):
+ model["StackName"] = util.generate_default_name(
+ request.stack_name, request.logical_resource_id
+ )
+
+ create_params = util.select_attributes(
+ model,
+ [
+ "StackName",
+ "Parameters",
+ "NotificationARNs",
+ "TemplateURL",
+ "TimeoutInMinutes",
+ "Tags",
+ ],
+ )
+
+ create_params["Capabilities"] = [
+ "CAPABILITY_IAM",
+ "CAPABILITY_NAMED_IAM",
+ "CAPABILITY_AUTO_EXPAND",
+ ]
+
+ create_params["Parameters"] = [
+ {
+ "ParameterKey": k,
+ "ParameterValue": str(v).lower() if isinstance(v, bool) else str(v),
+ }
+ for k, v in create_params.get("Parameters", {}).items()
+ ]
+
+ result = request.aws_client_factory.cloudformation.create_stack(**create_params)
+ model["Id"] = result["StackId"]
+
+ request.custom_context[REPEATED_INVOCATION] = True
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ stack = request.aws_client_factory.cloudformation.describe_stacks(StackName=model["Id"])[
+ "Stacks"
+ ][0]
+ match stack["StackStatus"]:
+ case "CREATE_COMPLETE":
+ # only store nested stack outputs when we know the deploy has completed
+ model["Outputs"] = {
+ o["OutputKey"]: o["OutputValue"] for o in stack.get("Outputs", [])
+ }
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+ case "CREATE_IN_PROGRESS":
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+ case "CREATE_FAILED":
+ return ProgressEvent(
+ status=OperationStatus.FAILED,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+ case _:
+ raise Exception(f"Unexpected status: {stack['StackStatus']}")
+
+ def read(
+ self,
+ request: ResourceRequest[CloudFormationStackProperties],
+ ) -> ProgressEvent[CloudFormationStackProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[CloudFormationStackProperties],
+ ) -> ProgressEvent[CloudFormationStackProperties]:
+ """
+ Delete a resource
+ """
+
+ model = request.desired_state
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ request.aws_client_factory.cloudformation.delete_stack(StackName=model["Id"])
+
+ request.custom_context[REPEATED_INVOCATION] = True
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ try:
+ stack = request.aws_client_factory.cloudformation.describe_stacks(
+ StackName=model["Id"]
+ )["Stacks"][0]
+ except Exception as e:
+ if "does not exist" in str(e):
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+ raise e
+
+ match stack["StackStatus"]:
+ case "DELETE_COMPLETE":
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+ case "DELETE_IN_PROGRESS":
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+ case "DELETE_FAILED":
+ return ProgressEvent(
+ status=OperationStatus.FAILED,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+ case _:
+ raise Exception(f"Unexpected status: {stack['StackStatus']}")
+
+ def update(
+ self,
+ request: ResourceRequest[CloudFormationStackProperties],
+ ) -> ProgressEvent[CloudFormationStackProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
+
+ def list(
+ self,
+ request: ResourceRequest[CloudFormationStackProperties],
+ ) -> ProgressEvent[CloudFormationStackProperties]:
+ resources = request.aws_client_factory.cloudformation.describe_stacks()
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_models=[
+ CloudFormationStackProperties(Id=resource["StackId"])
+ for resource in resources["Stacks"]
+ ],
+ )
diff --git a/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_stack.schema.json b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_stack.schema.json
new file mode 100644
index 0000000000000..a26835e77ba10
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_stack.schema.json
@@ -0,0 +1,65 @@
+{
+ "typeName": "AWS::CloudFormation::Stack",
+ "description": "Resource Type definition for AWS::CloudFormation::Stack",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "type": "string"
+ },
+ "NotificationARNs": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Parameters": {
+ "type": "object",
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ },
+ "Tags": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "TemplateURL": {
+ "type": "string"
+ },
+ "TimeoutInMinutes": {
+ "type": "integer"
+ }
+ },
+ "definitions": {
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Key": {
+ "type": "string"
+ },
+ "Value": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ }
+ },
+ "required": [
+ "TemplateURL"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ]
+}
diff --git a/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_stack_plugin.py b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_stack_plugin.py
new file mode 100644
index 0000000000000..9dc020a564aa4
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_stack_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class CloudFormationStackProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::CloudFormation::Stack"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.cloudformation.resource_providers.aws_cloudformation_stack import (
+ CloudFormationStackProvider,
+ )
+
+ self.factory = CloudFormationStackProvider
diff --git a/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitcondition.py b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitcondition.py
new file mode 100644
index 0000000000000..051c901e425d9
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitcondition.py
@@ -0,0 +1,83 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import uuid
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class CloudFormationWaitConditionProperties(TypedDict):
+ Count: Optional[int]
+ Data: Optional[dict]
+ Handle: Optional[str]
+ Id: Optional[str]
+ Timeout: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class CloudFormationWaitConditionProvider(ResourceProvider[CloudFormationWaitConditionProperties]):
+ TYPE = "AWS::CloudFormation::WaitCondition" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[CloudFormationWaitConditionProperties],
+ ) -> ProgressEvent[CloudFormationWaitConditionProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Read-only properties:
+ - /properties/Data
+ - /properties/Id
+
+ """
+ model = request.desired_state
+ model["Data"] = {} # TODO
+ model["Id"] = f"{request.stack_id}/{uuid.uuid4()}/{request.logical_resource_id}"
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ def read(
+ self,
+ request: ResourceRequest[CloudFormationWaitConditionProperties],
+ ) -> ProgressEvent[CloudFormationWaitConditionProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[CloudFormationWaitConditionProperties],
+ ) -> ProgressEvent[CloudFormationWaitConditionProperties]:
+ """
+ Delete a resource
+
+
+ """
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model={}) # NO-OP
+
+ def update(
+ self,
+ request: ResourceRequest[CloudFormationWaitConditionProperties],
+ ) -> ProgressEvent[CloudFormationWaitConditionProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitcondition.schema.json b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitcondition.schema.json
new file mode 100644
index 0000000000000..232d5c012e745
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitcondition.schema.json
@@ -0,0 +1,29 @@
+{
+ "typeName": "AWS::CloudFormation::WaitCondition",
+ "description": "Resource Type definition for AWS::CloudFormation::WaitCondition",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "type": "string"
+ },
+ "Data": {
+ "type": "object"
+ },
+ "Count": {
+ "type": "integer"
+ },
+ "Handle": {
+ "type": "string"
+ },
+ "Timeout": {
+ "type": "string"
+ }
+ },
+ "readOnlyProperties": [
+ "/properties/Data",
+ "/properties/Id"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ]
+}
diff --git a/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitcondition_plugin.py b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitcondition_plugin.py
new file mode 100644
index 0000000000000..bdc8b49fd2e6d
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitcondition_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class CloudFormationWaitConditionProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::CloudFormation::WaitCondition"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.cloudformation.resource_providers.aws_cloudformation_waitcondition import (
+ CloudFormationWaitConditionProvider,
+ )
+
+ self.factory = CloudFormationWaitConditionProvider
diff --git a/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitconditionhandle.py b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitconditionhandle.py
new file mode 100644
index 0000000000000..f2b5237876fe0
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitconditionhandle.py
@@ -0,0 +1,94 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class CloudFormationWaitConditionHandleProperties(TypedDict):
+ Id: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class CloudFormationWaitConditionHandleProvider(
+ ResourceProvider[CloudFormationWaitConditionHandleProperties]
+):
+ TYPE = "AWS::CloudFormation::WaitConditionHandle" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[CloudFormationWaitConditionHandleProperties],
+ ) -> ProgressEvent[CloudFormationWaitConditionHandleProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+
+
+
+
+ Read-only properties:
+ - /properties/Id
+
+
+
+ """
+ # TODO: properly test this and fix s3 bucket usage
+ model = request.desired_state
+
+ s3 = request.aws_client_factory.s3
+ region = s3.meta.region_name
+
+ bucket = f"cloudformation-waitcondition-{region}"
+ waitcondition_url = s3.generate_presigned_url(
+ "put_object", Params={"Bucket": bucket, "Key": request.stack_id}
+ )
+ model["Id"] = waitcondition_url
+
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ def read(
+ self,
+ request: ResourceRequest[CloudFormationWaitConditionHandleProperties],
+ ) -> ProgressEvent[CloudFormationWaitConditionHandleProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[CloudFormationWaitConditionHandleProperties],
+ ) -> ProgressEvent[CloudFormationWaitConditionHandleProperties]:
+ """
+ Delete a resource
+
+
+ """
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model={})
+
+ def update(
+ self,
+ request: ResourceRequest[CloudFormationWaitConditionHandleProperties],
+ ) -> ProgressEvent[CloudFormationWaitConditionHandleProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitconditionhandle.schema.json b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitconditionhandle.schema.json
new file mode 100644
index 0000000000000..34c317b900bf4
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitconditionhandle.schema.json
@@ -0,0 +1,16 @@
+{
+ "typeName": "AWS::CloudFormation::WaitConditionHandle",
+ "description": "Resource Type definition for AWS::CloudFormation::WaitConditionHandle",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "type": "string"
+ }
+ },
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ]
+}
diff --git a/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitconditionhandle_plugin.py b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitconditionhandle_plugin.py
new file mode 100644
index 0000000000000..f5888171517ab
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/resource_providers/aws_cloudformation_waitconditionhandle_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class CloudFormationWaitConditionHandleProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::CloudFormation::WaitConditionHandle"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.cloudformation.resource_providers.aws_cloudformation_waitconditionhandle import (
+ CloudFormationWaitConditionHandleProvider,
+ )
+
+ self.factory = CloudFormationWaitConditionHandleProvider
diff --git a/localstack-core/localstack/services/cloudformation/scaffolding/CloudformationSchema.zip b/localstack-core/localstack/services/cloudformation/scaffolding/CloudformationSchema.zip
new file mode 100644
index 0000000000000..f9c8e2f6dbf4d
Binary files /dev/null and b/localstack-core/localstack/services/cloudformation/scaffolding/CloudformationSchema.zip differ
diff --git a/localstack-core/localstack/services/cloudformation/scaffolding/__main__.py b/localstack-core/localstack/services/cloudformation/scaffolding/__main__.py
new file mode 100644
index 0000000000000..d6eb97f8dbbf1
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/scaffolding/__main__.py
@@ -0,0 +1,824 @@
+from __future__ import annotations
+
+import json
+import os
+import zipfile
+from dataclasses import dataclass
+from enum import Enum, auto
+from functools import reduce
+from pathlib import Path
+from typing import Any, Generator, Literal, Optional, TypedDict, TypeVar
+
+import click
+from jinja2 import Environment, FileSystemLoader
+from yaml import safe_dump
+
+from .propgen import generate_ir_for_type
+
+try:
+ from rich.console import Console
+ from rich.syntax import Syntax
+except ImportError:
+
+ class Console:
+ def print(self, text: str):
+ print("# " + text.replace("[underline]", "").replace("[/underline]", ""))
+
+ def Syntax(text: str, *args, **kwargs) -> str:
+ return text
+
+
+# increase when any major changes are done to the scaffolding,
+# so that we can reason better about previously scaffolded resources in the future
+SCAFFOLDING_VERSION = 2
+
+# Some services require their names to be re-written as we know them by different names
+SERVICE_NAME_MAP = {
+ "OpenSearchService": "OpenSearch",
+ "Lambda": "lambda_",
+}
+
+
+class Property(TypedDict):
+ type: Optional[Literal["str"]]
+ items: Optional[dict]
+
+
+class HandlerDefinition(TypedDict):
+ permissions: Optional[list[str]]
+
+
+class HandlersDefinition(TypedDict):
+ create: HandlerDefinition
+ read: HandlerDefinition
+ update: HandlerDefinition
+ delete: HandlerDefinition
+ list: HandlerDefinition
+
+
+class ResourceSchema(TypedDict):
+ typeName: str
+ description: Optional[str]
+ required: Optional[list[str]]
+ properties: dict[str, Property]
+ handlers: HandlersDefinition
+
+
+def resolve_ref(schema: ResourceSchema, target: str) -> dict:
+ """
+ Given a schema {"a": {"b": "c"}} and the ref "#/a/b" return "c"
+ """
+ target_path = filter(None, (elem.strip() for elem in target.lstrip("#").split("/")))
+
+ T = TypeVar("T")
+
+ def lookup(d: dict[str, T], key: str) -> dict | T:
+ return d[key]
+
+ return reduce(lookup, target_path, schema)
+
+
+@dataclass
+class ResourceName:
+ full_name: str
+ namespace: str
+ service: str
+ resource: str
+ python_compatible_service_name: str
+
+ def provider_name(self) -> str:
+ return f"{self.service}{self.resource}"
+
+ def schema_filename(self) -> str:
+ return f"{self.namespace.lower()}-{self.service.lower()}-{self.resource.lower()}.json"
+
+ def path_compatible_full_name(self) -> str:
+ return f"{self.namespace.lower()}_{self.service.lower()}_{self.resource.lower()}"
+
+ @classmethod
+ def from_name(cls, name: str) -> ResourceName:
+ parts = name.split("::")
+ if len(parts) != 3 or parts[0] != "AWS":
+ raise ValueError(f"Invalid CloudFormation resource name {name}")
+
+ raw_service_name = parts[1].strip()
+ renamed_service = SERVICE_NAME_MAP.get(raw_service_name, raw_service_name)
+
+ return ResourceName(
+ full_name=name,
+ namespace=parts[0],
+ service=raw_service_name,
+ python_compatible_service_name=renamed_service,
+ resource=parts[2].strip(),
+ )
+
+
+def get_formatted_template_output(
+ env: Environment, template_name: str, *render_args, **render_kwargs
+) -> str:
+ template = env.get_template(template_name)
+ return template.render(*render_args, **render_kwargs)
+
+
+class SchemaProvider:
+ def __init__(self, zipfile_path: Path):
+ self.schemas = {}
+ with zipfile.ZipFile(zipfile_path) as infile:
+ for filename in infile.namelist():
+ with infile.open(filename) as schema_file:
+ schema = json.load(schema_file)
+ typename = schema["typeName"]
+ self.schemas[typename] = schema
+
+ def schema(self, resource_name: ResourceName) -> ResourceSchema:
+ try:
+ return self.schemas[resource_name.full_name]
+ except KeyError as e:
+ raise click.ClickException(
+ f"Could not find schema for CloudFormation resource type: {resource_name.full_name}"
+ ) from e
+
+
+LOCALSTACK_ROOT_DIR = Path(__file__).parent.joinpath("../../../../..").resolve()
+LOCALSTACK_PRO_ROOT_DIR = LOCALSTACK_ROOT_DIR.joinpath("../localstack-ext").resolve()
+TESTS_ROOT_DIR = LOCALSTACK_ROOT_DIR.joinpath(
+ "tests/aws/services/cloudformation/resource_providers"
+)
+TESTS_PRO_ROOT_DIR = LOCALSTACK_PRO_ROOT_DIR.joinpath(
+ "localstack-pro-core/tests/aws/services/cloudformation/resource_providers"
+)
+
+assert LOCALSTACK_ROOT_DIR.is_dir(), f"{LOCALSTACK_ROOT_DIR} does not exist"
+assert LOCALSTACK_PRO_ROOT_DIR.is_dir(), f"{LOCALSTACK_PRO_ROOT_DIR} does not exist"
+assert TESTS_ROOT_DIR.is_dir(), f"{TESTS_ROOT_DIR} does not exist"
+assert TESTS_PRO_ROOT_DIR.is_dir(), f"{TESTS_PRO_ROOT_DIR} does not exist"
+
+
+def root_dir(pro: bool = False) -> Path:
+ if pro:
+ return LOCALSTACK_PRO_ROOT_DIR
+ else:
+ return LOCALSTACK_ROOT_DIR
+
+
+def tests_root_dir(pro: bool = False) -> Path:
+ if pro:
+ return TESTS_PRO_ROOT_DIR
+ else:
+ return TESTS_ROOT_DIR
+
+
+def template_path(
+ resource_name: ResourceName,
+ file_type: FileType,
+ root: Optional[Path] = None,
+ pro: bool = False,
+) -> Path:
+ """
+ Given a resource name and file type, return the path of the template relative to the template root.
+ """
+ match file_type:
+ case FileType.minimal_template:
+ stub = "basic.yaml"
+ case FileType.attribute_template:
+ stub = "getatt_exploration.yaml"
+ case FileType.update_without_replacement_template:
+ stub = "update.yaml"
+ case FileType.autogenerated_template:
+ stub = "basic_autogenerated.yaml"
+ case _:
+ raise ValueError(f"File type {file_type} is not a template")
+
+ output_path = (
+ tests_root_dir(pro)
+ .joinpath(
+ f"{resource_name.python_compatible_service_name.lower()}/{resource_name.path_compatible_full_name()}/templates/{stub}"
+ )
+ .resolve()
+ )
+
+ if root:
+ test_path = (
+ root_dir(pro)
+ .joinpath(
+ f"tests/aws/cloudformation/resource_providers/{resource_name.python_compatible_service_name.lower()}/{resource_name.path_compatible_full_name()}"
+ )
+ .resolve()
+ )
+
+ common_root = os.path.relpath(output_path, test_path)
+ return Path(common_root)
+ else:
+ return output_path
+
+
+class FileType(Enum):
+ # service code
+ plugin = auto()
+ provider = auto()
+
+ # test files
+ integration_test = auto()
+ getatt_test = auto()
+ # cloudcontrol_test = auto()
+ parity_test = auto()
+
+ # templates
+ attribute_template = auto()
+ minimal_template = auto()
+ update_without_replacement_template = auto()
+ autogenerated_template = auto()
+
+ # schema
+ schema = auto()
+
+
+class TemplateRenderer:
+ def __init__(self, schema: ResourceSchema, environment: Environment, pro: bool = False):
+ self.schema = schema
+ self.environment = environment
+ self.pro = pro
+
+ def render(
+ self,
+ file_type: FileType,
+ resource_name: ResourceName,
+ ) -> str:
+ # Generated outputs (template, schema)
+ # templates
+ if file_type == FileType.attribute_template:
+ return self.render_attribute_template(resource_name)
+ elif file_type == FileType.minimal_template:
+ return self.render_minimal_template(resource_name)
+ elif file_type == FileType.update_without_replacement_template:
+ return self.render_update_without_replacement_template(resource_name)
+ elif file_type == FileType.autogenerated_template:
+ return self.render_autogenerated_template(resource_name)
+ # schema
+ elif file_type == FileType.schema:
+ return json.dumps(self.schema, indent=2)
+
+ template_mapping = {
+ FileType.plugin: "plugin_template.py.j2",
+ FileType.provider: "provider_template.py.j2",
+ FileType.getatt_test: "test_getatt_template.py.j2",
+ FileType.integration_test: "test_integration_template.py.j2",
+ # FileType.cloudcontrol_test: "test_cloudcontrol_template.py.j2",
+ FileType.parity_test: "test_parity_template.py.j2",
+ }
+ kwargs = dict(
+ name=resource_name.full_name, # AWS::SNS::Topic
+ resource=resource_name.provider_name(), # SNSTopic
+ scaffolding_version=f"v{SCAFFOLDING_VERSION}",
+ )
+ # TODO: we might want to segregate each provider in its own directory
+ # e.g. .../resource_providers/aws_iam_role/test_X.py vs. .../resource_providers/iam/test_X.py
+ # add extra parameters
+ tests_output_path = root_dir(self.pro).joinpath(
+ f"tests/aws/cloudformation/resource_providers/{resource_name.python_compatible_service_name.lower()}/{resource_name.full_name.lower()}"
+ )
+ match file_type:
+ case FileType.getatt_test:
+ kwargs["getatt_targets"] = list(self.get_getatt_targets())
+ kwargs["service"] = resource_name.service.lower()
+ kwargs["resource"] = resource_name.resource.lower()
+ kwargs["template_path"] = str(
+ template_path(resource_name, FileType.attribute_template, tests_output_path)
+ )
+ case FileType.provider:
+ property_ir = generate_ir_for_type(
+ [self.schema],
+ resource_name.full_name,
+ provider_prefix=resource_name.provider_name(),
+ )
+ kwargs["provider_properties"] = property_ir
+ kwargs["required_properties"] = self.schema.get("required")
+ kwargs["create_only_properties"] = self.schema.get("createOnlyProperties")
+ kwargs["read_only_properties"] = self.schema.get("readOnlyProperties")
+ kwargs["primary_identifier"] = self.schema.get("primaryIdentifier")
+ kwargs["create_permissions"] = (
+ self.schema.get("handlers", {}).get("create", {}).get("permissions")
+ )
+ kwargs["delete_permissions"] = (
+ self.schema.get("handlers", {}).get("delete", {}).get("permissions")
+ )
+ kwargs["read_permissions"] = (
+ self.schema.get("handlers", {}).get("read", {}).get("permissions")
+ )
+ kwargs["update_permissions"] = (
+ self.schema.get("handlers", {}).get("update", {}).get("permissions")
+ )
+ kwargs["list_permissions"] = (
+ self.schema.get("handlers", {}).get("list", {}).get("permissions")
+ )
+ case FileType.plugin:
+ kwargs["service"] = resource_name.python_compatible_service_name.lower()
+ kwargs["lower_resource"] = resource_name.resource.lower()
+ kwargs["pro"] = self.pro
+ case FileType.integration_test:
+ kwargs["black_box_template_path"] = str(
+ template_path(resource_name, FileType.minimal_template, tests_output_path)
+ )
+ kwargs["update_template_path"] = str(
+ template_path(
+ resource_name,
+ FileType.update_without_replacement_template,
+ tests_output_path,
+ )
+ )
+ kwargs["autogenerated_template_path"] = str(
+ template_path(resource_name, FileType.autogenerated_template, tests_output_path)
+ )
+ # case FileType.cloudcontrol_test:
+ case FileType.parity_test:
+ kwargs["parity_test_filename"] = "test_parity.py"
+ case _:
+ raise NotImplementedError(f"Rendering template of type {file_type}")
+
+ return get_formatted_template_output(
+ self.environment, template_mapping[file_type], **kwargs
+ )
+
+ def get_getatt_targets(self) -> Generator[str, None, None]:
+ for name, defn in self.schema["properties"].items():
+ if "type" in defn and defn["type"] in ["string"]:
+ yield name
+
+ def render_minimal_template(self, resource_name: ResourceName) -> str:
+ template = {
+ "AWSTemplateFormatVersion": "2010-09-09",
+ "Description": f"Template to exercise create and delete operations for {resource_name.full_name}",
+ "Resources": {
+ "MyResource": {
+ "Type": resource_name.full_name,
+ "Properties": {},
+ },
+ },
+ "Outputs": {
+ "MyRef": {
+ "Value": {
+ "Ref": "MyResource",
+ },
+ },
+ },
+ }
+
+ return safe_dump(template, sort_keys=False)
+
+ def render_update_without_replacement_template(self, resource_name: ResourceName) -> str:
+ template = {
+ "AWSTemplateFormatVersion": "2010-09-09",
+ "Description": f"Template to exercise updating {resource_name.full_name}",
+ "Parameters": {
+ "AttributeValue": {
+ "Type": "String",
+ "Description": "Value of property to change to force an update",
+ },
+ },
+ "Resources": {
+ "MyResource": {
+ "Type": resource_name.full_name,
+ "Properties": {
+ "SomeProperty": "!Ref AttributeValue",
+ },
+ },
+ },
+ "Outputs": {
+ "MyRef": {
+ "Value": {
+ "Ref": "MyResource",
+ },
+ },
+ "MyOutput": {
+ "Value": "# TODO: the value to verify",
+ },
+ },
+ }
+ return safe_dump(template, sort_keys=False)
+
+ def render_autogenerated_template(self, resource_name: ResourceName) -> str:
+ template = {
+ "AWSTemplateFormatVersion": "2010-09-09",
+ "Description": f"Template to exercise updating autogenerated properties of {resource_name.full_name}",
+ "Resources": {
+ "MyResource": {
+ "Type": resource_name.full_name,
+ },
+ },
+ "Outputs": {
+ "MyRef": {
+ "Value": {
+ "Ref": "MyResource",
+ },
+ },
+ },
+ }
+ return safe_dump(template, sort_keys=False)
+
+ def render_attribute_template(self, resource_name: ResourceName) -> str:
+ template = {
+ "AWSTemplateFormatVersion": "2010-09-09",
+ "Description": f"Template to exercise getting attributes of {resource_name.full_name}",
+ "Parameters": {
+ "AttributeName": {
+ "Type": "String",
+ "Description": "Name of the attribute to fetch from the resource",
+ },
+ },
+ "Resources": {
+ "MyResource": {
+ "Type": resource_name.full_name,
+ "Properties": {},
+ },
+ },
+ "Outputs": self.render_outputs(),
+ }
+
+ return safe_dump(template, sort_keys=False)
+
+ def required_properties(self) -> dict[str, Property]:
+ return PropertyRenderer(self.schema).properties()
+
+ def render_outputs(self) -> dict:
+ """
+ Generate an output for each property in the schema
+ """
+ outputs = {}
+
+ # ref
+ outputs["MyRef"] = {"Value": {"Ref": "MyResource"}}
+
+ # getatt
+ outputs["MyOutput"] = {"Value": {"Fn::GetAtt": ["MyResource", {"Ref": "AttributeName"}]}}
+
+ return outputs
+
+
+class PropertyRenderer:
+ def __init__(self, schema: ResourceSchema):
+ self.schema = schema
+
+ def properties(self) -> dict:
+ required_properties = self.schema.get("required", [])
+
+ result = {}
+ for name, defn in self.schema["properties"].items():
+ if name not in required_properties:
+ continue
+
+ value = self.render_property(defn)
+ result[name] = value
+
+ return result
+
+ def render_property(self, property: Property) -> str | dict | list:
+ if prop_type := property.get("type"):
+ if prop_type in {"string"}:
+ return self._render_basic(prop_type)
+ elif prop_type == "array":
+ return [self.render_property(item) for item in property["items"]]
+ elif oneof := property.get("oneOf"):
+ return self._render_one_of(oneof)
+ else:
+ raise NotImplementedError(property)
+
+ def _render_basic(self, type: str) -> str:
+ return "CHANGEME"
+
+ def _render_one_of(self, options: list[Property]) -> Any:
+ return self.render_property(options[0])
+
+
+class FileWriter:
+ destination_files: dict[FileType, Path]
+
+ def __init__(
+ self, resource_name: ResourceName, console: Console, overwrite: bool, pro: bool = False
+ ):
+ self.resource_name = resource_name
+ self.console = console
+ self.overwrite = overwrite
+ self.pro = pro
+
+ base_path = (
+ ["localstack-pro-core", "localstack", "pro", "core"]
+ if self.pro
+ else ["localstack-core", "localstack"]
+ )
+
+ self.destination_files = {
+ FileType.provider: root_dir(self.pro).joinpath(
+ *base_path,
+ "services",
+ self.resource_name.python_compatible_service_name.lower(),
+ "resource_providers",
+ f"{self.resource_name.namespace.lower()}_{self.resource_name.service.lower()}_{self.resource_name.resource.lower()}.py",
+ ),
+ FileType.plugin: root_dir(self.pro).joinpath(
+ *base_path,
+ "services",
+ self.resource_name.python_compatible_service_name.lower(),
+ "resource_providers",
+ f"{self.resource_name.namespace.lower()}_{self.resource_name.service.lower()}_{self.resource_name.resource.lower()}_plugin.py",
+ ),
+ FileType.schema: root_dir(self.pro).joinpath(
+ *base_path,
+ "services",
+ self.resource_name.python_compatible_service_name.lower(),
+ "resource_providers",
+ f"aws_{self.resource_name.service.lower()}_{self.resource_name.resource.lower()}.schema.json",
+ ),
+ FileType.integration_test: tests_root_dir(self.pro).joinpath(
+ self.resource_name.python_compatible_service_name.lower(),
+ self.resource_name.path_compatible_full_name(),
+ "test_basic.py",
+ ),
+ FileType.getatt_test: tests_root_dir(self.pro).joinpath(
+ self.resource_name.python_compatible_service_name.lower(),
+ self.resource_name.path_compatible_full_name(),
+ "test_exploration.py",
+ ),
+ # FileType.cloudcontrol_test: tests_root_dir(self.pro).joinpath(
+ # self.resource_name.python_compatible_service_name.lower(),
+ # f"test_aws_{self.resource_name.service.lower()}_{self.resource_name.resource.lower()}_cloudcontrol.py",
+ # ),
+ FileType.parity_test: tests_root_dir(self.pro).joinpath(
+ self.resource_name.python_compatible_service_name.lower(),
+ self.resource_name.path_compatible_full_name(),
+ "test_parity.py",
+ ),
+ }
+
+ # output files that are templates
+ templates = [
+ FileType.attribute_template,
+ FileType.minimal_template,
+ FileType.update_without_replacement_template,
+ FileType.autogenerated_template,
+ ]
+ for template_type in templates:
+ self.destination_files[template_type] = template_path(self.resource_name, template_type)
+
+ def write(self, file_type: FileType, contents: str):
+ file_destination = self.destination_files[file_type]
+ destination_path = file_destination.parent
+ destination_path.mkdir(parents=True, exist_ok=True)
+
+ if file_destination.exists():
+ should_overwrite = self.confirm_overwrite(file_destination)
+ if not should_overwrite:
+ self.console.print(f"Skipping {file_destination}")
+ return
+
+ match file_type:
+ # provider
+ case FileType.provider:
+ self.ensure_python_init_files(destination_path)
+ self.write_text(contents, file_destination)
+ self.console.print(f"Written provider to {file_destination}")
+ case FileType.plugin:
+ self.ensure_python_init_files(destination_path)
+ self.write_text(contents, file_destination)
+ self.console.print(f"Written plugin to {file_destination}")
+
+ # tests
+ case FileType.integration_test:
+ self.ensure_python_init_files(destination_path)
+ self.write_text(contents, file_destination)
+ self.console.print(f"Written integration test to {file_destination}")
+ case FileType.getatt_test:
+ self.write_text(contents, file_destination)
+ self.console.print(f"Written getatt tests to {file_destination}")
+ # case FileType.cloudcontrol_test:
+ # self.write_text(contents, file_destination)
+ # self.console.print(f"Written cloudcontrol tests to {file_destination}")
+ case FileType.parity_test:
+ self.write_text(contents, file_destination)
+ self.console.print(f"Written parity tests to {file_destination}")
+
+ # templates
+ case FileType.attribute_template:
+ self.write_text(contents, file_destination)
+ self.console.print(f"Written attribute template to {file_destination}")
+ case FileType.minimal_template:
+ self.write_text(contents, file_destination)
+ self.console.print(f"Written minimal template to {file_destination}")
+ case FileType.update_without_replacement_template:
+ self.write_text(contents, file_destination)
+ self.console.print(
+ f"Written update without replacement template to {file_destination}"
+ )
+ case FileType.autogenerated_template:
+ self.write_text(contents, file_destination)
+ self.console.print(
+ f"Written autogenerated properties template to {file_destination}"
+ )
+
+ # schema
+ case FileType.schema:
+ self.write_text(contents, file_destination)
+ self.console.print(f"Written schema to {file_destination}")
+ case _:
+ raise NotImplementedError(f"Writing {file_type}")
+
+ def confirm_overwrite(self, destination_file: Path) -> bool:
+ """
+ If a file we are about to write to exists, overwrite or ignore.
+
+ :return True if file should be (over-)written, False otherwise
+ """
+ return self.overwrite or click.confirm("Destination files already exist, overwrite?")
+
+ @staticmethod
+ def write_text(contents: str, destination: Path):
+ with destination.open("wt") as outfile:
+ print(contents, file=outfile)
+
+ @staticmethod
+ def ensure_python_init_files(path: Path):
+ """
+ Make sure __init__.py files are created correctly
+ """
+ project_root = path.parent.parent.parent.parent
+ path_relative_to_root = path.relative_to(project_root)
+ dir = project_root
+ for part in path_relative_to_root.parts:
+ dir = dir / part
+ test_path = dir.joinpath("__init__.py")
+ if not test_path.is_file():
+ # touch file
+ with test_path.open("w"):
+ pass
+
+
+class OutputFactory:
+ def __init__(
+ self,
+ template_renderer: TemplateRenderer,
+ printer: Console,
+ writer: FileWriter,
+ ):
+ self.template_renderer = template_renderer
+ self.printer = printer
+ self.writer = writer
+
+ def get(self, file_type: FileType, resource_name: ResourceName) -> Output:
+ contents = self.template_renderer.render(file_type, resource_name)
+ return Output(contents, file_type, self.printer, self.writer, resource_name)
+
+
+class Output:
+ def __init__(
+ self,
+ contents: str,
+ file_type: FileType,
+ printer: Console,
+ writer: FileWriter,
+ resource_name: ResourceName,
+ ):
+ self.contents = contents
+ self.file_type = file_type
+ self.printer = printer
+ self.writer = writer
+ self.resource_name = resource_name
+
+ def handle(self, should_write: bool = False):
+ if should_write:
+ self.write()
+ else:
+ self.print()
+
+ def write(self):
+ self.writer.write(self.file_type, self.contents)
+
+ def print(self):
+ match self.file_type:
+ # service code
+ case FileType.provider:
+ self.printer.print("\n[underline]Provider template[/underline]\n")
+ self.printer.print(Syntax(self.contents, "python"))
+ case FileType.plugin:
+ self.printer.print("\n[underline]Plugin[/underline]\n")
+ self.printer.print(Syntax(self.contents, "python"))
+ # tests
+ case FileType.integration_test:
+ self.printer.print("\n[underline]Integration test file[/underline]\n")
+ self.printer.print(Syntax(self.contents, "python"))
+ case FileType.getatt_test:
+ self.printer.print("\n[underline]GetAtt test file[/underline]\n")
+ self.printer.print(Syntax(self.contents, "python"))
+ # case FileType.cloudcontrol_test:
+ # self.printer.print("\n[underline]CloudControl test[/underline]\n")
+ # self.printer.print(Syntax(self.contents, "python"))
+ case FileType.parity_test:
+ self.printer.print("\n[underline]Parity test[/underline]\n")
+ self.printer.print(Syntax(self.contents, "python"))
+
+ # templates
+ case FileType.attribute_template:
+ self.printer.print("\n[underline]Attribute Test Template[/underline]\n")
+ self.printer.print(Syntax(self.contents, "yaml"))
+ case FileType.minimal_template:
+ self.printer.print("\n[underline]Minimal template[/underline]\n")
+ self.printer.print(Syntax(self.contents, "yaml"))
+ case FileType.update_without_replacement_template:
+ self.printer.print("\n[underline]Update test template[/underline]\n")
+ self.printer.print(Syntax(self.contents, "yaml"))
+ case FileType.autogenerated_template:
+ self.printer.print("\n[underline]Autogenerated properties template[/underline]\n")
+ self.printer.print(Syntax(self.contents, "yaml"))
+
+ # schema
+ case FileType.schema:
+ self.printer.print("\n[underline]Schema[/underline]\n")
+ self.printer.print(Syntax(self.contents, "json"))
+ case _:
+ raise NotImplementedError(self.file_type)
+
+
+@click.group()
+def cli():
+ pass
+
+
+@cli.command()
+@click.option(
+ "-r",
+ "--resource-type",
+ required=True,
+ help="CloudFormation resource type (e.g. 'AWS::SSM::Parameter') to generate",
+)
+@click.option("-w", "--write/--no-write", default=False)
+@click.option("--overwrite", is_flag=True, default=False)
+@click.option("-t", "--write-tests/--no-write-tests", default=False)
+@click.option("--pro", is_flag=True, default=False)
+def generate(
+ resource_type: str,
+ write: bool,
+ write_tests: bool,
+ overwrite: bool,
+ pro: bool,
+):
+ console = Console()
+ console.rule(title=resource_type)
+
+ schema_provider = SchemaProvider(
+ zipfile_path=Path(__file__).parent.joinpath("CloudformationSchema.zip")
+ )
+
+ template_root = Path(__file__).parent.joinpath("templates")
+ env = Environment(
+ loader=FileSystemLoader(template_root),
+ )
+
+ parts = resource_type.rpartition("::")
+ if parts[-1] == "*":
+ # generate all resource types for that service
+ matching_resources = [x for x in schema_provider.schemas.keys() if x.startswith(parts[0])]
+ else:
+ matching_resources = [resource_type]
+
+ for matching_resource in matching_resources:
+ console.rule(title=matching_resource)
+ resource_name = ResourceName.from_name(matching_resource)
+ schema = schema_provider.schema(resource_name)
+
+ template_renderer = TemplateRenderer(schema, env, pro)
+ writer = FileWriter(resource_name, console, overwrite, pro)
+ output_factory = OutputFactory(template_renderer, console, writer) # noqa
+ for file_type in FileType:
+ if not write_tests and file_type in {
+ FileType.integration_test,
+ FileType.getatt_test,
+ FileType.parity_test,
+ FileType.minimal_template,
+ FileType.update_without_replacement_template,
+ FileType.attribute_template,
+ FileType.autogenerated_template,
+ }:
+ # skip test generation
+ continue
+ output_factory.get(file_type, resource_name).handle(should_write=write)
+
+ console.rule(title="Resources & Instructions")
+ console.print(
+ "Resource types: https://docs.aws.amazon.com/cloudformation-cli/latest/userguide/resource-types.html"
+ )
+ # TODO: print for every resource
+ for matching_resource in matching_resources:
+ resource_name = ResourceName.from_name(matching_resource)
+ console.print(
+ # lambda_ should become lambda (re-use the same list we use for generating the models)
+ f"{matching_resource}: https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-{resource_name.service.lower()}-{resource_name.resource.lower()}.html"
+ )
+ console.print("\nWondering where to get started?")
+ console.print(
+ "First run `make entrypoints` to make sure your resource provider plugin is actually registered."
+ )
+ console.print(
+ 'Then start off by finalizing the generated minimal ("basic") template and get it to deploy against AWS.'
+ )
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/localstack-core/localstack/services/cloudformation/scaffolding/propgen.py b/localstack-core/localstack/services/cloudformation/scaffolding/propgen.py
new file mode 100644
index 0000000000000..6a7e90166b490
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/scaffolding/propgen.py
@@ -0,0 +1,227 @@
+"""
+Implementation of generating the types for a provider from the schema
+"""
+
+from __future__ import annotations
+
+import logging
+import textwrap
+from dataclasses import dataclass
+from typing import Optional, TypedDict
+
+LOG = logging.getLogger(__name__)
+
+
+@dataclass
+class Item:
+ """An Item is a single field definition"""
+
+ name: str
+ type: str
+ required: bool
+
+ def __str__(self) -> str:
+ return f"{self.name}: {self.type}"
+
+ @classmethod
+ def new(cls, name: str, type: str, required: bool = False) -> Item:
+ if required:
+ return cls(name=name, type=type, required=required)
+ else:
+ return cls(name=name, type=f"Optional[{type}]", required=required)
+
+
+@dataclass
+class PrimitiveStruct:
+ name: str
+ primitive_type: str
+
+ def __str__(self) -> str:
+ return f"""
+{self.name} = {self.primitive_type}
+"""
+
+
+@dataclass
+class Struct:
+ """A struct represents a single rendered class"""
+
+ name: str
+ items: list[Item]
+
+ def __str__(self) -> str:
+ if self.items:
+ raw_text = "\n".join(map(str, self.sorted_items))
+ else:
+ raw_text = "pass"
+ formatted_items = textwrap.indent(raw_text, " ")
+ return f"""
+class {self.name}(TypedDict):
+{formatted_items}
+"""
+
+ @property
+ def sorted_items(self) -> list[Item]:
+ required_items = sorted(
+ [item for item in self.items if item.required], key=lambda item: item.name
+ )
+ optional_items = sorted(
+ [item for item in self.items if not item.required], key=lambda item: item.name
+ )
+ return required_items + optional_items
+
+
+@dataclass
+class IR:
+ structs: list[Struct]
+
+ def __str__(self) -> str:
+ """
+ Pretty print the IR
+ """
+ return "\n\n".join(map(str, self.structs))
+
+
+class Schema(TypedDict):
+ properties: dict
+ definitions: dict
+ typeName: str
+ required: Optional[list[str]]
+
+
+TYPE_MAP = {
+ "string": "str",
+ "boolean": "bool",
+ "integer": "int",
+ "number": "float",
+ "object": "dict",
+ "array": "list",
+}
+
+
+class PropertyTypeScaffolding:
+ resource_type: str
+ provider_prefix: str
+ schema: Schema
+
+ structs: list[Struct]
+
+ required_properties: list[str]
+
+ def __init__(self, resource_type: str, provider_prefix: str, schema: Schema):
+ self.resource_type = resource_type
+ self.provider_prefix = provider_prefix
+ self.schema = schema
+ self.structs = []
+ self.required_properties = schema.get("required", [])
+
+ def get_structs(self) -> list[Struct]:
+ root_struct = Struct(f"{self.provider_prefix}Properties", items=[])
+ self._add_struct(root_struct)
+
+ for property_name, property_def in self.schema["properties"].items():
+ is_required = property_name in self.required_properties
+ item = self.property_to_item(property_name, property_def, is_required)
+ root_struct.items.append(item)
+
+ return self.structs
+
+ def _add_struct(self, struct: Struct):
+ if struct.name in [s.name for s in self.structs]:
+ return
+ else:
+ self.structs.append(struct)
+
+ def get_ref_definition(self, property_ref: str) -> dict:
+ property_ref_name = property_ref.lstrip("#").rpartition("/")[-1]
+ return self.schema["definitions"][property_ref_name]
+
+ def resolve_type_of_property(self, property_def: dict) -> str:
+ if property_ref := property_def.get("$ref"):
+ ref_definition = self.get_ref_definition(property_ref)
+ ref_type = ref_definition.get("type")
+ if ref_type not in ["object", "array"]:
+ # in this case we simply flatten it (instead of for example creating a type alias)
+ resolved_type = TYPE_MAP.get(ref_type)
+ if resolved_type is None:
+ LOG.warning(
+ "Type for %s not found in the TYPE_MAP. Using `Any` as fallback.", ref_type
+ )
+ resolved_type = "Any"
+ else:
+ if ref_type == "object":
+ # the object might only have a pattern defined and no actual properties
+ if "properties" not in ref_definition:
+ resolved_type = "dict"
+ else:
+ nested_struct = self.ref_to_struct(property_ref)
+ resolved_type = nested_struct.name
+ self._add_struct(nested_struct)
+ elif ref_type == "array":
+ item_def = ref_definition["items"]
+ item_type = self.resolve_type_of_property(item_def)
+ resolved_type = f"list[{item_type}]"
+ else:
+ raise Exception(f"Unknown property type encountered: {ref_type}")
+ else:
+ match property_type := property_def.get("type"):
+ # primitives
+ case "string":
+ resolved_type = "str"
+ case "boolean":
+ resolved_type = "bool"
+ case "integer":
+ resolved_type = "int"
+ case "number":
+ resolved_type = "float"
+ # complex objects
+ case "object":
+ resolved_type = "dict" # TODO: any cases where we need to continue here?
+ case "array":
+ try:
+ item_type = self.resolve_type_of_property(property_def["items"])
+ resolved_type = f"list[{item_type}]"
+ except RecursionError:
+ resolved_type = "list[Any]"
+ case _:
+ # TODO: allOf, anyOf, patternProperties (?)
+ # AWS::ApiGateway::RestApi passes a ["object", "string"] here for the "Body" property
+ # it probably makes sense to assume this behaves the same as a "oneOf"
+ if one_of := property_def.get("oneOf"):
+ resolved_type = "|".join([self.resolve_type_of_property(o) for o in one_of])
+ elif isinstance(property_type, list):
+ resolved_type = "|".join([TYPE_MAP[pt] for pt in property_type])
+ else:
+ raise Exception(f"Unknown property type: {property_type}")
+ return resolved_type
+
+ def property_to_item(self, property_name: str, property_def: dict, required: bool) -> Item:
+ resolved_type = self.resolve_type_of_property(property_def)
+ return Item(name=property_name, type=f"Optional[{resolved_type}]", required=required)
+
+ def ref_to_struct(self, property_ref: str) -> Struct:
+ property_ref_name = property_ref.lstrip("#").rpartition("/")[-1]
+ resolved_def = self.schema["definitions"][property_ref_name]
+ nested_struct = Struct(name=property_ref_name, items=[])
+ if resolved_properties := resolved_def.get("properties"):
+ required_props = resolved_def.get("required", [])
+ for k, v in resolved_properties.items():
+ is_required = k in required_props
+ item = self.property_to_item(k, v, is_required)
+ nested_struct.items.append(item)
+ else:
+ raise Exception("Unknown resource format. Expected properties on object")
+
+ return nested_struct
+
+
+def generate_ir_for_type(schema: list[Schema], type_name: str, provider_prefix: str = "") -> IR:
+ try:
+ resource_schema = [every for every in schema if every["typeName"] == type_name][0]
+ except IndexError:
+ raise ValueError(f"could not find schema for type {type_name}")
+
+ structs = PropertyTypeScaffolding(
+ resource_type=type_name, provider_prefix=provider_prefix, schema=resource_schema
+ ).get_structs()
+ return IR(structs=structs)
diff --git a/localstack-core/localstack/services/cloudformation/scaffolding/templates/plugin_template.py.j2 b/localstack-core/localstack/services/cloudformation/scaffolding/templates/plugin_template.py.j2
new file mode 100644
index 0000000000000..0a9a530cdfccc
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/scaffolding/templates/plugin_template.py.j2
@@ -0,0 +1,22 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import ResourceProvider
+{%- if pro %}
+{%- set base_class = "CloudFormationResourceProviderPluginExt" %}
+{%- set root_module = "localstack.pro.core" %}
+{%- else %}
+{%- set base_class = "CloudFormationResourceProviderPlugin" %}
+{%- set root_module = "localstack" %}
+{%- endif %}
+from {{ root_module }}.services.cloudformation.resource_provider import {{ base_class }}
+
+class {{ resource }}ProviderPlugin({{ base_class }}):
+ name = "{{ name }}"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from {{ root_module }}.services.{{ service }}.resource_providers.aws_{{ service }}_{{ lower_resource }} import {{ resource }}Provider
+
+ self.factory = {{ resource }}Provider
diff --git a/localstack-core/localstack/services/cloudformation/scaffolding/templates/provider_template.py.j2 b/localstack-core/localstack/services/cloudformation/scaffolding/templates/provider_template.py.j2
new file mode 100644
index 0000000000000..3d52dbd6b7a83
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/scaffolding/templates/provider_template.py.j2
@@ -0,0 +1,138 @@
+# LocalStack Resource Provider Scaffolding {{ scaffolding_version }}
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+{{ provider_properties }}
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+class {{ resource }}Provider(ResourceProvider[{{ resource }}Properties]):
+
+ TYPE = "{{ name }}" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[{{ resource }}Properties],
+ ) -> ProgressEvent[{{ resource }}Properties]:
+ """
+ Create a new resource.
+
+ {% if primary_identifier -%}
+ Primary identifier fields:
+ {%- for property in primary_identifier %}
+ - {{ property }}
+ {%- endfor %}
+ {%- endif %}
+
+ {% if required_properties -%}
+ Required properties:
+ {%- for property in required_properties %}
+ - {{ property }}
+ {%- endfor %}
+ {%- endif %}
+
+ {% if create_only_properties -%}
+ Create-only properties:
+ {%- for property in create_only_properties %}
+ - {{ property }}
+ {%- endfor %}
+ {%- endif %}
+
+ {% if read_only_properties -%}
+ Read-only properties:
+ {%- for property in read_only_properties %}
+ - {{ property }}
+ {%- endfor %}
+ {%- endif %}
+
+ {% if create_permissions -%}
+ IAM permissions required:
+ {%- for permission in create_permissions %}
+ - {{ permission }}
+ {%- endfor -%}
+ {%- endif %}
+
+ """
+ model = request.desired_state
+
+ # TODO: validations
+
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ # this is the first time this callback is invoked
+ # TODO: defaults
+ # TODO: idempotency
+ # TODO: actually create the resource
+ request.custom_context[REPEATED_INVOCATION] = True
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ # TODO: check the status of the resource
+ # - if finished, update the model with all fields and return success event:
+ # return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+ # - else
+ # return ProgressEvent(status=OperationStatus.IN_PROGRESS, resource_model=model)
+
+ raise NotImplementedError
+
+ def read(
+ self,
+ request: ResourceRequest[{{ resource }}Properties],
+ ) -> ProgressEvent[{{ resource }}Properties]:
+ """
+ Fetch resource information
+
+ {% if read_permissions -%}
+ IAM permissions required:
+ {%- for permission in read_permissions %}
+ - {{ permission }}
+ {%- endfor %}
+ {%- endif %}
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[{{ resource }}Properties],
+ ) -> ProgressEvent[{{ resource }}Properties]:
+ """
+ Delete a resource
+
+ {% if delete_permissions -%}
+ IAM permissions required:
+ {%- for permission in delete_permissions %}
+ - {{ permission }}
+ {%- endfor %}
+ {%- endif %}
+ """
+ raise NotImplementedError
+
+ def update(
+ self,
+ request: ResourceRequest[{{ resource }}Properties],
+ ) -> ProgressEvent[{{ resource }}Properties]:
+ """
+ Update a resource
+
+ {% if update_permissions -%}
+ IAM permissions required:
+ {%- for permission in update_permissions %}
+ - {{ permission }}
+ {%- endfor %}
+ {%- endif %}
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/cloudformation/scaffolding/templates/test_getatt_template.py.j2 b/localstack-core/localstack/services/cloudformation/scaffolding/templates/test_getatt_template.py.j2
new file mode 100644
index 0000000000000..24f59945903b5
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/scaffolding/templates/test_getatt_template.py.j2
@@ -0,0 +1,41 @@
+# LocalStack Resource Provider Scaffolding {{ scaffolding_version }}
+import os
+
+import pytest
+
+from localstack.testing.aws.util import is_aws_cloud
+
+
+RESOURCE_GETATT_TARGETS = {{getatt_targets}}
+
+
+class TestAttributeAccess:
+ @pytest.mark.parametrize("attribute", RESOURCE_GETATT_TARGETS)
+ @pytest.mark.skipif(condition=not is_aws_cloud(), reason="Exploratory test only")
+ def test_getatt(
+ self,
+ aws_client,
+ deploy_cfn_template,
+ attribute,
+ snapshot,
+ ):
+ """
+ Use this test to find out which properties support GetAtt access
+
+ Fn::GetAtt documentation: https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/intrinsic-function-reference-getatt.html
+ """
+
+ stack = deploy_cfn_template(
+ template_path=os.path.join(
+ os.path.dirname(__file__),
+ "{{ template_path }}",
+ ),
+ parameters={"AttributeName": attribute},
+ )
+ snapshot.match("stack_outputs", stack.outputs)
+
+ # check physical resource id
+ res = aws_client.cloudformation.describe_stack_resource(
+ StackName=stack.stack_name, LogicalResourceId="MyResource"
+ )["StackResourceDetail"]
+ snapshot.match("physical_resource_id", res.get("PhysicalResourceId"))
diff --git a/localstack-core/localstack/services/cloudformation/scaffolding/templates/test_integration_template.py.j2 b/localstack-core/localstack/services/cloudformation/scaffolding/templates/test_integration_template.py.j2
new file mode 100644
index 0000000000000..98bd596be3b89
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/scaffolding/templates/test_integration_template.py.j2
@@ -0,0 +1,93 @@
+# LocalStack Resource Provider Scaffolding {{ scaffolding_version }}
+import os
+
+import pytest
+# from botocore.exceptions import ClientError
+
+
+class TestBasicCRD:
+
+ def test_black_box(self, deploy_cfn_template, aws_client, snapshot):
+ """
+ Simple test that
+ - deploys a stack containing the resource
+ - verifies that the resource has been created correctly by querying the service directly
+ - deletes the stack ensuring that the delete operation has been implemented correctly
+ - verifies that the resource no longer exists by querying the service directly
+ """
+ stack = deploy_cfn_template(
+ template_path=os.path.join(
+ os.path.dirname(__file__),
+ "{{ black_box_template_path }}",
+ ),
+ )
+ snapshot.match("stack-outputs", stack.outputs)
+
+ # TODO: fetch the resource and perform any required validations here
+ # e.g.
+ # parameter_name = stack.outputs["MyRef"]
+ # snapshot.add_transformer(snapshot.transform.regex(parameter_name, ""))
+
+ # res = aws_client.ssm.get_parameter(Name=stack.outputs["MyRef"])
+ # - this snapshot also asserts that the value set in the template is correct
+ # snapshot.match("describe-resource", res)
+
+ # verify that the delete operation works
+ stack.destroy()
+
+ # TODO: fetch the resource again and assert that it no longer exists
+ # e.g.
+ # with pytest.raises(ClientError):
+ # aws_client.ssm.get_parameter(Name=stack.outputs["MyRef"])
+
+ def test_autogenerated_values(self, aws_client, deploy_cfn_template, snapshot):
+ stack = deploy_cfn_template(
+ template_path=os.path.join(
+ os.path.dirname(__file__),
+ "{{ autogenerated_template_path }}",
+ ),
+ )
+ snapshot.match("stack_outputs", stack.outputs)
+
+ # user_name = stack.outputs["MyRef"]
+
+ # verify resource has been correctly deployed with the autogenerated field
+ # e.g. aws_client.iam.get_user(UserName=user_name)
+
+ # check the auto-generated pattern
+ # TODO: add a sample of the auto-generated value here for reference, e.g. "TestStack-CustomUser-13AA838"
+
+
+class TestUpdates:
+ @pytest.mark.skip(reason="TODO")
+ def test_update_without_replacement(self, deploy_cfn_template, aws_client, snapshot):
+ """
+ Test an UPDATE of a simple property that does not require replacing the entire resource.
+ Check out the official resource documentation at https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-template-resource-type-ref.html to see if a property needs replacement
+ """
+ stack = deploy_cfn_template(
+ template_path=os.path.join(
+ os.path.dirname(__file__),
+ "{{ update_template_path }}",
+ ),
+ parameters={"AttributeValue": "first"},
+ )
+
+ # TODO: implement fetching the resource and performing any required validations here
+ res = aws_client.ssm.get_parameter(Name=stack.outputs["MyRef"])
+ snapshot.match("describe-resource-before-update", res)
+
+ # TODO: update the stack
+ deploy_cfn_template(
+ stack_name=stack.stack_name,
+ template_path=os.path.join(
+ os.path.dirname(__file__),
+ "{{ update_template_path }}",
+ ),
+ parameters={"AttributeValue": "second"},
+ is_update=True,
+ )
+
+ # TODO: check the value has changed
+ res = aws_client.ssm.get_parameter(Name=stack.outputs["MyRef"])
+ snapshot.match("describe-resource-after-update", res)
diff --git a/localstack-core/localstack/services/cloudformation/scaffolding/templates/test_parity_template.py.j2 b/localstack-core/localstack/services/cloudformation/scaffolding/templates/test_parity_template.py.j2
new file mode 100644
index 0000000000000..6cf269aa392db
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/scaffolding/templates/test_parity_template.py.j2
@@ -0,0 +1,30 @@
+# ruff: noqa
+# LocalStack Resource Provider Scaffolding {{ scaffolding_version }}
+
+import pytest
+
+
+@pytest.mark.skip(reason="TODO")
+class TestParity:
+ """
+ Pro-active parity-focused tests that go into more detailed than the basic test skeleton
+
+ TODO: add more focused detailed tests for updates, different combinations, etc.
+ Use snapshots here to capture detailed parity with AWS
+
+ Other ideas for tests in here:
+ - Negative test: invalid combination of properties
+ - Negative test: missing required properties
+ """
+
+ def test_create_with_full_properties(self, aws_client, deploy_cfn_template):
+ """ A sort of smoke test that simply covers as many properties as possible """
+ ...
+
+
+
+
+@pytest.mark.skip(reason="TODO")
+class TestSamples:
+ """ User-provided samples and other reactively added scenarios (e.g. reported and reproduced GitHub issues) """
+ ...
diff --git a/localstack-core/localstack/services/cloudformation/service_models.py b/localstack-core/localstack/services/cloudformation/service_models.py
new file mode 100644
index 0000000000000..aeadbeb85f305
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/service_models.py
@@ -0,0 +1,128 @@
+import logging
+from typing import TypedDict
+
+from localstack.services.cloudformation.deployment_utils import check_not_found_exception
+
+LOG = logging.getLogger(__name__)
+
+# dict key used to store the deployment state of a resource
+KEY_RESOURCE_STATE = "_state_"
+
+
+class DependencyNotYetSatisfied(Exception):
+ """Exception indicating that a resource dependency is not (yet) deployed/available."""
+
+ def __init__(self, resource_ids, message=None):
+ message = message or "Unresolved dependencies: %s" % resource_ids
+ super(DependencyNotYetSatisfied, self).__init__(message)
+ resource_ids = resource_ids if isinstance(resource_ids, list) else [resource_ids]
+ self.resource_ids = resource_ids
+
+
+class ResourceJson(TypedDict):
+ Type: str
+ Properties: dict
+
+
+class GenericBaseModel:
+ """Abstract base class representing a resource model class in LocalStack.
+ This class keeps references to a combination of (1) the CF resource
+ properties (as defined in the template), and (2) the current deployment
+ state of a resource.
+
+ Concrete subclasses will implement convenience methods to manage resources,
+ e.g., fetching the latest deployment state, getting the resource name, etc.
+ """
+
+ def __init__(self, account_id: str, region_name: str, resource_json: dict, **params):
+ # self.stack_name = stack_name # TODO: add stack name to params
+ self.account_id = account_id
+ self.region_name = region_name
+ self.resource_json = resource_json
+ self.resource_type = resource_json["Type"]
+ # Properties, as defined in the resource template
+ self.properties = resource_json["Properties"] = resource_json.get("Properties") or {}
+ # State, as determined from the deployed resource; use a special dict key here to keep
+ # track of state changes within resource_json (this way we encapsulate all state details
+ # in `resource_json` and the changes will survive creation of multiple instances of this class)
+ self.state = resource_json[KEY_RESOURCE_STATE] = resource_json.get(KEY_RESOURCE_STATE) or {}
+
+ # ----------------------
+ # ABSTRACT BASE METHODS
+ # ----------------------
+
+ def fetch_state(self, stack_name, resources):
+ """Fetch the latest deployment state of this resource, or return None if not currently deployed (NOTE: THIS IS NOT ALWAYS TRUE)."""
+ return None
+
+ def update_resource(self, new_resource, stack_name, resources):
+ """Update the deployment of this resource, using the updated properties (implemented by subclasses)."""
+ raise NotImplementedError
+
+ def is_updatable(self) -> bool:
+ return type(self).update_resource != GenericBaseModel.update_resource
+
+ @classmethod
+ def cloudformation_type(cls):
+ """Return the CloudFormation resource type name, e.g., "AWS::S3::Bucket" (implemented by subclasses)."""
+ pass
+
+ @staticmethod
+ def get_deploy_templates():
+ """Return template configurations used to create the final API requests (implemented by subclasses)."""
+ pass
+
+ # TODO: rework to normal instance method when resources aren't mutated in different place anymore
+ @staticmethod
+ def add_defaults(resource, stack_name: str):
+ """Set any defaults required, including auto-generating names. Must be called before deploying the resource"""
+ pass
+
+ # ---------------------
+ # GENERIC UTIL METHODS
+ # ---------------------
+
+ # TODO: remove
+ def fetch_and_update_state(self, *args, **kwargs):
+ if self.physical_resource_id is None:
+ return None
+
+ try:
+ state = self.fetch_state(*args, **kwargs)
+ self.update_state(state)
+ return state
+ except Exception as e:
+ if not check_not_found_exception(e, self.resource_type, self.properties):
+ LOG.warning(
+ "Unable to fetch state for resource %s: %s",
+ self,
+ e,
+ exc_info=LOG.isEnabledFor(logging.DEBUG),
+ )
+
+ # TODO: remove
+ def update_state(self, details):
+ """Update the deployment state of this resource (existing attributes will be overwritten)."""
+ details = details or {}
+ self.state.update(details)
+
+ @property
+ def physical_resource_id(self) -> str | None:
+ """Return the (cached) physical resource ID."""
+ return self.resource_json.get("PhysicalResourceId")
+
+ @property
+ def logical_resource_id(self) -> str:
+ """Return the logical resource ID."""
+ return self.resource_json["LogicalResourceId"]
+
+ # TODO: rename? make it clearer what props are in comparison with state, properties and resource_json
+ @property
+ def props(self) -> dict:
+ """Return a copy of (1) the resource properties (from the template), combined with
+ (2) the current deployment state properties of the resource."""
+ result = dict(self.properties)
+ result.update(self.state or {})
+ last_state = self.resource_json.get("_last_deployed_state", {})
+ result.update(last_state)
+ return result
diff --git a/localstack-core/localstack/services/cloudformation/stores.py b/localstack-core/localstack/services/cloudformation/stores.py
new file mode 100644
index 0000000000000..11c8fa0cbb879
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/stores.py
@@ -0,0 +1,129 @@
+import logging
+from typing import Optional
+
+from localstack.aws.api.cloudformation import StackStatus
+from localstack.services.cloudformation.engine.entities import Stack, StackChangeSet, StackSet
+from localstack.services.stores import AccountRegionBundle, BaseStore, LocalAttribute
+
+LOG = logging.getLogger(__name__)
+
+
+class CloudFormationStore(BaseStore):
+ # maps stack ID to stack details
+ stacks: dict[str, Stack] = LocalAttribute(default=dict)
+
+ # maps stack set ID to stack set details
+ stack_sets: dict[str, StackSet] = LocalAttribute(default=dict)
+
+ # maps macro ID to macros
+ macros: dict[str, dict] = LocalAttribute(default=dict)
+
+ # exports: dict[str, str]
+ @property
+ def exports(self):
+ exports = []
+ output_keys = {}
+ for stack_id, stack in self.stacks.items():
+ for output in stack.resolved_outputs:
+ export_name = output.get("ExportName")
+ if not export_name:
+ continue
+ if export_name in output_keys:
+ # TODO: raise exception on stack creation in case of duplicate exports
+ LOG.warning(
+ "Found duplicate export name %s in stacks: %s %s",
+ export_name,
+ output_keys[export_name],
+ stack.stack_id,
+ )
+ entry = {
+ "ExportingStackId": stack.stack_id,
+ "Name": export_name,
+ "Value": output["OutputValue"],
+ }
+ exports.append(entry)
+ output_keys[export_name] = stack.stack_id
+ return exports
+
+
+cloudformation_stores = AccountRegionBundle("cloudformation", CloudFormationStore)
+
+
+def get_cloudformation_store(account_id: str, region_name: str) -> CloudFormationStore:
+ return cloudformation_stores[account_id][region_name]
+
+
+# TODO: rework / fix usage of this
+def find_stack(account_id: str, region_name: str, stack_name: str) -> Stack | None:
+ # Warning: This function may not return the correct stack if multiple stacks with same name exist.
+ state = get_cloudformation_store(account_id, region_name)
+ return (
+ [s for s in state.stacks.values() if stack_name in [s.stack_name, s.stack_id]] or [None]
+ )[0]
+
+
+def find_stack_by_id(account_id: str, region_name: str, stack_id: str) -> Stack | None:
+ """
+ Find the stack by id.
+
+ :param account_id: account of the stack
+ :param region_name: region of the stack
+ :param stack_id: stack id
+ :return: Stack if it is found, None otherwise
+ """
+ state = get_cloudformation_store(account_id, region_name)
+ for stack in state.stacks.values():
+ # there can only be one stack with an id
+ if stack_id == stack.stack_id:
+ return stack
+ return None
+
+
+def find_active_stack_by_name_or_id(
+ account_id: str, region_name: str, stack_name_or_id: str
+) -> Stack | None:
+ """
+ Find the active stack by name. Some cloudformation operations only allow referencing by slack name if the stack is
+ "active", which we currently interpret as not DELETE_COMPLETE.
+
+ :param account_id: account of the stack
+ :param region_name: region of the stack
+ :param stack_name_or_id: stack name or stack id
+ :return: Stack if it is found, None otherwise
+ """
+ state = get_cloudformation_store(account_id, region_name)
+ for stack in state.stacks.values():
+ # there can only be one stack where this condition is true for each region
+ # as there can only be one active stack with a given name
+ if (
+ stack_name_or_id in [stack.stack_name, stack.stack_id]
+ and stack.status != "DELETE_COMPLETE"
+ ):
+ return stack
+ return None
+
+
+def find_change_set(
+ account_id: str,
+ region_name: str,
+ cs_name: str,
+ stack_name: Optional[str] = None,
+ active_only: bool = False,
+) -> Optional[StackChangeSet]:
+ store = get_cloudformation_store(account_id, region_name)
+ for stack in store.stacks.values():
+ if active_only and stack.status == StackStatus.DELETE_COMPLETE:
+ continue
+ if stack_name in (stack.stack_name, stack.stack_id, None):
+ for change_set in stack.change_sets:
+ if cs_name in (change_set.change_set_id, change_set.change_set_name):
+ return change_set
+ return None
+
+
+def exports_map(account_id: str, region_name: str):
+ result = {}
+ store = get_cloudformation_store(account_id, region_name)
+ for export in store.exports:
+ result[export["Name"]] = export
+ return result
diff --git a/localstack-core/localstack/services/cloudformation/usage.py b/localstack-core/localstack/services/cloudformation/usage.py
new file mode 100644
index 0000000000000..44ef5d43eb3ce
--- /dev/null
+++ b/localstack-core/localstack/services/cloudformation/usage.py
@@ -0,0 +1,4 @@
+from localstack.utils.analytics.usage import UsageSetCounter
+
+resource_type = UsageSetCounter("cloudformation:resourcetype")
+missing_resource_types = UsageSetCounter("cloudformation:missingresourcetypes")
diff --git a/localstack-core/localstack/services/cloudwatch/__init__.py b/localstack-core/localstack/services/cloudwatch/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/cloudwatch/alarm_scheduler.py b/localstack-core/localstack/services/cloudwatch/alarm_scheduler.py
new file mode 100644
index 0000000000000..2b0675f121450
--- /dev/null
+++ b/localstack-core/localstack/services/cloudwatch/alarm_scheduler.py
@@ -0,0 +1,395 @@
+import json
+import logging
+import math
+import threading
+from datetime import datetime, timedelta, timezone
+from typing import TYPE_CHECKING, List, Optional
+
+from localstack.aws.api.cloudwatch import MetricAlarm, MetricDataQuery, MetricStat, StateValue
+from localstack.aws.connect import connect_to
+from localstack.utils.aws import arns, aws_stack
+from localstack.utils.scheduler import Scheduler
+
+if TYPE_CHECKING:
+ from mypy_boto3_cloudwatch import CloudWatchClient
+
+LOG = logging.getLogger(__name__)
+
+# TODO currently not supported, used for anomaly detection models:
+# LessThanLowerOrGreaterThanUpperThreshold
+# LessThanLowerThreshold
+# GreaterThanUpperThreshold
+COMPARISON_OPS = {
+ "GreaterThanOrEqualToThreshold": (lambda value, threshold: value >= threshold),
+ "GreaterThanThreshold": (lambda value, threshold: value > threshold),
+ "LessThanThreshold": (lambda value, threshold: value < threshold),
+ "LessThanOrEqualToThreshold": (lambda value, threshold: value <= threshold),
+}
+
+DEFAULT_REASON = "Alarm Evaluation"
+THRESHOLD_CROSSED = "Threshold Crossed"
+INSUFFICIENT_DATA = "Insufficient Data"
+
+
+class AlarmScheduler:
+ def __init__(self) -> None:
+ """
+ Creates a new AlarmScheduler, with a Scheduler, that will be started in a new thread
+ """
+ super().__init__()
+ self.scheduler = Scheduler()
+ self.thread = threading.Thread(target=self.scheduler.run, name="cloudwatch-scheduler")
+ self.thread.start()
+ self.scheduled_alarms = {}
+
+ def shutdown_scheduler(self) -> None:
+ """
+ Shutsdown the scheduler, must be called before application stops
+ """
+ self.scheduler.close()
+ self.thread.join(10)
+
+ def schedule_metric_alarm(self, alarm_arn: str) -> None:
+ """(Re-)schedules the alarm, if the alarm is re-scheduled, the running alarm scheduler will be cancelled before
+ starting a new one"""
+ alarm_details = get_metric_alarm_details_for_alarm_arn(alarm_arn)
+ self.delete_scheduler_for_alarm(alarm_arn)
+ if not alarm_details:
+ LOG.warning("Scheduling alarm failed: could not find alarm %s", alarm_arn)
+ return
+
+ if not self._is_alarm_supported(alarm_details):
+ LOG.warning(
+ "Given alarm configuration not yet supported, alarm state will not be evaluated."
+ )
+ return
+
+ period = alarm_details["Period"]
+ evaluation_periods = alarm_details["EvaluationPeriods"]
+ schedule_period = evaluation_periods * period
+
+ def on_error(e):
+ LOG.exception("Error executing scheduled alarm", exc_info=e)
+
+ task = self.scheduler.schedule(
+ func=calculate_alarm_state,
+ period=schedule_period,
+ fixed_rate=True,
+ args=[alarm_arn],
+ on_error=on_error,
+ )
+
+ self.scheduled_alarms[alarm_arn] = task
+
+ def delete_scheduler_for_alarm(self, alarm_arn: str) -> None:
+ """
+ Deletes the recurring scheduler for an alarm
+
+ :param alarm_arn: the arn of the alarm to be removed
+ """
+ task = self.scheduled_alarms.pop(alarm_arn, None)
+ if task:
+ task.cancel()
+
+ def restart_existing_alarms(self) -> None:
+ """
+ Only used re-create persistent state. Reschedules alarms that already exist
+ """
+ for region in aws_stack.get_valid_regions_for_service("cloudwatch"):
+ client = connect_to(region_name=region).cloudwatch
+ result = client.describe_alarms()
+ for metric_alarm in result["MetricAlarms"]:
+ arn = metric_alarm["AlarmArn"]
+ self.schedule_metric_alarm(alarm_arn=arn)
+
+ def _is_alarm_supported(self, alarm_details: MetricAlarm) -> bool:
+ required_parameters = ["Period", "Statistic", "MetricName", "Threshold"]
+ for param in required_parameters:
+ if param not in alarm_details:
+ LOG.debug(
+ "Currently only simple MetricAlarm are supported. Alarm is missing '%s'. ExtendedStatistic is not yet supported.",
+ param,
+ )
+ return False
+ if alarm_details["ComparisonOperator"] not in COMPARISON_OPS:
+ LOG.debug(
+ "ComparisonOperator '%s' not yet supported.",
+ alarm_details["ComparisonOperator"],
+ )
+ return False
+ return True
+
+
+def get_metric_alarm_details_for_alarm_arn(alarm_arn: str) -> Optional[MetricAlarm]:
+ alarm_name = arns.extract_resource_from_arn(alarm_arn).split(":", 1)[1]
+ client = get_cloudwatch_client_for_region_of_alarm(alarm_arn)
+ metric_alarms = client.describe_alarms(AlarmNames=[alarm_name])["MetricAlarms"]
+ return metric_alarms[0] if metric_alarms else None
+
+
+def get_cloudwatch_client_for_region_of_alarm(alarm_arn: str) -> "CloudWatchClient":
+ parsed_arn = arns.parse_arn(alarm_arn)
+ region = parsed_arn["region"]
+ access_key_id = parsed_arn["account"]
+ return connect_to(region_name=region, aws_access_key_id=access_key_id).cloudwatch
+
+
+def generate_metric_query(alarm_details: MetricAlarm) -> MetricDataQuery:
+ """Creates the dict with the required data for MetricDataQueries when calling client.get_metric_data"""
+
+ metric = {
+ "MetricName": alarm_details["MetricName"],
+ }
+ if alarm_details.get("Namespace"):
+ metric["Namespace"] = alarm_details["Namespace"]
+ if alarm_details.get("Dimensions"):
+ metric["Dimensions"] = alarm_details["Dimensions"]
+ return MetricDataQuery(
+ Id=alarm_details["AlarmName"],
+ MetricStat=MetricStat(
+ Metric=metric,
+ Period=alarm_details["Period"],
+ Stat=alarm_details["Statistic"],
+ ),
+ # TODO other fields might be required in the future
+ )
+
+
+def is_threshold_exceeded(metric_values: List[float], alarm_details: MetricAlarm) -> bool:
+ """Evaluates if the threshold is exceeded for the configured alarm and given metric values
+
+ :param metric_values: values to compare against threshold
+ :param alarm_details: Alarm Description, as returned from describe_alarms
+
+ :return: True if threshold is exceeded, else False
+ """
+ threshold = alarm_details["Threshold"]
+ comparison_operator = alarm_details["ComparisonOperator"]
+ treat_missing_data = alarm_details.get("TreatMissingData", "missing")
+ evaluation_periods = alarm_details.get("EvaluationPeriods")
+ datapoints_to_alarm = alarm_details.get("DatapointsToAlarm", evaluation_periods)
+ evaluated_datapoints = []
+ for value in metric_values:
+ if value is None:
+ if treat_missing_data == "breaching":
+ evaluated_datapoints.append(True)
+ elif treat_missing_data == "notBreaching":
+ evaluated_datapoints.append(False)
+ # else we can ignore the data
+ else:
+ evaluated_datapoints.append(COMPARISON_OPS.get(comparison_operator)(value, threshold))
+
+ sum_breaching = evaluated_datapoints.count(True)
+ if sum_breaching >= datapoints_to_alarm:
+ return True
+ return False
+
+
+def is_triggering_premature_alarm(metric_values: List[float], alarm_details: MetricAlarm) -> bool:
+ """
+ Checks if a premature alarm should be triggered.
+ https://docs.aws.amazon.com/AmazonCloudWatch/latest/monitoring/AlarmThatSendsEmail.html#CloudWatch-alarms-avoiding-premature-transition:
+
+ [...] alarms are designed to always go into ALARM state when the oldest available breaching datapoint during the Evaluation
+ Periods number of data points is at least as old as the value of Datapoints to Alarm, and all other more recent data
+ points are breaching or missing. In this case, the alarm goes into ALARM state even if the total number of datapoints
+ available is lower than M (Datapoints to Alarm).
+ This alarm logic applies to M out of N alarms as well.
+ """
+ treat_missing_data = alarm_details.get("TreatMissingData", "missing")
+ if treat_missing_data not in ("missing", "ignore"):
+ return False
+
+ datapoints_to_alarm = alarm_details.get("DatapointsToAlarm", 1)
+ if datapoints_to_alarm > 1:
+ comparison_operator = alarm_details["ComparisonOperator"]
+ threshold = alarm_details["Threshold"]
+ oldest_datapoints = metric_values[:-datapoints_to_alarm]
+ if oldest_datapoints.count(None) == len(oldest_datapoints):
+ if metric_values[-datapoints_to_alarm] and COMPARISON_OPS.get(comparison_operator)(
+ metric_values[-datapoints_to_alarm], threshold
+ ):
+ values = list(filter(None, metric_values[len(oldest_datapoints) :]))
+ if all(
+ COMPARISON_OPS.get(comparison_operator)(value, threshold) for value in values
+ ):
+ return True
+ return False
+
+
+def collect_metric_data(alarm_details: MetricAlarm, client: "CloudWatchClient") -> List[float]:
+ """
+ Collects the metric data for the evaluation interval.
+
+ :param alarm_details: the alarm details as returned by describe_alarms
+ :param client: the cloudwatch client
+ :return: list with data points
+ """
+ metric_values = []
+
+ evaluation_periods = alarm_details["EvaluationPeriods"]
+ period = alarm_details["Period"]
+
+ # From the docs: "Whenever an alarm evaluates whether to change state, CloudWatch attempts to retrieve a higher number of data
+ # points than the number specified as Evaluation Periods."
+ # No other indication, try to calculate a reasonable value:
+ magic_number = max(math.floor(evaluation_periods / 3), 2)
+ collected_periods = evaluation_periods + magic_number
+
+ now = datetime.utcnow().replace(tzinfo=timezone.utc)
+ metric_query = generate_metric_query(alarm_details)
+
+ # get_metric_data needs to be run in a loop, so we also collect empty data points on the right position
+ for i in range(0, collected_periods):
+ start_time = now - timedelta(seconds=period)
+ end_time = now
+ metric_data = client.get_metric_data(
+ MetricDataQueries=[metric_query], StartTime=start_time, EndTime=end_time
+ )["MetricDataResults"][0]
+ val = metric_data["Values"]
+ # oldest datapoint should be at the beginning of the list
+ metric_values.insert(0, val[0] if val else None)
+ now = start_time
+ return metric_values
+
+
+def update_alarm_state(
+ client: "CloudWatchClient",
+ alarm_name: str,
+ current_state: str,
+ desired_state: str,
+ reason: str = DEFAULT_REASON,
+ state_reason_data: dict = None,
+) -> None:
+ """Updates the alarm state, if the current_state is different than the desired_state
+
+ :param client: the cloudwatch client
+ :param alarm_name: the name of the alarm
+ :param current_state: the state the alarm is currently in
+ :param desired_state: the state the alarm should have after updating
+ :param reason: reason why the state is set, will be used to for set_alarm_state
+ :param state_reason_data: data associated with the state change, optional
+ """
+ if current_state == desired_state:
+ return
+ client.set_alarm_state(
+ AlarmName=alarm_name,
+ StateValue=desired_state,
+ StateReason=reason,
+ StateReasonData=json.dumps(state_reason_data),
+ )
+
+
+def calculate_alarm_state(alarm_arn: str) -> None:
+ """
+ Calculates and updates the state of the alarm
+
+ :param alarm_arn: the arn of the alarm to be evaluated
+ """
+ alarm_details = get_metric_alarm_details_for_alarm_arn(alarm_arn)
+ if not alarm_details:
+ LOG.warning("Could not find alarm %s", alarm_arn)
+ return
+
+ client = get_cloudwatch_client_for_region_of_alarm(alarm_arn)
+
+ query_date = datetime.utcnow().strftime(format="%Y-%m-%dT%H:%M:%S+0000")
+ metric_values = collect_metric_data(alarm_details, client)
+
+ state_reason_data = {
+ "version": "1.0",
+ "queryDate": query_date,
+ "period": alarm_details["Period"],
+ "recentDatapoints": [v for v in metric_values if v is not None],
+ "threshold": alarm_details["Threshold"],
+ }
+ if alarm_details.get("Statistic"):
+ state_reason_data["statistic"] = alarm_details["Statistic"]
+ if alarm_details.get("Unit"):
+ state_reason_data["unit"] = alarm_details["Unit"]
+
+ alarm_name = alarm_details["AlarmName"]
+ alarm_state = alarm_details["StateValue"]
+ treat_missing_data = alarm_details.get("TreatMissingData", "missing")
+
+ empty_datapoints = metric_values.count(None)
+ if empty_datapoints == len(metric_values):
+ evaluation_periods = alarm_details["EvaluationPeriods"]
+ details_msg = (
+ f"no datapoints were received for {evaluation_periods} period{'s' if evaluation_periods > 1 else ''} and "
+ f"{evaluation_periods} missing datapoint{'s were' if evaluation_periods > 1 else ' was'} treated as"
+ )
+ if treat_missing_data == "missing":
+ update_alarm_state(
+ client,
+ alarm_name,
+ alarm_state,
+ StateValue.INSUFFICIENT_DATA,
+ f"{INSUFFICIENT_DATA}: {details_msg} [{treat_missing_data.capitalize()}].",
+ state_reason_data=state_reason_data,
+ )
+ elif treat_missing_data == "breaching":
+ update_alarm_state(
+ client,
+ alarm_name,
+ alarm_state,
+ StateValue.ALARM,
+ f"{THRESHOLD_CROSSED}: {details_msg} [{treat_missing_data.capitalize()}].",
+ state_reason_data=state_reason_data,
+ )
+ elif treat_missing_data == "notBreaching":
+ update_alarm_state(
+ client,
+ alarm_name,
+ alarm_state,
+ StateValue.OK,
+ f"{THRESHOLD_CROSSED}: {details_msg} [NonBreaching].",
+ state_reason_data=state_reason_data,
+ )
+ # 'ignore': keep the same state
+ return
+
+ if is_triggering_premature_alarm(metric_values, alarm_details):
+ if treat_missing_data == "missing":
+ update_alarm_state(
+ client,
+ alarm_name,
+ alarm_state,
+ StateValue.ALARM,
+ f"{THRESHOLD_CROSSED}: premature alarm for missing datapoints",
+ state_reason_data=state_reason_data,
+ )
+ # for 'ignore' the state should be retained
+ return
+
+ # collect all non-empty datapoints from the evaluation interval
+ collected_datapoints = [val for val in reversed(metric_values) if val is not None]
+
+ # adding empty data points until amount of data points == "evaluation periods"
+ evaluation_periods = alarm_details["EvaluationPeriods"]
+ while len(collected_datapoints) < evaluation_periods and treat_missing_data in (
+ "breaching",
+ "notBreaching",
+ ):
+ # breaching/non-breaching datapoints will be evaluated
+ # ignore/missing are not relevant
+ collected_datapoints.append(None)
+
+ if is_threshold_exceeded(collected_datapoints, alarm_details):
+ update_alarm_state(
+ client,
+ alarm_name,
+ alarm_state,
+ StateValue.ALARM,
+ THRESHOLD_CROSSED,
+ state_reason_data=state_reason_data,
+ )
+ else:
+ update_alarm_state(
+ client,
+ alarm_name,
+ alarm_state,
+ StateValue.OK,
+ THRESHOLD_CROSSED,
+ state_reason_data=state_reason_data,
+ )
diff --git a/localstack-core/localstack/services/cloudwatch/cloudwatch_database_helper.py b/localstack-core/localstack/services/cloudwatch/cloudwatch_database_helper.py
new file mode 100644
index 0000000000000..43383cf2782ad
--- /dev/null
+++ b/localstack-core/localstack/services/cloudwatch/cloudwatch_database_helper.py
@@ -0,0 +1,460 @@
+import logging
+import os
+import sqlite3
+import threading
+from datetime import datetime, timezone
+from typing import Dict, List, Optional
+
+from localstack import config
+from localstack.aws.api.cloudwatch import MetricData, MetricDataQuery, ScanBy
+from localstack.utils.files import mkdir
+
+LOG = logging.getLogger(__name__)
+
+STAT_TO_SQLITE_AGGREGATION_FUNC = {
+ "Sum": "SUM(value)",
+ "Average": "SUM(value)", # we need to calculate the avg manually as we have also a table with aggregated data
+ "Minimum": "MIN(value)",
+ "Maximum": "MAX(value)",
+ "SampleCount": "Sum(count)",
+}
+
+STAT_TO_SQLITE_COL_NAME_HELPER = {
+ "Sum": "sum",
+ "Average": "sum",
+ "Minimum": "min",
+ "Maximum": "max",
+ "SampleCount": "sample_count",
+}
+
+
+class CloudwatchDatabase:
+ DB_NAME = "metrics.db"
+ CLOUDWATCH_DATA_ROOT: str = os.path.join(config.dirs.data, "cloudwatch")
+ METRICS_DB: str = os.path.join(CLOUDWATCH_DATA_ROOT, DB_NAME)
+ METRICS_DB_READ_ONLY: str = f"file:{METRICS_DB}?mode=ro"
+ TABLE_SINGLE_METRICS = "SINGLE_METRICS"
+ TABLE_AGGREGATED_METRICS = "AGGREGATED_METRICS"
+ DATABASE_LOCK: threading.RLock
+
+ def __init__(self):
+ self.DATABASE_LOCK = threading.RLock()
+ if os.path.exists(self.METRICS_DB):
+ LOG.debug("database for metrics already exists (%s)", self.METRICS_DB)
+ return
+
+ mkdir(self.CLOUDWATCH_DATA_ROOT)
+ with self.DATABASE_LOCK, sqlite3.connect(self.METRICS_DB) as conn:
+ cur = conn.cursor()
+ common_columns = """
+ "id" INTEGER,
+ "account_id" TEXT,
+ "region" TEXT,
+ "metric_name" TEXT,
+ "namespace" TEXT,
+ "timestamp" NUMERIC,
+ "dimensions" TEXT,
+ "unit" TEXT,
+ "storage_resolution" INTEGER
+ """
+ cur.execute(
+ f"""
+ CREATE TABLE "{self.TABLE_SINGLE_METRICS}" (
+ {common_columns},
+ "value" NUMERIC,
+ PRIMARY KEY("id")
+ );
+ """
+ )
+
+ cur.execute(
+ f"""
+ CREATE TABLE "{self.TABLE_AGGREGATED_METRICS}" (
+ {common_columns},
+ "sample_count" NUMERIC,
+ "sum" NUMERIC,
+ "min" NUMERIC,
+ "max" NUMERIC,
+ PRIMARY KEY("id")
+ );
+ """
+ )
+ # create indexes
+ cur.executescript(
+ """
+ CREATE INDEX idx_single_metrics_comp ON SINGLE_METRICS (metric_name, namespace);
+ CREATE INDEX idx_aggregated_metrics_comp ON AGGREGATED_METRICS (metric_name, namespace);
+ """
+ )
+ conn.commit()
+
+ def add_metric_data(
+ self, account_id: str, region: str, namespace: str, metric_data: MetricData
+ ):
+ def _get_current_unix_timestamp_utc():
+ now = datetime.utcnow().replace(tzinfo=timezone.utc)
+ return int(now.timestamp())
+
+ for metric in metric_data:
+ unix_timestamp = (
+ self._convert_timestamp_to_unix(metric.get("Timestamp"))
+ if metric.get("Timestamp")
+ else _get_current_unix_timestamp_utc()
+ )
+
+ inserts = []
+ if metric.get("Value") is not None:
+ inserts.append({"Value": metric.get("Value"), "TimesToInsert": 1})
+ elif metric.get("Values"):
+ counts = metric.get("Counts", [1] * len(metric.get("Values")))
+ inserts = [
+ {"Value": value, "TimesToInsert": int(counts[indexValue])}
+ for indexValue, value in enumerate(metric.get("Values"))
+ ]
+ all_data = []
+ for insert in inserts:
+ times_to_insert = insert.get("TimesToInsert")
+
+ data = (
+ account_id,
+ region,
+ metric.get("MetricName"),
+ namespace,
+ unix_timestamp,
+ self._get_ordered_dimensions_with_separator(metric.get("Dimensions")),
+ metric.get("Unit"),
+ metric.get("StorageResolution"),
+ insert.get("Value"),
+ )
+ all_data.extend([data] * times_to_insert)
+
+ if all_data:
+ with self.DATABASE_LOCK, sqlite3.connect(self.METRICS_DB) as conn:
+ cur = conn.cursor()
+ query = f"INSERT INTO {self.TABLE_SINGLE_METRICS} (account_id, region, metric_name, namespace, timestamp, dimensions, unit, storage_resolution, value) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"
+ cur.executemany(query, all_data)
+ conn.commit()
+
+ if statistic_values := metric.get("StatisticValues"):
+ with self.DATABASE_LOCK, sqlite3.connect(self.METRICS_DB) as conn:
+ cur = conn.cursor()
+ cur.execute(
+ f"""INSERT INTO {self.TABLE_AGGREGATED_METRICS}
+ ("account_id", "region", "metric_name", "namespace", "timestamp", "dimensions", "unit", "storage_resolution", "sample_count", "sum", "min", "max")
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
+ (
+ account_id,
+ region,
+ metric.get("MetricName"),
+ namespace,
+ unix_timestamp,
+ self._get_ordered_dimensions_with_separator(metric.get("Dimensions")),
+ metric.get("Unit"),
+ metric.get("StorageResolution"),
+ statistic_values.get("SampleCount"),
+ statistic_values.get("Sum"),
+ statistic_values.get("Minimum"),
+ statistic_values.get("Maximum"),
+ ),
+ )
+
+ conn.commit()
+
+ def get_units_for_metric_data_stat(
+ self,
+ account_id: str,
+ region: str,
+ start_time: datetime,
+ end_time: datetime,
+ metric_name: str,
+ namespace: str,
+ ):
+ # prepare SQL query
+ start_time_unix = self._convert_timestamp_to_unix(start_time)
+ end_time_unix = self._convert_timestamp_to_unix(end_time)
+
+ data = (
+ account_id,
+ region,
+ namespace,
+ metric_name,
+ start_time_unix,
+ end_time_unix,
+ )
+
+ sql_query = f"""
+ SELECT GROUP_CONCAT(unit) AS unit_values
+ FROM(
+ SELECT
+ DISTINCT COALESCE(unit, 'NULL_VALUE') AS unit
+ FROM (
+ SELECT
+ account_id, region, metric_name, namespace, timestamp, unit
+ FROM {self.TABLE_SINGLE_METRICS}
+ UNION ALL
+ SELECT
+ account_id, region, metric_name, namespace, timestamp, unit
+ FROM {self.TABLE_AGGREGATED_METRICS}
+ ) AS combined
+ WHERE account_id = ? AND region = ?
+ AND namespace = ? AND metric_name = ?
+ AND timestamp >= ? AND timestamp < ?
+ ) AS subquery
+ """
+ with self.DATABASE_LOCK, sqlite3.connect(self.METRICS_DB_READ_ONLY, uri=True) as conn:
+ cur = conn.cursor()
+ cur.execute(
+ sql_query,
+ data,
+ )
+ result_row = cur.fetchone()
+ return result_row[0].split(",") if result_row[0] else ["NULL_VALUE"]
+
+ def get_metric_data_stat(
+ self,
+ account_id: str,
+ region: str,
+ query: MetricDataQuery,
+ start_time: datetime,
+ end_time: datetime,
+ scan_by: str,
+ ) -> Dict[str, List]:
+ metric_stat = query.get("MetricStat")
+ metric = metric_stat.get("Metric")
+ period = metric_stat.get("Period")
+ stat = metric_stat.get("Stat")
+ dimensions = metric.get("Dimensions", [])
+ unit = metric_stat.get("Unit")
+
+ # prepare SQL query
+ start_time_unix = self._convert_timestamp_to_unix(start_time)
+ end_time_unix = self._convert_timestamp_to_unix(end_time)
+
+ data = (
+ account_id,
+ region,
+ metric.get("Namespace"),
+ metric.get("MetricName"),
+ )
+
+ dimension_filter = "AND dimensions is null " if not dimensions else "AND dimensions LIKE ? "
+ if dimensions:
+ data = data + (
+ self._get_ordered_dimensions_with_separator(dimensions, for_search=True),
+ )
+
+ unit_filter = ""
+ if unit:
+ if unit == "NULL_VALUE":
+ unit_filter = "AND unit IS NULL"
+ else:
+ unit_filter = "AND unit = ? "
+ data += (unit,)
+
+ sql_query = f"""
+ SELECT
+ {STAT_TO_SQLITE_AGGREGATION_FUNC[stat]},
+ SUM(count)
+ FROM (
+ SELECT
+ value, 1 as count,
+ account_id, region, metric_name, namespace, timestamp, dimensions, unit, storage_resolution
+ FROM {self.TABLE_SINGLE_METRICS}
+ UNION ALL
+ SELECT
+ {STAT_TO_SQLITE_COL_NAME_HELPER[stat]} as value, sample_count as count,
+ account_id, region, metric_name, namespace, timestamp, dimensions, unit, storage_resolution
+ FROM {self.TABLE_AGGREGATED_METRICS}
+ ) AS combined
+ WHERE account_id = ? AND region = ?
+ AND namespace = ? AND metric_name = ?
+ {dimension_filter}
+ {unit_filter}
+ AND timestamp >= ? AND timestamp < ?
+ ORDER BY timestamp ASC
+ """
+
+ timestamps = []
+ values = []
+ query_params = []
+
+ # Prepare all the query parameters
+ while start_time_unix < end_time_unix:
+ next_start_time = start_time_unix + period
+ query_params.append(data + (start_time_unix, next_start_time))
+ start_time_unix = next_start_time
+
+ all_results = []
+ with self.DATABASE_LOCK, sqlite3.connect(self.METRICS_DB_READ_ONLY, uri=True) as conn:
+ cur = conn.cursor()
+ batch_size = 500
+ for i in range(0, len(query_params), batch_size):
+ batch = query_params[i : i + batch_size]
+ cur.execute(
+ f"""
+ SELECT * FROM (
+ {" UNION ALL ".join(["SELECT * FROM (" + sql_query + ")"] * len(batch))}
+ )
+ """,
+ sum(batch, ()), # flatten the list of tuples in batch into a single tuple
+ )
+ all_results.extend(cur.fetchall())
+
+ # Process results outside the lock
+ for i, result_row in enumerate(all_results):
+ if result_row[1]:
+ calculated_result = (
+ result_row[0] / result_row[1] if stat == "Average" else result_row[0]
+ )
+ timestamps.append(query_params[i][-2]) # start_time_unix
+ values.append(calculated_result)
+
+ # The while loop while always give us the timestamps in ascending order as we start with the start_time
+ # and increase it by the period until we reach the end_time
+ # If we want the timestamps in descending order we need to reverse the list
+ if scan_by is None or scan_by == ScanBy.TimestampDescending:
+ timestamps = timestamps[::-1]
+ values = values[::-1]
+
+ return {
+ "timestamps": timestamps,
+ "values": values,
+ }
+
+ def list_metrics(
+ self,
+ account_id: str,
+ region: str,
+ namespace: str,
+ metric_name: str,
+ dimensions: list[dict[str, str]],
+ ) -> dict:
+ data = (account_id, region)
+
+ namespace_filter = ""
+ if namespace:
+ namespace_filter = " AND namespace = ?"
+ data = data + (namespace,)
+
+ metric_name_filter = ""
+ if metric_name:
+ metric_name_filter = " AND metric_name = ?"
+ data = data + (metric_name,)
+
+ dimension_filter = "" if not dimensions else " AND dimensions LIKE ? "
+ if dimensions:
+ data = data + (
+ self._get_ordered_dimensions_with_separator(dimensions, for_search=True),
+ )
+
+ query = f"""
+ SELECT DISTINCT metric_name, namespace, dimensions
+ FROM (
+ SELECT metric_name, namespace, dimensions, account_id, region, timestamp
+ FROM SINGLE_METRICS
+ UNION
+ SELECT metric_name, namespace, dimensions, account_id, region, timestamp
+ FROM AGGREGATED_METRICS
+ ) AS combined
+ WHERE account_id = ? AND region = ?
+ {namespace_filter}
+ {metric_name_filter}
+ {dimension_filter}
+ ORDER BY timestamp DESC
+ """
+ with self.DATABASE_LOCK, sqlite3.connect(self.METRICS_DB_READ_ONLY, uri=True) as conn:
+ cur = conn.cursor()
+
+ cur.execute(
+ query,
+ data,
+ )
+ metrics_result = [
+ {
+ "metric_name": r[0],
+ "namespace": r[1],
+ "dimensions": self._restore_dimensions_from_string(r[2]),
+ }
+ for r in cur.fetchall()
+ ]
+
+ return {"metrics": metrics_result}
+
+ def clear_tables(self):
+ with self.DATABASE_LOCK, sqlite3.connect(self.METRICS_DB) as conn:
+ cur = conn.cursor()
+ cur.execute(f"DELETE FROM {self.TABLE_SINGLE_METRICS}")
+ cur.execute(f"DELETE FROM {self.TABLE_AGGREGATED_METRICS}")
+ conn.commit()
+ cur.execute("VACUUM")
+ conn.commit()
+
+ def _get_ordered_dimensions_with_separator(self, dims: Optional[List[Dict]], for_search=False):
+ """
+ Returns a string with the dimensions in the format "Name=Value\tName=Value\tName=Value" in order to store the metric
+ with the dimensions in a single column in the database
+
+ :param dims: List of dimensions in the format [{"Name": "name", "Value": "value"}, ...]
+ :param for_search: If True, the dimensions will be formatted in a way that can be used in a LIKE query to search. Default is False. Example: " %{Name}={Value}% "
+ :return: String with the dimensions in the format "Name=Value\tName=Value\tName=Value"
+ """
+ if not dims:
+ return None
+ dims.sort(key=lambda d: d["Name"])
+ dimensions = ""
+ if not for_search:
+ for d in dims:
+ dimensions += f"{d['Name']}={d['Value']}\t" # aws does not allow ascii control characters, we can use it a sa separator
+ else:
+ for d in dims:
+ dimensions += f"%{d.get('Name')}={d.get('Value', '')}%"
+
+ return dimensions
+
+ def _restore_dimensions_from_string(self, dimensions: str):
+ if not dimensions:
+ return None
+ dims = []
+ for d in dimensions.split("\t"):
+ if not d:
+ continue
+ name, value = d.split("=")
+ dims.append({"Name": name, "Value": value})
+
+ return dims
+
+ def _convert_timestamp_to_unix(
+ self, timestamp: datetime
+ ): # TODO verify if this is the standard format, might need to convert
+ return int(timestamp.timestamp())
+
+ def get_all_metric_data(self):
+ with self.DATABASE_LOCK, sqlite3.connect(self.METRICS_DB_READ_ONLY, uri=True) as conn:
+ cur = conn.cursor()
+ """ shape for each data entry:
+ {
+ "ns": r.namespace,
+ "n": r.name,
+ "v": r.value,
+ "t": r.timestamp,
+ "d": [{"n": d.name, "v": d.value} for d in r.dimensions],
+ "account": account-id, # new for v2
+ "region": region_name, # new for v2
+ }
+ """
+ query = f"SELECT namespace, metric_name, value, timestamp, dimensions, account_id, region from {self.TABLE_SINGLE_METRICS}"
+ cur.execute(query)
+ metrics_result = [
+ {
+ "ns": r[0],
+ "n": r[1],
+ "v": r[2],
+ "t": r[3],
+ "d": r[4],
+ "account": r[5],
+ "region": r[6],
+ }
+ for r in cur.fetchall()
+ ]
+ # TODO add aggregated metrics (was not handled by v1 either)
+ return metrics_result
diff --git a/localstack-core/localstack/services/cloudwatch/models.py b/localstack-core/localstack/services/cloudwatch/models.py
new file mode 100644
index 0000000000000..a1246569f4f97
--- /dev/null
+++ b/localstack-core/localstack/services/cloudwatch/models.py
@@ -0,0 +1,109 @@
+import datetime
+from datetime import timezone
+from typing import Dict, List
+
+from localstack.aws.api.cloudwatch import CompositeAlarm, DashboardBody, MetricAlarm, StateValue
+from localstack.services.stores import (
+ AccountRegionBundle,
+ BaseStore,
+ CrossRegionAttribute,
+ LocalAttribute,
+)
+from localstack.utils.aws import arns
+from localstack.utils.tagging import TaggingService
+
+
+class LocalStackMetricAlarm:
+ region: str
+ account_id: str
+ alarm: MetricAlarm
+
+ def __init__(self, account_id: str, region: str, alarm: MetricAlarm):
+ self.account_id = account_id
+ self.region = region
+ self.alarm = alarm
+ self.set_default_attributes()
+
+ def set_default_attributes(self):
+ current_time = datetime.datetime.now(timezone.utc)
+ self.alarm["AlarmArn"] = arns.cloudwatch_alarm_arn(
+ self.alarm["AlarmName"], account_id=self.account_id, region_name=self.region
+ )
+ self.alarm["AlarmConfigurationUpdatedTimestamp"] = current_time
+ self.alarm.setdefault("ActionsEnabled", True)
+ self.alarm.setdefault("OKActions", [])
+ self.alarm.setdefault("AlarmActions", [])
+ self.alarm.setdefault("InsufficientDataActions", [])
+ self.alarm["StateValue"] = StateValue.INSUFFICIENT_DATA
+ self.alarm["StateReason"] = "Unchecked: Initial alarm creation"
+ self.alarm["StateUpdatedTimestamp"] = current_time
+ self.alarm.setdefault("Dimensions", [])
+ self.alarm["StateTransitionedTimestamp"] = current_time
+
+
+class LocalStackCompositeAlarm:
+ region: str
+ account_id: str
+ alarm: CompositeAlarm
+
+ def __init__(self, account_id: str, region: str, alarm: CompositeAlarm):
+ self.account_id = account_id
+ self.region = region
+ self.alarm = alarm
+ self.set_default_attributes()
+
+ def set_default_attributes(self):
+ current_time = datetime.datetime.now(timezone.utc)
+ self.alarm["AlarmArn"] = arns.cloudwatch_alarm_arn(
+ self.alarm["AlarmName"], account_id=self.account_id, region_name=self.region
+ )
+ self.alarm["AlarmConfigurationUpdatedTimestamp"] = current_time
+ self.alarm.setdefault("ActionsEnabled", True)
+ self.alarm.setdefault("OKActions", [])
+ self.alarm.setdefault("AlarmActions", [])
+ self.alarm.setdefault("InsufficientDataActions", [])
+ self.alarm["StateValue"] = StateValue.INSUFFICIENT_DATA
+ self.alarm["StateReason"] = "Unchecked: Initial alarm creation"
+ self.alarm["StateUpdatedTimestamp"] = current_time
+ self.alarm["StateTransitionedTimestamp"] = current_time
+
+
+class LocalStackDashboard:
+ region: str
+ account_id: str
+ dashboard_name: str
+ dashboard_arn: str
+ dashboard_body: DashboardBody
+
+ def __init__(
+ self, account_id: str, region: str, dashboard_name: str, dashboard_body: DashboardBody
+ ):
+ self.account_id = account_id
+ self.region = region
+ self.dashboard_name = dashboard_name
+ self.dashboard_arn = arns.cloudwatch_dashboard_arn(
+ self.dashboard_name, account_id=self.account_id, region_name=self.region
+ )
+ self.dashboard_body = dashboard_body
+ self.last_modified = datetime.datetime.now()
+ self.size = 225 # TODO: calculate size
+
+
+LocalStackAlarm = LocalStackMetricAlarm | LocalStackCompositeAlarm
+
+
+class CloudWatchStore(BaseStore):
+ # maps resource ARN to tags
+ TAGS: TaggingService = CrossRegionAttribute(default=TaggingService)
+
+ # maps resource ARN to alarms
+ alarms: Dict[str, LocalStackAlarm] = LocalAttribute(default=dict)
+
+ # Contains all the Alarm Histories. Per documentation, an alarm history is retained even if the alarm is deleted,
+ # making it necessary to save this at store level
+ histories: List[Dict] = LocalAttribute(default=list)
+
+ dashboards: Dict[str, LocalStackDashboard] = LocalAttribute(default=dict)
+
+
+cloudwatch_stores = AccountRegionBundle("cloudwatch", CloudWatchStore)
diff --git a/localstack-core/localstack/services/cloudwatch/provider.py b/localstack-core/localstack/services/cloudwatch/provider.py
new file mode 100644
index 0000000000000..42e4b5fe94e58
--- /dev/null
+++ b/localstack-core/localstack/services/cloudwatch/provider.py
@@ -0,0 +1,530 @@
+import json
+import logging
+import uuid
+from typing import Any, Optional
+from xml.sax.saxutils import escape
+
+from moto.cloudwatch import cloudwatch_backends
+from moto.cloudwatch.models import CloudWatchBackend, FakeAlarm, MetricDatum
+
+from localstack.aws.accounts import get_account_id_from_access_key_id
+from localstack.aws.api import CommonServiceException, RequestContext, handler
+from localstack.aws.api.cloudwatch import (
+ AlarmNames,
+ AmazonResourceName,
+ CloudwatchApi,
+ DescribeAlarmsInput,
+ DescribeAlarmsOutput,
+ GetMetricDataInput,
+ GetMetricDataOutput,
+ GetMetricStatisticsInput,
+ GetMetricStatisticsOutput,
+ ListTagsForResourceOutput,
+ PutCompositeAlarmInput,
+ PutMetricAlarmInput,
+ StateValue,
+ TagKeyList,
+ TagList,
+ TagResourceOutput,
+ UntagResourceOutput,
+)
+from localstack.aws.connect import connect_to
+from localstack.constants import DEFAULT_AWS_ACCOUNT_ID
+from localstack.http import Request
+from localstack.services import moto
+from localstack.services.cloudwatch.alarm_scheduler import AlarmScheduler
+from localstack.services.edge import ROUTER
+from localstack.services.plugins import SERVICE_PLUGINS, ServiceLifecycleHook
+from localstack.utils.aws import arns
+from localstack.utils.aws.arns import extract_account_id_from_arn, lambda_function_name
+from localstack.utils.aws.request_context import (
+ extract_access_key_id_from_auth_header,
+ extract_region_from_auth_header,
+)
+from localstack.utils.patch import patch
+from localstack.utils.strings import camel_to_snake_case
+from localstack.utils.sync import poll_condition
+from localstack.utils.tagging import TaggingService
+from localstack.utils.threads import start_worker_thread
+
+PATH_GET_RAW_METRICS = "/_aws/cloudwatch/metrics/raw"
+DEPRECATED_PATH_GET_RAW_METRICS = "/cloudwatch/metrics/raw"
+MOTO_INITIAL_UNCHECKED_REASON = "Unchecked: Initial alarm creation"
+
+LOG = logging.getLogger(__name__)
+
+
+@patch(target=FakeAlarm.update_state)
+def update_state(target, self, reason, reason_data, state_value):
+ if reason_data is None:
+ reason_data = ""
+ if self.state_reason == MOTO_INITIAL_UNCHECKED_REASON:
+ old_state = StateValue.INSUFFICIENT_DATA
+ else:
+ old_state = self.state_value
+
+ old_state_reason = self.state_reason
+ old_state_update_timestamp = self.state_updated_timestamp
+ target(self, reason, reason_data, state_value)
+
+ # check the state and trigger required actions
+ if not self.actions_enabled or old_state == self.state_value:
+ return
+ if self.state_value == "OK":
+ actions = self.ok_actions
+ elif self.state_value == "ALARM":
+ actions = self.alarm_actions
+ else:
+ actions = self.insufficient_data_actions
+ for action in actions:
+ data = arns.parse_arn(action)
+ if data["service"] == "sns":
+ service = connect_to(region_name=data["region"], aws_access_key_id=data["account"]).sns
+ subject = f"""{self.state_value}: "{self.name}" in {self.region_name}"""
+ message = create_message_response_update_state_sns(self, old_state)
+ service.publish(TopicArn=action, Subject=subject, Message=message)
+ elif data["service"] == "lambda":
+ service = connect_to(
+ region_name=data["region"], aws_access_key_id=data["account"]
+ ).lambda_
+ message = create_message_response_update_state_lambda(
+ self, old_state, old_state_reason, old_state_update_timestamp
+ )
+ service.invoke(FunctionName=lambda_function_name(action), Payload=message)
+ else:
+ # TODO: support other actions
+ LOG.warning(
+ "Action for service %s not implemented, action '%s' will not be triggered.",
+ data["service"],
+ action,
+ )
+
+
+@patch(target=CloudWatchBackend.put_metric_alarm)
+def put_metric_alarm(
+ target,
+ self,
+ name: str,
+ namespace: str,
+ metric_name: str,
+ comparison_operator: str,
+ evaluation_periods: int,
+ period: int,
+ threshold: float,
+ statistic: str,
+ description: str,
+ dimensions: list[dict[str, str]],
+ alarm_actions: list[str],
+ metric_data_queries: Optional[list[Any]] = None,
+ datapoints_to_alarm: Optional[int] = None,
+ extended_statistic: Optional[str] = None,
+ ok_actions: Optional[list[str]] = None,
+ insufficient_data_actions: Optional[list[str]] = None,
+ unit: Optional[str] = None,
+ actions_enabled: bool = True,
+ treat_missing_data: Optional[str] = None,
+ evaluate_low_sample_count_percentile: Optional[str] = None,
+ threshold_metric_id: Optional[str] = None,
+ rule: Optional[str] = None,
+ tags: Optional[list[dict[str, str]]] = None,
+) -> FakeAlarm:
+ if description:
+ description = escape(description)
+ return target(
+ self,
+ name,
+ namespace,
+ metric_name,
+ comparison_operator,
+ evaluation_periods,
+ period,
+ threshold,
+ statistic,
+ description,
+ dimensions,
+ alarm_actions,
+ metric_data_queries,
+ datapoints_to_alarm,
+ extended_statistic,
+ ok_actions,
+ insufficient_data_actions,
+ unit,
+ actions_enabled,
+ treat_missing_data,
+ evaluate_low_sample_count_percentile,
+ threshold_metric_id,
+ rule,
+ tags,
+ )
+
+
+def create_metric_data_query_from_alarm(alarm: FakeAlarm):
+ # TODO may need to be adapted for other use cases
+ # verified return value with a snapshot test
+ return [
+ {
+ "id": str(uuid.uuid4()),
+ "metricStat": {
+ "metric": {
+ "namespace": alarm.namespace,
+ "name": alarm.metric_name,
+ "dimensions": alarm.dimensions or {},
+ },
+ "period": int(alarm.period),
+ "stat": alarm.statistic,
+ },
+ "returnData": True,
+ }
+ ]
+
+
+def create_message_response_update_state_lambda(
+ alarm: FakeAlarm, old_state, old_state_reason, old_state_timestamp
+):
+ response = {
+ "accountId": extract_account_id_from_arn(alarm.alarm_arn),
+ "alarmArn": alarm.alarm_arn,
+ "alarmData": {
+ "alarmName": alarm.name,
+ "state": {
+ "value": alarm.state_value,
+ "reason": alarm.state_reason,
+ "timestamp": alarm.state_updated_timestamp,
+ },
+ "previousState": {
+ "value": old_state,
+ "reason": old_state_reason,
+ "timestamp": old_state_timestamp,
+ },
+ "configuration": {
+ "description": alarm.description or "",
+ "metrics": alarm.metric_data_queries
+ or create_metric_data_query_from_alarm(
+ alarm
+ ), # TODO: add test with metric_data_queries
+ },
+ },
+ "time": alarm.state_updated_timestamp,
+ "region": alarm.region_name,
+ "source": "aws.cloudwatch",
+ }
+ return json.dumps(response)
+
+
+def create_message_response_update_state_sns(alarm, old_state):
+ response = {
+ "AWSAccountId": extract_account_id_from_arn(alarm.alarm_arn),
+ "OldStateValue": old_state,
+ "AlarmName": alarm.name,
+ "AlarmDescription": alarm.description or "",
+ "AlarmConfigurationUpdatedTimestamp": alarm.configuration_updated_timestamp,
+ "NewStateValue": alarm.state_value,
+ "NewStateReason": alarm.state_reason,
+ "StateChangeTime": alarm.state_updated_timestamp,
+ # the long-name for 'region' should be used - as we don't have it, we use the short name
+ # which needs to be slightly changed to make snapshot tests work
+ "Region": alarm.region_name.replace("-", " ").capitalize(),
+ "AlarmArn": alarm.alarm_arn,
+ "OKActions": alarm.ok_actions or [],
+ "AlarmActions": alarm.alarm_actions or [],
+ "InsufficientDataActions": alarm.insufficient_data_actions or [],
+ }
+
+ # collect trigger details
+ details = {
+ "MetricName": alarm.metric_name or "",
+ "Namespace": alarm.namespace or "",
+ "Unit": alarm.unit or None, # testing with AWS revealed this currently returns None
+ "Period": int(alarm.period) if alarm.period else 0,
+ "EvaluationPeriods": int(alarm.evaluation_periods) if alarm.evaluation_periods else 0,
+ "ComparisonOperator": alarm.comparison_operator or "",
+ "Threshold": float(alarm.threshold) if alarm.threshold else 0.0,
+ "TreatMissingData": alarm.treat_missing_data or "",
+ "EvaluateLowSampleCountPercentile": alarm.evaluate_low_sample_count_percentile or "",
+ }
+
+ # Dimensions not serializable
+ dimensions = []
+ if alarm.dimensions:
+ for d in alarm.dimensions:
+ dimensions.append({"value": d.value, "name": d.name})
+
+ details["Dimensions"] = dimensions or ""
+
+ if alarm.statistic:
+ details["StatisticType"] = "Statistic"
+ details["Statistic"] = camel_to_snake_case(alarm.statistic).upper() # AWS returns uppercase
+ elif alarm.extended_statistic:
+ details["StatisticType"] = "ExtendedStatistic"
+ details["ExtendedStatistic"] = alarm.extended_statistic
+
+ response["Trigger"] = details
+
+ return json.dumps(response)
+
+
+class ValidationError(CommonServiceException):
+ def __init__(self, message: str):
+ super().__init__("ValidationError", message, 400, True)
+
+
+def _set_alarm_actions(context, alarm_names, enabled):
+ backend = cloudwatch_backends[context.account_id][context.region]
+ for name in alarm_names:
+ alarm = backend.alarms.get(name)
+ if alarm:
+ alarm.actions_enabled = enabled
+
+
+def _cleanup_describe_output(alarm):
+ if "Metrics" in alarm and len(alarm["Metrics"]) == 0:
+ alarm.pop("Metrics")
+ reason_data = alarm.get("StateReasonData")
+ if reason_data is not None and reason_data in ("{}", ""):
+ alarm.pop("StateReasonData")
+ if (
+ alarm.get("StateReason", "") == MOTO_INITIAL_UNCHECKED_REASON
+ and alarm.get("StateValue") != StateValue.INSUFFICIENT_DATA
+ ):
+ alarm["StateValue"] = StateValue.INSUFFICIENT_DATA
+
+
+class CloudwatchProvider(CloudwatchApi, ServiceLifecycleHook):
+ """
+ Cloudwatch provider.
+
+ LIMITATIONS:
+ - no alarm rule evaluation
+ """
+
+ def __init__(self):
+ self.tags = TaggingService()
+ self.alarm_scheduler = None
+
+ def on_after_init(self):
+ ROUTER.add(PATH_GET_RAW_METRICS, self.get_raw_metrics)
+ self.start_alarm_scheduler()
+
+ def on_before_state_reset(self):
+ self.shutdown_alarm_scheduler()
+
+ def on_after_state_reset(self):
+ self.start_alarm_scheduler()
+
+ def on_before_state_load(self):
+ self.shutdown_alarm_scheduler()
+
+ def on_after_state_load(self):
+ self.start_alarm_scheduler()
+
+ def restart_alarms(*args):
+ poll_condition(lambda: SERVICE_PLUGINS.is_running("cloudwatch"))
+ self.alarm_scheduler.restart_existing_alarms()
+
+ start_worker_thread(restart_alarms)
+
+ def on_before_stop(self):
+ self.shutdown_alarm_scheduler()
+
+ def start_alarm_scheduler(self):
+ if not self.alarm_scheduler:
+ LOG.debug("starting cloudwatch scheduler")
+ self.alarm_scheduler = AlarmScheduler()
+
+ def shutdown_alarm_scheduler(self):
+ LOG.debug("stopping cloudwatch scheduler")
+ self.alarm_scheduler.shutdown_scheduler()
+ self.alarm_scheduler = None
+
+ def delete_alarms(self, context: RequestContext, alarm_names: AlarmNames, **kwargs) -> None:
+ moto.call_moto(context)
+ for alarm_name in alarm_names:
+ arn = arns.cloudwatch_alarm_arn(alarm_name, context.account_id, context.region)
+ self.alarm_scheduler.delete_scheduler_for_alarm(arn)
+
+ def get_raw_metrics(self, request: Request):
+ region = extract_region_from_auth_header(request.headers)
+ account_id = (
+ get_account_id_from_access_key_id(
+ extract_access_key_id_from_auth_header(request.headers)
+ )
+ or DEFAULT_AWS_ACCOUNT_ID
+ )
+ backend = cloudwatch_backends[account_id][region]
+ if backend:
+ result = [m for m in backend.metric_data if isinstance(m, MetricDatum)]
+ # TODO handle aggregated metrics as well (MetricAggregatedDatum)
+ else:
+ result = []
+
+ result = [
+ {
+ "ns": r.namespace,
+ "n": r.name,
+ "v": r.value,
+ "t": r.timestamp,
+ "d": [{"n": d.name, "v": d.value} for d in r.dimensions],
+ "account": account_id,
+ "region": region,
+ }
+ for r in result
+ ]
+ return {"metrics": result}
+
+ def list_tags_for_resource(
+ self, context: RequestContext, resource_arn: AmazonResourceName, **kwargs
+ ) -> ListTagsForResourceOutput:
+ tags = self.tags.list_tags_for_resource(resource_arn)
+ return ListTagsForResourceOutput(Tags=tags.get("Tags", []))
+
+ def untag_resource(
+ self,
+ context: RequestContext,
+ resource_arn: AmazonResourceName,
+ tag_keys: TagKeyList,
+ **kwargs,
+ ) -> UntagResourceOutput:
+ self.tags.untag_resource(resource_arn, tag_keys)
+ return UntagResourceOutput()
+
+ def tag_resource(
+ self, context: RequestContext, resource_arn: AmazonResourceName, tags: TagList, **kwargs
+ ) -> TagResourceOutput:
+ self.tags.tag_resource(resource_arn, tags)
+ return TagResourceOutput()
+
+ @handler("GetMetricData", expand=False)
+ def get_metric_data(
+ self, context: RequestContext, request: GetMetricDataInput
+ ) -> GetMetricDataOutput:
+ result = moto.call_moto(context)
+ # moto currently uses hardcoded label metric_name + stat
+ # parity tests shows that default is MetricStat, but there might also be a label explicitly set
+ metric_data_queries = request["MetricDataQueries"]
+ for i in range(0, len(metric_data_queries)):
+ metric_query = metric_data_queries[i]
+ label = metric_query.get("Label") or metric_query.get("MetricStat", {}).get(
+ "Metric", {}
+ ).get("MetricName")
+ if label:
+ result["MetricDataResults"][i]["Label"] = label
+ if "Messages" not in result:
+ # parity tests reveals that an empty messages list is added
+ result["Messages"] = []
+ return result
+
+ @handler("PutMetricAlarm", expand=False)
+ def put_metric_alarm(
+ self,
+ context: RequestContext,
+ request: PutMetricAlarmInput,
+ ) -> None:
+ # missing will be the default, when not set (but it will not explicitly be set)
+ if request.get("TreatMissingData", "missing") not in [
+ "breaching",
+ "notBreaching",
+ "ignore",
+ "missing",
+ ]:
+ raise ValidationError(
+ f"The value {request['TreatMissingData']} is not supported for TreatMissingData parameter. Supported values are [breaching, notBreaching, ignore, missing]."
+ )
+ # do some sanity checks:
+ if request.get("Period"):
+ # Valid values are 10, 30, and any multiple of 60.
+ value = request.get("Period")
+ if value not in (10, 30):
+ if value % 60 != 0:
+ raise ValidationError("Period must be 10, 30 or a multiple of 60")
+ if request.get("Statistic"):
+ if request.get("Statistic") not in [
+ "SampleCount",
+ "Average",
+ "Sum",
+ "Minimum",
+ "Maximum",
+ ]:
+ raise ValidationError(
+ f"Value '{request.get('Statistic')}' at 'statistic' failed to satisfy constraint: Member must satisfy enum value set: [Maximum, SampleCount, Sum, Minimum, Average]"
+ )
+
+ moto.call_moto(context)
+
+ name = request.get("AlarmName")
+ arn = arns.cloudwatch_alarm_arn(name, context.account_id, context.region)
+ self.tags.tag_resource(arn, request.get("Tags"))
+ self.alarm_scheduler.schedule_metric_alarm(arn)
+
+ @handler("PutCompositeAlarm", expand=False)
+ def put_composite_alarm(
+ self,
+ context: RequestContext,
+ request: PutCompositeAlarmInput,
+ ) -> None:
+ backend = cloudwatch_backends[context.account_id][context.region]
+ backend.put_metric_alarm(
+ name=request.get("AlarmName"),
+ namespace=None,
+ metric_name=None,
+ metric_data_queries=None,
+ comparison_operator=None,
+ evaluation_periods=None,
+ datapoints_to_alarm=None,
+ period=None,
+ threshold=None,
+ statistic=None,
+ extended_statistic=None,
+ description=request.get("AlarmDescription"),
+ dimensions=[],
+ alarm_actions=request.get("AlarmActions", []),
+ ok_actions=request.get("OKActions", []),
+ insufficient_data_actions=request.get("InsufficientDataActions", []),
+ unit=None,
+ actions_enabled=request.get("ActionsEnabled"),
+ treat_missing_data=None,
+ evaluate_low_sample_count_percentile=None,
+ threshold_metric_id=None,
+ rule=request.get("AlarmRule"),
+ tags=request.get("Tags", []),
+ )
+ LOG.warning(
+ "Composite Alarms configuration is not yet supported, alarm state will not be evaluated"
+ )
+
+ @handler("EnableAlarmActions")
+ def enable_alarm_actions(
+ self, context: RequestContext, alarm_names: AlarmNames, **kwargs
+ ) -> None:
+ _set_alarm_actions(context, alarm_names, enabled=True)
+
+ @handler("DisableAlarmActions")
+ def disable_alarm_actions(
+ self, context: RequestContext, alarm_names: AlarmNames, **kwargs
+ ) -> None:
+ _set_alarm_actions(context, alarm_names, enabled=False)
+
+ @handler("DescribeAlarms", expand=False)
+ def describe_alarms(
+ self, context: RequestContext, request: DescribeAlarmsInput
+ ) -> DescribeAlarmsOutput:
+ response = moto.call_moto(context)
+
+ for c in response["CompositeAlarms"]:
+ _cleanup_describe_output(c)
+ for m in response["MetricAlarms"]:
+ _cleanup_describe_output(m)
+
+ return response
+
+ @handler("GetMetricStatistics", expand=False)
+ def get_metric_statistics(
+ self, context: RequestContext, request: GetMetricStatisticsInput
+ ) -> GetMetricStatisticsOutput:
+ response = moto.call_moto(context)
+
+ # cleanup -> ExtendendStatics is not included in AWS response if it returned empty
+ for datapoint in response.get("Datapoints"):
+ if "ExtendedStatistics" in datapoint and not datapoint.get("ExtendedStatistics"):
+ datapoint.pop("ExtendedStatistics")
+
+ return response
diff --git a/localstack-core/localstack/services/cloudwatch/provider_v2.py b/localstack-core/localstack/services/cloudwatch/provider_v2.py
new file mode 100644
index 0000000000000..d2239691d826d
--- /dev/null
+++ b/localstack-core/localstack/services/cloudwatch/provider_v2.py
@@ -0,0 +1,1109 @@
+import datetime
+import json
+import logging
+import re
+import threading
+import uuid
+from datetime import timezone
+from typing import List
+
+from localstack.aws.api import CommonServiceException, RequestContext, handler
+from localstack.aws.api.cloudwatch import (
+ AccountId,
+ ActionPrefix,
+ AlarmName,
+ AlarmNamePrefix,
+ AlarmNames,
+ AlarmTypes,
+ AmazonResourceName,
+ CloudwatchApi,
+ DashboardBody,
+ DashboardName,
+ DashboardNamePrefix,
+ DashboardNames,
+ Datapoint,
+ DeleteDashboardsOutput,
+ DescribeAlarmHistoryOutput,
+ DescribeAlarmsForMetricOutput,
+ DescribeAlarmsOutput,
+ DimensionFilters,
+ Dimensions,
+ EntityMetricDataList,
+ ExtendedStatistic,
+ ExtendedStatistics,
+ GetDashboardOutput,
+ GetMetricDataMaxDatapoints,
+ GetMetricDataOutput,
+ GetMetricStatisticsOutput,
+ HistoryItemType,
+ IncludeLinkedAccounts,
+ InvalidParameterCombinationException,
+ InvalidParameterValueException,
+ LabelOptions,
+ ListDashboardsOutput,
+ ListMetricsOutput,
+ ListTagsForResourceOutput,
+ MaxRecords,
+ MetricData,
+ MetricDataQueries,
+ MetricDataQuery,
+ MetricDataResult,
+ MetricDataResultMessages,
+ MetricName,
+ MetricStat,
+ Namespace,
+ NextToken,
+ Period,
+ PutCompositeAlarmInput,
+ PutDashboardOutput,
+ PutMetricAlarmInput,
+ RecentlyActive,
+ ResourceNotFound,
+ ScanBy,
+ StandardUnit,
+ StateReason,
+ StateReasonData,
+ StateValue,
+ Statistic,
+ Statistics,
+ StrictEntityValidation,
+ TagKeyList,
+ TagList,
+ TagResourceOutput,
+ Timestamp,
+ UntagResourceOutput,
+)
+from localstack.aws.connect import connect_to
+from localstack.http import Request
+from localstack.services.cloudwatch.alarm_scheduler import AlarmScheduler
+from localstack.services.cloudwatch.cloudwatch_database_helper import CloudwatchDatabase
+from localstack.services.cloudwatch.models import (
+ CloudWatchStore,
+ LocalStackAlarm,
+ LocalStackCompositeAlarm,
+ LocalStackDashboard,
+ LocalStackMetricAlarm,
+ cloudwatch_stores,
+)
+from localstack.services.edge import ROUTER
+from localstack.services.plugins import SERVICE_PLUGINS, ServiceLifecycleHook
+from localstack.state import AssetDirectory, StateVisitor
+from localstack.utils.aws import arns
+from localstack.utils.aws.arns import extract_account_id_from_arn, lambda_function_name
+from localstack.utils.collections import PaginatedList
+from localstack.utils.json import CustomEncoder as JSONEncoder
+from localstack.utils.strings import camel_to_snake_case
+from localstack.utils.sync import poll_condition
+from localstack.utils.threads import start_worker_thread
+from localstack.utils.time import timestamp_millis
+
+PATH_GET_RAW_METRICS = "/_aws/cloudwatch/metrics/raw"
+MOTO_INITIAL_UNCHECKED_REASON = "Unchecked: Initial alarm creation"
+LIST_METRICS_MAX_RESULTS = 500
+# If the values in these fields are not the same, their values are added when generating labels
+LABEL_DIFFERENTIATORS = ["Stat", "Period"]
+HISTORY_VERSION = "1.0"
+
+LOG = logging.getLogger(__name__)
+_STORE_LOCK = threading.RLock()
+AWS_MAX_DATAPOINTS_ACCEPTED: int = 1440
+
+
+class ValidationError(CommonServiceException):
+ # TODO: check this error against AWS (doesn't exist in the API)
+ def __init__(self, message: str):
+ super().__init__("ValidationError", message, 400, True)
+
+
+class InvalidParameterCombination(CommonServiceException):
+ def __init__(self, message: str):
+ super().__init__("InvalidParameterCombination", message, 400, True)
+
+
+def _validate_parameters_for_put_metric_data(metric_data: MetricData) -> None:
+ for index, metric_item in enumerate(metric_data):
+ indexplusone = index + 1
+ if metric_item.get("Value") and metric_item.get("Values"):
+ raise InvalidParameterCombinationException(
+ f"The parameters MetricData.member.{indexplusone}.Value and MetricData.member.{indexplusone}.Values are mutually exclusive and you have specified both."
+ )
+
+ if metric_item.get("StatisticValues") and metric_item.get("Value"):
+ raise InvalidParameterCombinationException(
+ f"The parameters MetricData.member.{indexplusone}.Value and MetricData.member.{indexplusone}.StatisticValues are mutually exclusive and you have specified both."
+ )
+
+ if metric_item.get("Values") and metric_item.get("Counts"):
+ values = metric_item.get("Values")
+ counts = metric_item.get("Counts")
+ if len(values) != len(counts):
+ raise InvalidParameterValueException(
+ f"The parameters MetricData.member.{indexplusone}.Values and MetricData.member.{indexplusone}.Counts must be of the same size."
+ )
+
+
+class CloudwatchProvider(CloudwatchApi, ServiceLifecycleHook):
+ """
+ Cloudwatch provider.
+
+ LIMITATIONS:
+ - simplified composite alarm rule evaluation:
+ - only OR operator is supported
+ - only ALARM expression is supported
+ - only metric alarms can be included in the rule and they should be referenced by ARN only
+ """
+
+ def __init__(self):
+ self.alarm_scheduler: AlarmScheduler = None
+ self.store = None
+ self.cloudwatch_database = CloudwatchDatabase()
+
+ @staticmethod
+ def get_store(account_id: str, region: str) -> CloudWatchStore:
+ return cloudwatch_stores[account_id][region]
+
+ def accept_state_visitor(self, visitor: StateVisitor):
+ visitor.visit(cloudwatch_stores)
+ visitor.visit(AssetDirectory(self.service, CloudwatchDatabase.CLOUDWATCH_DATA_ROOT))
+
+ def on_after_init(self):
+ ROUTER.add(PATH_GET_RAW_METRICS, self.get_raw_metrics)
+ self.start_alarm_scheduler()
+
+ def on_before_state_reset(self):
+ self.shutdown_alarm_scheduler()
+ self.cloudwatch_database.clear_tables()
+
+ def on_after_state_reset(self):
+ self.start_alarm_scheduler()
+
+ def on_before_state_load(self):
+ self.shutdown_alarm_scheduler()
+
+ def on_after_state_load(self):
+ self.start_alarm_scheduler()
+
+ def restart_alarms(*args):
+ poll_condition(lambda: SERVICE_PLUGINS.is_running("cloudwatch"))
+ self.alarm_scheduler.restart_existing_alarms()
+
+ start_worker_thread(restart_alarms)
+
+ def on_before_stop(self):
+ self.shutdown_alarm_scheduler()
+
+ def start_alarm_scheduler(self):
+ if not self.alarm_scheduler:
+ LOG.debug("starting cloudwatch scheduler")
+ self.alarm_scheduler = AlarmScheduler()
+
+ def shutdown_alarm_scheduler(self):
+ LOG.debug("stopping cloudwatch scheduler")
+ self.alarm_scheduler.shutdown_scheduler()
+ self.alarm_scheduler = None
+
+ def delete_alarms(self, context: RequestContext, alarm_names: AlarmNames, **kwargs) -> None:
+ """
+ Delete alarms.
+ """
+ with _STORE_LOCK:
+ for alarm_name in alarm_names:
+ alarm_arn = arns.cloudwatch_alarm_arn(
+ alarm_name, account_id=context.account_id, region_name=context.region
+ ) # obtain alarm ARN from alarm name
+ self.alarm_scheduler.delete_scheduler_for_alarm(alarm_arn)
+ store = self.get_store(context.account_id, context.region)
+ store.alarms.pop(alarm_arn, None)
+
+ def put_metric_data(
+ self,
+ context: RequestContext,
+ namespace: Namespace,
+ metric_data: MetricData = None,
+ entity_metric_data: EntityMetricDataList = None,
+ strict_entity_validation: StrictEntityValidation = None,
+ **kwargs,
+ ) -> None:
+ # TODO add support for entity_metric_data and strict_entity_validation
+ _validate_parameters_for_put_metric_data(metric_data)
+
+ self.cloudwatch_database.add_metric_data(
+ context.account_id, context.region, namespace, metric_data
+ )
+
+ def get_metric_data(
+ self,
+ context: RequestContext,
+ metric_data_queries: MetricDataQueries,
+ start_time: Timestamp,
+ end_time: Timestamp,
+ next_token: NextToken = None,
+ scan_by: ScanBy = None,
+ max_datapoints: GetMetricDataMaxDatapoints = None,
+ label_options: LabelOptions = None,
+ **kwargs,
+ ) -> GetMetricDataOutput:
+ results: List[MetricDataResult] = []
+ limit = max_datapoints or 100_800
+ messages: MetricDataResultMessages = []
+ nxt = None
+ label_additions = []
+
+ for diff in LABEL_DIFFERENTIATORS:
+ non_unique = []
+ for query in metric_data_queries:
+ non_unique.append(query["MetricStat"][diff])
+ if len(set(non_unique)) > 1:
+ label_additions.append(diff)
+
+ for query in metric_data_queries:
+ query_result = self.cloudwatch_database.get_metric_data_stat(
+ account_id=context.account_id,
+ region=context.region,
+ query=query,
+ start_time=start_time,
+ end_time=end_time,
+ scan_by=scan_by,
+ )
+ if query_result.get("messages"):
+ messages.extend(query_result.get("messages"))
+
+ label = query.get("Label") or f"{query['MetricStat']['Metric']['MetricName']}"
+ # TODO: does this happen even if a label is set in the query?
+ for label_addition in label_additions:
+ label = f"{label} {query['MetricStat'][label_addition]}"
+
+ timestamps = query_result.get("timestamps", {})
+ values = query_result.get("values", {})
+
+ # Paginate
+ timestamp_value_dicts = [
+ {
+ "Timestamp": timestamp,
+ "Value": value,
+ }
+ for timestamp, value in zip(timestamps, values)
+ ]
+
+ pagination = PaginatedList(timestamp_value_dicts)
+ timestamp_page, nxt = pagination.get_page(
+ lambda item: item.get("Timestamp"),
+ next_token=next_token,
+ page_size=limit,
+ )
+
+ timestamps = [item.get("Timestamp") for item in timestamp_page]
+ values = [item.get("Value") for item in timestamp_page]
+
+ metric_data_result = {
+ "Id": query.get("Id"),
+ "Label": label,
+ "StatusCode": "Complete",
+ "Timestamps": timestamps,
+ "Values": values,
+ }
+ results.append(MetricDataResult(**metric_data_result))
+
+ return GetMetricDataOutput(MetricDataResults=results, NextToken=nxt, Messages=messages)
+
+ def set_alarm_state(
+ self,
+ context: RequestContext,
+ alarm_name: AlarmName,
+ state_value: StateValue,
+ state_reason: StateReason,
+ state_reason_data: StateReasonData = None,
+ **kwargs,
+ ) -> None:
+ try:
+ if state_reason_data:
+ state_reason_data = json.loads(state_reason_data)
+ except ValueError:
+ raise InvalidParameterValueException(
+ "TODO: check right error message: Json was not correctly formatted"
+ )
+ with _STORE_LOCK:
+ store = self.get_store(context.account_id, context.region)
+ alarm = store.alarms.get(
+ arns.cloudwatch_alarm_arn(
+ alarm_name, account_id=context.account_id, region_name=context.region
+ )
+ )
+ if not alarm:
+ raise ResourceNotFound()
+
+ old_state = alarm.alarm["StateValue"]
+ if state_value not in ("OK", "ALARM", "INSUFFICIENT_DATA"):
+ raise ValidationError(
+ f"1 validation error detected: Value '{state_value}' at 'stateValue' failed to satisfy constraint: Member must satisfy enum value set: [INSUFFICIENT_DATA, ALARM, OK]"
+ )
+
+ old_state_reason = alarm.alarm["StateReason"]
+ old_state_update_timestamp = alarm.alarm["StateUpdatedTimestamp"]
+
+ if old_state == state_value:
+ return
+
+ alarm.alarm["StateTransitionedTimestamp"] = datetime.datetime.now(timezone.utc)
+ # update startDate (=last ALARM date) - should only update when a new alarm is triggered
+ # the date is only updated if we have a reason-data, which is set by an alarm
+ if state_reason_data:
+ state_reason_data["startDate"] = state_reason_data.get("queryDate")
+
+ self._update_state(
+ context,
+ alarm,
+ state_value,
+ state_reason,
+ state_reason_data,
+ )
+
+ self._evaluate_composite_alarms(context, alarm)
+
+ if not alarm.alarm["ActionsEnabled"]:
+ return
+ if state_value == "OK":
+ actions = alarm.alarm["OKActions"]
+ elif state_value == "ALARM":
+ actions = alarm.alarm["AlarmActions"]
+ else:
+ actions = alarm.alarm["InsufficientDataActions"]
+ for action in actions:
+ data = arns.parse_arn(action)
+ # test for sns - can this be done in a more generic way?
+ if data["service"] == "sns":
+ service = connect_to(
+ region_name=data["region"], aws_access_key_id=data["account"]
+ ).sns
+ subject = f"""{state_value}: "{alarm_name}" in {context.region}"""
+ message = create_message_response_update_state_sns(alarm, old_state)
+ service.publish(TopicArn=action, Subject=subject, Message=message)
+ elif data["service"] == "lambda":
+ service = connect_to(
+ region_name=data["region"], aws_access_key_id=data["account"]
+ ).lambda_
+ message = create_message_response_update_state_lambda(
+ alarm, old_state, old_state_reason, old_state_update_timestamp
+ )
+ service.invoke(FunctionName=lambda_function_name(action), Payload=message)
+ else:
+ # TODO: support other actions
+ LOG.warning(
+ "Action for service %s not implemented, action '%s' will not be triggered.",
+ data["service"],
+ action,
+ )
+
+ def get_raw_metrics(self, request: Request):
+ """this feature was introduced with https://github.com/localstack/localstack/pull/3535
+ # in the meantime, it required a valid aws-header so that the account-id/region could be extracted
+ # with the new implementation, we want to return all data, but add the account-id/region as additional attributes
+
+ # TODO endpoint should be refactored or deprecated at some point
+ # - result should be paginated
+ # - include aggregated metrics (but we would also need to change/adapt the shape of "metrics" that we return)
+ :returns: json {"metrics": [{"ns": "namespace", "n": "metric_name", "v": value, "t": timestamp,
+ "d": [],"account": account, "region": region}]}
+ """
+ return {"metrics": self.cloudwatch_database.get_all_metric_data() or []}
+
+ @handler("PutMetricAlarm", expand=False)
+ def put_metric_alarm(self, context: RequestContext, request: PutMetricAlarmInput) -> None:
+ # missing will be the default, when not set (but it will not explicitly be set)
+ if request.get("TreatMissingData", "missing") not in [
+ "breaching",
+ "notBreaching",
+ "ignore",
+ "missing",
+ ]:
+ raise ValidationError(
+ f"The value {request['TreatMissingData']} is not supported for TreatMissingData parameter. Supported values are [breaching, notBreaching, ignore, missing]."
+ )
+ # do some sanity checks:
+ if request.get("Period"):
+ # Valid values are 10, 30, and any multiple of 60.
+ value = request.get("Period")
+ if value not in (10, 30):
+ if value % 60 != 0:
+ raise ValidationError("Period must be 10, 30 or a multiple of 60")
+ if request.get("Statistic"):
+ if request.get("Statistic") not in [
+ "SampleCount",
+ "Average",
+ "Sum",
+ "Minimum",
+ "Maximum",
+ ]:
+ raise ValidationError(
+ f"Value '{request.get('Statistic')}' at 'statistic' failed to satisfy constraint: Member must satisfy enum value set: [Maximum, SampleCount, Sum, Minimum, Average]"
+ )
+
+ extended_statistic = request.get("ExtendedStatistic")
+ if extended_statistic and not extended_statistic.startswith("p"):
+ raise InvalidParameterValueException(
+ f"The value {extended_statistic} for parameter ExtendedStatistic is not supported."
+ )
+ evaluate_low_sample_count_percentile = request.get("EvaluateLowSampleCountPercentile")
+ if evaluate_low_sample_count_percentile and evaluate_low_sample_count_percentile not in (
+ "evaluate",
+ "ignore",
+ ):
+ raise ValidationError(
+ f"Option {evaluate_low_sample_count_percentile} is not supported. "
+ "Supported options for parameter EvaluateLowSampleCountPercentile are evaluate and ignore."
+ )
+ with _STORE_LOCK:
+ store = self.get_store(context.account_id, context.region)
+ metric_alarm = LocalStackMetricAlarm(context.account_id, context.region, {**request})
+ alarm_arn = metric_alarm.alarm["AlarmArn"]
+ store.alarms[alarm_arn] = metric_alarm
+ self.alarm_scheduler.schedule_metric_alarm(alarm_arn)
+
+ @handler("PutCompositeAlarm", expand=False)
+ def put_composite_alarm(self, context: RequestContext, request: PutCompositeAlarmInput) -> None:
+ with _STORE_LOCK:
+ store = self.get_store(context.account_id, context.region)
+ composite_alarm = LocalStackCompositeAlarm(
+ context.account_id, context.region, {**request}
+ )
+
+ alarm_rule = composite_alarm.alarm["AlarmRule"]
+ rule_expression_validation_result = self._validate_alarm_rule_expression(alarm_rule)
+ [LOG.warning(w) for w in rule_expression_validation_result]
+
+ alarm_arn = composite_alarm.alarm["AlarmArn"]
+ store.alarms[alarm_arn] = composite_alarm
+
+ def describe_alarms(
+ self,
+ context: RequestContext,
+ alarm_names: AlarmNames = None,
+ alarm_name_prefix: AlarmNamePrefix = None,
+ alarm_types: AlarmTypes = None,
+ children_of_alarm_name: AlarmName = None,
+ parents_of_alarm_name: AlarmName = None,
+ state_value: StateValue = None,
+ action_prefix: ActionPrefix = None,
+ max_records: MaxRecords = None,
+ next_token: NextToken = None,
+ **kwargs,
+ ) -> DescribeAlarmsOutput:
+ store = self.get_store(context.account_id, context.region)
+ alarms = list(store.alarms.values())
+ if action_prefix:
+ alarms = [a.alarm for a in alarms if a.alarm["AlarmAction"].startswith(action_prefix)]
+ elif alarm_name_prefix:
+ alarms = [a.alarm for a in alarms if a.alarm["AlarmName"].startswith(alarm_name_prefix)]
+ elif alarm_names:
+ alarms = [a.alarm for a in alarms if a.alarm["AlarmName"] in alarm_names]
+ elif state_value:
+ alarms = [a.alarm for a in alarms if a.alarm["StateValue"] == state_value]
+ else:
+ alarms = [a.alarm for a in list(store.alarms.values())]
+
+ # TODO: Pagination
+ metric_alarms = [a for a in alarms if a.get("AlarmRule") is None]
+ composite_alarms = [a for a in alarms if a.get("AlarmRule") is not None]
+ return DescribeAlarmsOutput(CompositeAlarms=composite_alarms, MetricAlarms=metric_alarms)
+
+ def describe_alarms_for_metric(
+ self,
+ context: RequestContext,
+ metric_name: MetricName,
+ namespace: Namespace,
+ statistic: Statistic = None,
+ extended_statistic: ExtendedStatistic = None,
+ dimensions: Dimensions = None,
+ period: Period = None,
+ unit: StandardUnit = None,
+ **kwargs,
+ ) -> DescribeAlarmsForMetricOutput:
+ store = self.get_store(context.account_id, context.region)
+ alarms = [
+ a.alarm
+ for a in store.alarms.values()
+ if isinstance(a, LocalStackMetricAlarm)
+ and a.alarm.get("MetricName") == metric_name
+ and a.alarm.get("Namespace") == namespace
+ ]
+
+ if statistic:
+ alarms = [a for a in alarms if a.get("Statistic") == statistic]
+ if dimensions:
+ alarms = [a for a in alarms if a.get("Dimensions") == dimensions]
+ if period:
+ alarms = [a for a in alarms if a.get("Period") == period]
+ if unit:
+ alarms = [a for a in alarms if a.get("Unit") == unit]
+ return DescribeAlarmsForMetricOutput(MetricAlarms=alarms)
+
+ def list_tags_for_resource(
+ self, context: RequestContext, resource_arn: AmazonResourceName, **kwargs
+ ) -> ListTagsForResourceOutput:
+ store = self.get_store(context.account_id, context.region)
+ tags = store.TAGS.list_tags_for_resource(resource_arn)
+ return ListTagsForResourceOutput(Tags=tags.get("Tags", []))
+
+ def untag_resource(
+ self,
+ context: RequestContext,
+ resource_arn: AmazonResourceName,
+ tag_keys: TagKeyList,
+ **kwargs,
+ ) -> UntagResourceOutput:
+ store = self.get_store(context.account_id, context.region)
+ store.TAGS.untag_resource(resource_arn, tag_keys)
+ return UntagResourceOutput()
+
+ def tag_resource(
+ self, context: RequestContext, resource_arn: AmazonResourceName, tags: TagList, **kwargs
+ ) -> TagResourceOutput:
+ store = self.get_store(context.account_id, context.region)
+ store.TAGS.tag_resource(resource_arn, tags)
+ return TagResourceOutput()
+
+ def put_dashboard(
+ self,
+ context: RequestContext,
+ dashboard_name: DashboardName,
+ dashboard_body: DashboardBody,
+ **kwargs,
+ ) -> PutDashboardOutput:
+ pattern = r"^[a-zA-Z0-9_-]+$"
+ if not re.match(pattern, dashboard_name):
+ raise InvalidParameterValueException(
+ "The value for field DashboardName contains invalid characters. "
+ "It can only contain alphanumerics, dash (-) and underscore (_).\n"
+ )
+
+ store = self.get_store(context.account_id, context.region)
+ store.dashboards[dashboard_name] = LocalStackDashboard(
+ context.account_id, context.region, dashboard_name, dashboard_body
+ )
+ return PutDashboardOutput()
+
+ def get_dashboard(
+ self, context: RequestContext, dashboard_name: DashboardName, **kwargs
+ ) -> GetDashboardOutput:
+ store = self.get_store(context.account_id, context.region)
+ dashboard = store.dashboards.get(dashboard_name)
+ if not dashboard:
+ raise InvalidParameterValueException(f"Dashboard {dashboard_name} does not exist.")
+
+ return GetDashboardOutput(
+ DashboardName=dashboard_name,
+ DashboardBody=dashboard.dashboard_body,
+ DashboardArn=dashboard.dashboard_arn,
+ )
+
+ def delete_dashboards(
+ self, context: RequestContext, dashboard_names: DashboardNames, **kwargs
+ ) -> DeleteDashboardsOutput:
+ store = self.get_store(context.account_id, context.region)
+ for dashboard_name in dashboard_names:
+ store.dashboards.pop(dashboard_name, None)
+ return DeleteDashboardsOutput()
+
+ def list_dashboards(
+ self,
+ context: RequestContext,
+ dashboard_name_prefix: DashboardNamePrefix = None,
+ next_token: NextToken = None,
+ **kwargs,
+ ) -> ListDashboardsOutput:
+ store = self.get_store(context.account_id, context.region)
+ dashboard_names = list(store.dashboards.keys())
+ dashboard_names = [
+ name for name in dashboard_names if name.startswith(dashboard_name_prefix or "")
+ ]
+
+ entries = [
+ {
+ "DashboardName": name,
+ "DashboardArn": store.dashboards[name].dashboard_arn,
+ "LastModified": store.dashboards[name].last_modified,
+ "Size": store.dashboards[name].size,
+ }
+ for name in dashboard_names
+ ]
+ return ListDashboardsOutput(
+ DashboardEntries=entries,
+ )
+
+ def list_metrics(
+ self,
+ context: RequestContext,
+ namespace: Namespace = None,
+ metric_name: MetricName = None,
+ dimensions: DimensionFilters = None,
+ next_token: NextToken = None,
+ recently_active: RecentlyActive = None,
+ include_linked_accounts: IncludeLinkedAccounts = None,
+ owning_account: AccountId = None,
+ **kwargs,
+ ) -> ListMetricsOutput:
+ result = self.cloudwatch_database.list_metrics(
+ context.account_id,
+ context.region,
+ namespace,
+ metric_name,
+ dimensions or [],
+ )
+
+ metrics = [
+ {
+ "Namespace": metric.get("namespace"),
+ "MetricName": metric.get("metric_name"),
+ "Dimensions": metric.get("dimensions"),
+ }
+ for metric in result.get("metrics", [])
+ ]
+ aliases_list = PaginatedList(metrics)
+ page, nxt = aliases_list.get_page(
+ lambda metric: f"{metric.get('Namespace')}-{metric.get('MetricName')}-{metric.get('Dimensions')}",
+ next_token=next_token,
+ page_size=LIST_METRICS_MAX_RESULTS,
+ )
+ return ListMetricsOutput(Metrics=page, NextToken=nxt)
+
+ def get_metric_statistics(
+ self,
+ context: RequestContext,
+ namespace: Namespace,
+ metric_name: MetricName,
+ start_time: Timestamp,
+ end_time: Timestamp,
+ period: Period,
+ dimensions: Dimensions = None,
+ statistics: Statistics = None,
+ extended_statistics: ExtendedStatistics = None,
+ unit: StandardUnit = None,
+ **kwargs,
+ ) -> GetMetricStatisticsOutput:
+ start_time_unix = int(start_time.timestamp())
+ end_time_unix = int(end_time.timestamp())
+
+ if not start_time_unix < end_time_unix:
+ raise InvalidParameterValueException(
+ "The parameter StartTime must be less than the parameter EndTime."
+ )
+
+ expected_datapoints = (end_time_unix - start_time_unix) / period
+
+ if expected_datapoints > AWS_MAX_DATAPOINTS_ACCEPTED:
+ raise InvalidParameterCombination(
+ f"You have requested up to {int(expected_datapoints)} datapoints, which exceeds the limit of {AWS_MAX_DATAPOINTS_ACCEPTED}. "
+ f"You may reduce the datapoints requested by increasing Period, or decreasing the time range."
+ )
+
+ stat_datapoints = {}
+
+ units = (
+ [unit]
+ if unit
+ else self.cloudwatch_database.get_units_for_metric_data_stat(
+ account_id=context.account_id,
+ region=context.region,
+ start_time=start_time,
+ end_time=end_time,
+ metric_name=metric_name,
+ namespace=namespace,
+ )
+ )
+
+ for stat in statistics:
+ for selected_unit in units:
+ query_result = self.cloudwatch_database.get_metric_data_stat(
+ account_id=context.account_id,
+ region=context.region,
+ start_time=start_time,
+ end_time=end_time,
+ scan_by="TimestampDescending",
+ query=MetricDataQuery(
+ MetricStat=MetricStat(
+ Metric={
+ "MetricName": metric_name,
+ "Namespace": namespace,
+ "Dimensions": dimensions or [],
+ },
+ Period=period,
+ Stat=stat,
+ Unit=selected_unit,
+ )
+ ),
+ )
+
+ timestamps = query_result.get("timestamps", [])
+ values = query_result.get("values", [])
+ for i, timestamp in enumerate(timestamps):
+ stat_datapoints.setdefault(selected_unit, {})
+ stat_datapoints[selected_unit].setdefault(timestamp, {})
+ stat_datapoints[selected_unit][timestamp][stat] = values[i]
+ stat_datapoints[selected_unit][timestamp]["Unit"] = selected_unit
+
+ datapoints: list[Datapoint] = []
+ for selected_unit, results in stat_datapoints.items():
+ for timestamp, stats in results.items():
+ datapoints.append(
+ Datapoint(
+ Timestamp=timestamp,
+ SampleCount=stats.get("SampleCount"),
+ Average=stats.get("Average"),
+ Sum=stats.get("Sum"),
+ Minimum=stats.get("Minimum"),
+ Maximum=stats.get("Maximum"),
+ Unit="None" if selected_unit == "NULL_VALUE" else selected_unit,
+ )
+ )
+
+ return GetMetricStatisticsOutput(Datapoints=datapoints, Label=metric_name)
+
+ def _update_state(
+ self,
+ context: RequestContext,
+ alarm: LocalStackAlarm,
+ state_value: str,
+ state_reason: str,
+ state_reason_data: dict = None,
+ ):
+ old_state = alarm.alarm["StateValue"]
+ old_state_reason = alarm.alarm["StateReason"]
+ store = self.get_store(context.account_id, context.region)
+ current_time = datetime.datetime.now()
+ # version is not present in state reason data for composite alarm, hence the check
+ if state_reason_data and isinstance(alarm, LocalStackMetricAlarm):
+ state_reason_data["version"] = HISTORY_VERSION
+ history_data = {
+ "version": HISTORY_VERSION,
+ "oldState": {"stateValue": old_state, "stateReason": old_state_reason},
+ "newState": {
+ "stateValue": state_value,
+ "stateReason": state_reason,
+ "stateReasonData": state_reason_data,
+ },
+ }
+ store.histories.append(
+ {
+ "Timestamp": timestamp_millis(alarm.alarm["StateUpdatedTimestamp"]),
+ "HistoryItemType": HistoryItemType.StateUpdate,
+ "AlarmName": alarm.alarm["AlarmName"],
+ "HistoryData": json.dumps(history_data),
+ "HistorySummary": f"Alarm updated from {old_state} to {state_value}",
+ "AlarmType": "MetricAlarm"
+ if isinstance(alarm, LocalStackMetricAlarm)
+ else "CompositeAlarm",
+ }
+ )
+ alarm.alarm["StateValue"] = state_value
+ alarm.alarm["StateReason"] = state_reason
+ if state_reason_data:
+ alarm.alarm["StateReasonData"] = json.dumps(state_reason_data)
+ alarm.alarm["StateUpdatedTimestamp"] = current_time
+
+ def disable_alarm_actions(
+ self, context: RequestContext, alarm_names: AlarmNames, **kwargs
+ ) -> None:
+ self._set_alarm_actions(context, alarm_names, enabled=False)
+
+ def enable_alarm_actions(
+ self, context: RequestContext, alarm_names: AlarmNames, **kwargs
+ ) -> None:
+ self._set_alarm_actions(context, alarm_names, enabled=True)
+
+ def _set_alarm_actions(self, context, alarm_names, enabled):
+ store = self.get_store(context.account_id, context.region)
+ for name in alarm_names:
+ alarm_arn = arns.cloudwatch_alarm_arn(
+ name, account_id=context.account_id, region_name=context.region
+ )
+ alarm = store.alarms.get(alarm_arn)
+ if alarm:
+ alarm.alarm["ActionsEnabled"] = enabled
+
+ def describe_alarm_history(
+ self,
+ context: RequestContext,
+ alarm_name: AlarmName = None,
+ alarm_types: AlarmTypes = None,
+ history_item_type: HistoryItemType = None,
+ start_date: Timestamp = None,
+ end_date: Timestamp = None,
+ max_records: MaxRecords = None,
+ next_token: NextToken = None,
+ scan_by: ScanBy = None,
+ **kwargs,
+ ) -> DescribeAlarmHistoryOutput:
+ store = self.get_store(context.account_id, context.region)
+ history = store.histories
+ if alarm_name:
+ history = [h for h in history if h["AlarmName"] == alarm_name]
+
+ def _get_timestamp(input: dict):
+ if timestamp_string := input.get("Timestamp"):
+ return datetime.datetime.fromisoformat(timestamp_string)
+ return None
+
+ if start_date:
+ history = [h for h in history if (date := _get_timestamp(h)) and date >= start_date]
+ if end_date:
+ history = [h for h in history if (date := _get_timestamp(h)) and date <= end_date]
+ return DescribeAlarmHistoryOutput(AlarmHistoryItems=history)
+
+ def _evaluate_composite_alarms(self, context: RequestContext, triggering_alarm):
+ # TODO either pass store as a parameter or acquire RLock (with _STORE_LOCK:)
+ # everything works ok now but better ensure protection of critical section in front of future changes
+ store = self.get_store(context.account_id, context.region)
+ alarms = list(store.alarms.values())
+ composite_alarms = [a for a in alarms if isinstance(a, LocalStackCompositeAlarm)]
+ for composite_alarm in composite_alarms:
+ self._evaluate_composite_alarm(context, composite_alarm, triggering_alarm)
+
+ def _evaluate_composite_alarm(self, context, composite_alarm, triggering_alarm):
+ store = self.get_store(context.account_id, context.region)
+ alarm_rule = composite_alarm.alarm["AlarmRule"]
+ rule_expression_validation = self._validate_alarm_rule_expression(alarm_rule)
+ if rule_expression_validation:
+ LOG.warning(
+ "Alarm rule contains unsupported expressions and will not be evaluated: %s",
+ rule_expression_validation,
+ )
+ return
+ new_state_value = StateValue.OK
+ # assuming that a rule consists only of ALARM evaluations of metric alarms, with OR logic applied
+ for metric_alarm_arn in self._get_alarm_arns(alarm_rule):
+ metric_alarm = store.alarms.get(metric_alarm_arn)
+ if not metric_alarm:
+ LOG.warning(
+ "Alarm rule won't be evaluated as there is no alarm with ARN %s",
+ metric_alarm_arn,
+ )
+ return
+ if metric_alarm.alarm["StateValue"] == StateValue.ALARM:
+ triggering_alarm = metric_alarm
+ new_state_value = StateValue.ALARM
+ break
+ old_state_value = composite_alarm.alarm["StateValue"]
+ if old_state_value == new_state_value:
+ return
+ triggering_alarm_arn = triggering_alarm.alarm.get("AlarmArn")
+ triggering_alarm_state = triggering_alarm.alarm.get("StateValue")
+ triggering_alarm_state_change_timestamp = triggering_alarm.alarm.get(
+ "StateTransitionedTimestamp"
+ )
+ state_reason_formatted_timestamp = triggering_alarm_state_change_timestamp.strftime(
+ "%A %d %B, %Y %H:%M:%S %Z"
+ )
+ state_reason = (
+ f"{triggering_alarm_arn} "
+ f"transitioned to {triggering_alarm_state} "
+ f"at {state_reason_formatted_timestamp}"
+ )
+ state_reason_data = {
+ "triggeringAlarms": [
+ {
+ "arn": triggering_alarm_arn,
+ "state": {
+ "value": triggering_alarm_state,
+ "timestamp": timestamp_millis(triggering_alarm_state_change_timestamp),
+ },
+ }
+ ]
+ }
+ self._update_state(
+ context, composite_alarm, new_state_value, state_reason, state_reason_data
+ )
+ if composite_alarm.alarm["ActionsEnabled"]:
+ self._run_composite_alarm_actions(
+ context, composite_alarm, old_state_value, triggering_alarm
+ )
+
+ def _validate_alarm_rule_expression(self, alarm_rule):
+ validation_result = []
+ alarms_conditions = [alarm.strip() for alarm in alarm_rule.split("OR")]
+ for alarm_condition in alarms_conditions:
+ if not alarm_condition.startswith("ALARM"):
+ validation_result.append(
+ f"Unsupported expression in alarm rule condition {alarm_condition}: Only ALARM expression is supported by Localstack as of now"
+ )
+ return validation_result
+
+ def _get_alarm_arns(self, composite_alarm_rule):
+ # regexp for everything within (" ")
+ return re.findall(r'\("([^"]*)"\)', composite_alarm_rule)
+
+ def _run_composite_alarm_actions(
+ self, context, composite_alarm, old_state_value, triggering_alarm
+ ):
+ new_state_value = composite_alarm.alarm["StateValue"]
+ if new_state_value == StateValue.OK:
+ actions = composite_alarm.alarm["OKActions"]
+ elif new_state_value == StateValue.ALARM:
+ actions = composite_alarm.alarm["AlarmActions"]
+ else:
+ actions = composite_alarm.alarm["InsufficientDataActions"]
+ for action in actions:
+ data = arns.parse_arn(action)
+ if data["service"] == "sns":
+ service = connect_to(
+ region_name=data["region"], aws_access_key_id=data["account"]
+ ).sns
+ subject = f"""{new_state_value}: "{composite_alarm.alarm["AlarmName"]}" in {context.region}"""
+ message = create_message_response_update_composite_alarm_state_sns(
+ composite_alarm, triggering_alarm, old_state_value
+ )
+ service.publish(TopicArn=action, Subject=subject, Message=message)
+ else:
+ # TODO: support other actions
+ LOG.warning(
+ "Action for service %s not implemented, action '%s' will not be triggered.",
+ data["service"],
+ action,
+ )
+
+
+def create_metric_data_query_from_alarm(alarm: LocalStackMetricAlarm):
+ # TODO may need to be adapted for other use cases
+ # verified return value with a snapshot test
+ return [
+ {
+ "id": str(uuid.uuid4()),
+ "metricStat": {
+ "metric": {
+ "namespace": alarm.alarm["Namespace"],
+ "name": alarm.alarm["MetricName"],
+ "dimensions": alarm.alarm.get("Dimensions") or {},
+ },
+ "period": int(alarm.alarm["Period"]),
+ "stat": alarm.alarm["Statistic"],
+ },
+ "returnData": True,
+ }
+ ]
+
+
+def create_message_response_update_state_lambda(
+ alarm: LocalStackMetricAlarm, old_state, old_state_reason, old_state_timestamp
+):
+ _alarm = alarm.alarm
+ response = {
+ "accountId": extract_account_id_from_arn(_alarm["AlarmArn"]),
+ "alarmArn": _alarm["AlarmArn"],
+ "alarmData": {
+ "alarmName": _alarm["AlarmName"],
+ "state": {
+ "value": _alarm["StateValue"],
+ "reason": _alarm["StateReason"],
+ "timestamp": _alarm["StateUpdatedTimestamp"],
+ },
+ "previousState": {
+ "value": old_state,
+ "reason": old_state_reason,
+ "timestamp": old_state_timestamp,
+ },
+ "configuration": {
+ "description": _alarm.get("AlarmDescription", ""),
+ "metrics": _alarm.get(
+ "Metrics", create_metric_data_query_from_alarm(alarm)
+ ), # TODO: add test with metric_data_queries
+ },
+ },
+ "time": _alarm["StateUpdatedTimestamp"],
+ "region": alarm.region,
+ "source": "aws.cloudwatch",
+ }
+ return json.dumps(response, cls=JSONEncoder)
+
+
+def create_message_response_update_state_sns(alarm: LocalStackMetricAlarm, old_state: StateValue):
+ _alarm = alarm.alarm
+ response = {
+ "AWSAccountId": alarm.account_id,
+ "OldStateValue": old_state,
+ "AlarmName": _alarm["AlarmName"],
+ "AlarmDescription": _alarm.get("AlarmDescription"),
+ "AlarmConfigurationUpdatedTimestamp": _alarm["AlarmConfigurationUpdatedTimestamp"],
+ "NewStateValue": _alarm["StateValue"],
+ "NewStateReason": _alarm["StateReason"],
+ "StateChangeTime": _alarm["StateUpdatedTimestamp"],
+ # the long-name for 'region' should be used - as we don't have it, we use the short name
+ # which needs to be slightly changed to make snapshot tests work
+ "Region": alarm.region.replace("-", " ").capitalize(),
+ "AlarmArn": _alarm["AlarmArn"],
+ "OKActions": _alarm.get("OKActions", []),
+ "AlarmActions": _alarm.get("AlarmActions", []),
+ "InsufficientDataActions": _alarm.get("InsufficientDataActions", []),
+ }
+
+ # collect trigger details
+ details = {
+ "MetricName": _alarm.get("MetricName", ""),
+ "Namespace": _alarm.get("Namespace", ""),
+ "Unit": _alarm.get("Unit", None), # testing with AWS revealed this currently returns None
+ "Period": int(_alarm.get("Period", 0)),
+ "EvaluationPeriods": int(_alarm.get("EvaluationPeriods", 0)),
+ "ComparisonOperator": _alarm.get("ComparisonOperator", ""),
+ "Threshold": float(_alarm.get("Threshold", 0.0)),
+ "TreatMissingData": _alarm.get("TreatMissingData", ""),
+ "EvaluateLowSampleCountPercentile": _alarm.get("EvaluateLowSampleCountPercentile", ""),
+ }
+
+ # Dimensions not serializable
+ dimensions = []
+ alarm_dimensions = _alarm.get("Dimensions", [])
+ if alarm_dimensions:
+ for d in _alarm["Dimensions"]:
+ dimensions.append({"value": d["Value"], "name": d["Name"]})
+ details["Dimensions"] = dimensions or ""
+
+ alarm_statistic = _alarm.get("Statistic")
+ alarm_extended_statistic = _alarm.get("ExtendedStatistic")
+
+ if alarm_statistic:
+ details["StatisticType"] = "Statistic"
+ details["Statistic"] = camel_to_snake_case(alarm_statistic).upper() # AWS returns uppercase
+ elif alarm_extended_statistic:
+ details["StatisticType"] = "ExtendedStatistic"
+ details["ExtendedStatistic"] = alarm_extended_statistic
+
+ response["Trigger"] = details
+
+ return json.dumps(response, cls=JSONEncoder)
+
+
+def create_message_response_update_composite_alarm_state_sns(
+ composite_alarm: LocalStackCompositeAlarm,
+ triggering_alarm: LocalStackMetricAlarm,
+ old_state: StateValue,
+):
+ _alarm = composite_alarm.alarm
+ response = {
+ "AWSAccountId": composite_alarm.account_id,
+ "AlarmName": _alarm["AlarmName"],
+ "AlarmDescription": _alarm.get("AlarmDescription"),
+ "AlarmRule": _alarm.get("AlarmRule"),
+ "OldStateValue": old_state,
+ "NewStateValue": _alarm["StateValue"],
+ "NewStateReason": _alarm["StateReason"],
+ "StateChangeTime": _alarm["StateUpdatedTimestamp"],
+ # the long-name for 'region' should be used - as we don't have it, we use the short name
+ # which needs to be slightly changed to make snapshot tests work
+ "Region": composite_alarm.region.replace("-", " ").capitalize(),
+ "AlarmArn": _alarm["AlarmArn"],
+ "OKActions": _alarm.get("OKActions", []),
+ "AlarmActions": _alarm.get("AlarmActions", []),
+ "InsufficientDataActions": _alarm.get("InsufficientDataActions", []),
+ }
+
+ triggering_children = [
+ {
+ "Arn": triggering_alarm.alarm.get("AlarmArn"),
+ "State": {
+ "Value": triggering_alarm.alarm["StateValue"],
+ "Timestamp": triggering_alarm.alarm["StateUpdatedTimestamp"],
+ },
+ }
+ ]
+
+ response["TriggeringChildren"] = triggering_children
+
+ return json.dumps(response, cls=JSONEncoder)
diff --git a/localstack-core/localstack/services/cloudwatch/resource_providers/__init__.py b/localstack-core/localstack/services/cloudwatch/resource_providers/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_alarm.py b/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_alarm.py
new file mode 100644
index 0000000000000..56aa3292de1f4
--- /dev/null
+++ b/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_alarm.py
@@ -0,0 +1,194 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class CloudWatchAlarmProperties(TypedDict):
+ ComparisonOperator: Optional[str]
+ EvaluationPeriods: Optional[int]
+ ActionsEnabled: Optional[bool]
+ AlarmActions: Optional[list[str]]
+ AlarmDescription: Optional[str]
+ AlarmName: Optional[str]
+ Arn: Optional[str]
+ DatapointsToAlarm: Optional[int]
+ Dimensions: Optional[list[Dimension]]
+ EvaluateLowSampleCountPercentile: Optional[str]
+ ExtendedStatistic: Optional[str]
+ Id: Optional[str]
+ InsufficientDataActions: Optional[list[str]]
+ MetricName: Optional[str]
+ Metrics: Optional[list[MetricDataQuery]]
+ Namespace: Optional[str]
+ OKActions: Optional[list[str]]
+ Period: Optional[int]
+ Statistic: Optional[str]
+ Threshold: Optional[float]
+ ThresholdMetricId: Optional[str]
+ TreatMissingData: Optional[str]
+ Unit: Optional[str]
+
+
+class Dimension(TypedDict):
+ Name: Optional[str]
+ Value: Optional[str]
+
+
+class Metric(TypedDict):
+ Dimensions: Optional[list[Dimension]]
+ MetricName: Optional[str]
+ Namespace: Optional[str]
+
+
+class MetricStat(TypedDict):
+ Metric: Optional[Metric]
+ Period: Optional[int]
+ Stat: Optional[str]
+ Unit: Optional[str]
+
+
+class MetricDataQuery(TypedDict):
+ Id: Optional[str]
+ AccountId: Optional[str]
+ Expression: Optional[str]
+ Label: Optional[str]
+ MetricStat: Optional[MetricStat]
+ Period: Optional[int]
+ ReturnData: Optional[bool]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class CloudWatchAlarmProvider(ResourceProvider[CloudWatchAlarmProperties]):
+ TYPE = "AWS::CloudWatch::Alarm" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[CloudWatchAlarmProperties],
+ ) -> ProgressEvent[CloudWatchAlarmProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - ComparisonOperator
+ - EvaluationPeriods
+
+ Create-only properties:
+ - /properties/AlarmName
+
+ Read-only properties:
+ - /properties/Id
+ - /properties/Arn
+
+
+
+ """
+ model = request.desired_state
+ cloudwatch = request.aws_client_factory.cloudwatch
+
+ if not model.get("AlarmName"):
+ model["AlarmName"] = util.generate_default_name(
+ stack_name=request.stack_name, logical_resource_id=request.logical_resource_id
+ )
+
+ create_params = util.select_attributes(
+ model,
+ [
+ "AlarmName",
+ "ComparisonOperator",
+ "EvaluationPeriods",
+ "Period",
+ "MetricName",
+ "Namespace",
+ "Statistic",
+ "Threshold",
+ "ActionsEnabled",
+ "AlarmActions",
+ "AlarmDescription",
+ "DatapointsToAlarm",
+ "Dimensions",
+ "EvaluateLowSampleCountPercentile",
+ "ExtendedStatistic",
+ "InsufficientDataActions",
+ "Metrics",
+ "OKActions",
+ "ThresholdMetricId",
+ "TreatMissingData",
+ "Unit",
+ ],
+ )
+
+ cloudwatch.put_metric_alarm(**create_params)
+ alarms = cloudwatch.describe_alarms(AlarmNames=[model["AlarmName"]])["MetricAlarms"]
+ if not alarms:
+ return ProgressEvent(
+ status=OperationStatus.FAILED,
+ resource_model=model,
+ message="Alarm not found",
+ )
+
+ alarm = alarms[0]
+ model["Arn"] = alarm["AlarmArn"]
+ model["Id"] = alarm["AlarmName"]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[CloudWatchAlarmProperties],
+ ) -> ProgressEvent[CloudWatchAlarmProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[CloudWatchAlarmProperties],
+ ) -> ProgressEvent[CloudWatchAlarmProperties]:
+ """
+ Delete a resource
+
+
+ """
+ model = request.desired_state
+ cloud_watch = request.aws_client_factory.cloudwatch
+ cloud_watch.delete_alarms(AlarmNames=[model["AlarmName"]])
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[CloudWatchAlarmProperties],
+ ) -> ProgressEvent[CloudWatchAlarmProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_alarm.schema.json b/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_alarm.schema.json
new file mode 100644
index 0000000000000..c30c227e6aff9
--- /dev/null
+++ b/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_alarm.schema.json
@@ -0,0 +1,200 @@
+{
+ "typeName": "AWS::CloudWatch::Alarm",
+ "description": "Resource Type definition for AWS::CloudWatch::Alarm",
+ "additionalProperties": false,
+ "properties": {
+ "ThresholdMetricId": {
+ "type": "string"
+ },
+ "EvaluateLowSampleCountPercentile": {
+ "type": "string"
+ },
+ "ExtendedStatistic": {
+ "type": "string"
+ },
+ "ComparisonOperator": {
+ "type": "string"
+ },
+ "TreatMissingData": {
+ "type": "string"
+ },
+ "Dimensions": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Dimension"
+ }
+ },
+ "Period": {
+ "type": "integer"
+ },
+ "EvaluationPeriods": {
+ "type": "integer"
+ },
+ "Unit": {
+ "type": "string"
+ },
+ "Namespace": {
+ "type": "string"
+ },
+ "OKActions": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "AlarmActions": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "MetricName": {
+ "type": "string"
+ },
+ "ActionsEnabled": {
+ "type": "boolean"
+ },
+ "Metrics": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/MetricDataQuery"
+ }
+ },
+ "AlarmDescription": {
+ "type": "string"
+ },
+ "AlarmName": {
+ "type": "string"
+ },
+ "Statistic": {
+ "type": "string"
+ },
+ "InsufficientDataActions": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Id": {
+ "type": "string"
+ },
+ "Arn": {
+ "type": "string"
+ },
+ "DatapointsToAlarm": {
+ "type": "integer"
+ },
+ "Threshold": {
+ "type": "number"
+ }
+ },
+ "definitions": {
+ "MetricStat": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Period": {
+ "type": "integer"
+ },
+ "Metric": {
+ "$ref": "#/definitions/Metric"
+ },
+ "Stat": {
+ "type": "string"
+ },
+ "Unit": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Stat",
+ "Period",
+ "Metric"
+ ]
+ },
+ "Metric": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "MetricName": {
+ "type": "string"
+ },
+ "Dimensions": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Dimension"
+ }
+ },
+ "Namespace": {
+ "type": "string"
+ }
+ }
+ },
+ "Dimension": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Value": {
+ "type": "string"
+ },
+ "Name": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Name"
+ ]
+ },
+ "MetricDataQuery": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "AccountId": {
+ "type": "string"
+ },
+ "ReturnData": {
+ "type": "boolean"
+ },
+ "Expression": {
+ "type": "string"
+ },
+ "Label": {
+ "type": "string"
+ },
+ "MetricStat": {
+ "$ref": "#/definitions/MetricStat"
+ },
+ "Period": {
+ "type": "integer"
+ },
+ "Id": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Id"
+ ]
+ }
+ },
+ "required": [
+ "ComparisonOperator",
+ "EvaluationPeriods"
+ ],
+ "createOnlyProperties": [
+ "/properties/AlarmName"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id",
+ "/properties/Arn"
+ ]
+}
diff --git a/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_alarm_plugin.py b/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_alarm_plugin.py
new file mode 100644
index 0000000000000..6dfffe39b52a4
--- /dev/null
+++ b/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_alarm_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class CloudWatchAlarmProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::CloudWatch::Alarm"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.cloudwatch.resource_providers.aws_cloudwatch_alarm import (
+ CloudWatchAlarmProvider,
+ )
+
+ self.factory = CloudWatchAlarmProvider
diff --git a/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_compositealarm.py b/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_compositealarm.py
new file mode 100644
index 0000000000000..b6ca22b2e9f3f
--- /dev/null
+++ b/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_compositealarm.py
@@ -0,0 +1,168 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+from localstack.utils.strings import str_to_bool
+
+
+class CloudWatchCompositeAlarmProperties(TypedDict):
+ AlarmRule: Optional[str]
+ ActionsEnabled: Optional[bool]
+ ActionsSuppressor: Optional[str]
+ ActionsSuppressorExtensionPeriod: Optional[int]
+ ActionsSuppressorWaitPeriod: Optional[int]
+ AlarmActions: Optional[list[str]]
+ AlarmDescription: Optional[str]
+ AlarmName: Optional[str]
+ Arn: Optional[str]
+ InsufficientDataActions: Optional[list[str]]
+ OKActions: Optional[list[str]]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class CloudWatchCompositeAlarmProvider(ResourceProvider[CloudWatchCompositeAlarmProperties]):
+ TYPE = "AWS::CloudWatch::CompositeAlarm" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[CloudWatchCompositeAlarmProperties],
+ ) -> ProgressEvent[CloudWatchCompositeAlarmProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/AlarmName
+
+ Required properties:
+ - AlarmRule
+
+ Create-only properties:
+ - /properties/AlarmName
+
+ Read-only properties:
+ - /properties/Arn
+
+ IAM permissions required:
+ - cloudwatch:DescribeAlarms
+ - cloudwatch:PutCompositeAlarm
+
+ """
+ model = request.desired_state
+ cloud_watch = request.aws_client_factory.cloudwatch
+
+ params = util.select_attributes(
+ model,
+ [
+ "AlarmName",
+ "AlarmRule",
+ "ActionsEnabled",
+ "ActionsSuppressor",
+ "ActionsSuppressorWaitPeriod",
+ "ActionsSuppressorExtensionPeriod",
+ "AlarmActions",
+ "AlarmDescription",
+ "InsufficientDataActions",
+ "OKActions",
+ ],
+ )
+ if not params.get("AlarmName"):
+ model["AlarmName"] = util.generate_default_name(
+ stack_name=request.stack_name, logical_resource_id=request.logical_resource_id
+ )
+ params["AlarmName"] = model["AlarmName"]
+
+ if "ActionsEnabled" in params:
+ params["ActionsEnabled"] = str_to_bool(params["ActionsEnabled"])
+
+ create_params = util.select_attributes(
+ model,
+ [
+ "AlarmName",
+ "AlarmRule",
+ "ActionsEnabled",
+ "ActionsSuppressor",
+ "ActionsSuppressorExtensionPeriod",
+ "ActionsSuppressorWaitPeriod",
+ "AlarmActions",
+ "AlarmDescription",
+ "InsufficientDataActions",
+ "OKActions",
+ ],
+ )
+
+ cloud_watch.put_composite_alarm(**create_params)
+ alarms = cloud_watch.describe_alarms(
+ AlarmNames=[model["AlarmName"]], AlarmTypes=["CompositeAlarm"]
+ )["CompositeAlarms"]
+
+ if not alarms:
+ return ProgressEvent(
+ status=OperationStatus.FAILED,
+ resource_model=model,
+ message="Composite Alarm not found",
+ )
+ model["Arn"] = alarms[0]["AlarmArn"]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[CloudWatchCompositeAlarmProperties],
+ ) -> ProgressEvent[CloudWatchCompositeAlarmProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - cloudwatch:DescribeAlarms
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[CloudWatchCompositeAlarmProperties],
+ ) -> ProgressEvent[CloudWatchCompositeAlarmProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - cloudwatch:DescribeAlarms
+ - cloudwatch:DeleteAlarms
+ """
+ model = request.desired_state
+ cloud_watch = request.aws_client_factory.cloudwatch
+ cloud_watch.delete_alarms(AlarmNames=[model["AlarmName"]])
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[CloudWatchCompositeAlarmProperties],
+ ) -> ProgressEvent[CloudWatchCompositeAlarmProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - cloudwatch:DescribeAlarms
+ - cloudwatch:PutCompositeAlarm
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_compositealarm.schema.json b/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_compositealarm.schema.json
new file mode 100644
index 0000000000000..36464ecf204be
--- /dev/null
+++ b/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_compositealarm.schema.json
@@ -0,0 +1,130 @@
+{
+ "typeName": "AWS::CloudWatch::CompositeAlarm",
+ "description": "The AWS::CloudWatch::CompositeAlarm type specifies an alarm which aggregates the states of other Alarms (Metric or Composite Alarms) as defined by the AlarmRule expression",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-cloudwatch.git",
+ "properties": {
+ "Arn": {
+ "type": "string",
+ "description": "Amazon Resource Name (ARN) of the alarm",
+ "minLength": 1,
+ "maxLength": 1600
+ },
+ "AlarmName": {
+ "description": "The name of the Composite Alarm",
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 255
+ },
+ "AlarmRule": {
+ "type": "string",
+ "description": "Expression which aggregates the state of other Alarms (Metric or Composite Alarms)",
+ "minLength": 1,
+ "maxLength": 10240
+ },
+ "AlarmDescription": {
+ "type": "string",
+ "description": "The description of the alarm",
+ "minLength": 0,
+ "maxLength": 1024
+ },
+ "ActionsEnabled": {
+ "description": "Indicates whether actions should be executed during any changes to the alarm state. The default is TRUE.",
+ "type": "boolean"
+ },
+ "OKActions": {
+ "type": "array",
+ "items": {
+ "type": "string",
+ "description": "Amazon Resource Name (ARN) of the action",
+ "minLength": 1,
+ "maxLength": 1024
+ },
+ "description": "The actions to execute when this alarm transitions to the OK state from any other state. Each action is specified as an Amazon Resource Name (ARN).",
+ "maxItems": 5
+ },
+ "AlarmActions": {
+ "type": "array",
+ "items": {
+ "type": "string",
+ "description": "Amazon Resource Name (ARN) of the action",
+ "minLength": 1,
+ "maxLength": 1024
+ },
+ "description": "The list of actions to execute when this alarm transitions into an ALARM state from any other state. Specify each action as an Amazon Resource Name (ARN).",
+ "maxItems": 5
+ },
+ "InsufficientDataActions": {
+ "type": "array",
+ "items": {
+ "type": "string",
+ "description": "Amazon Resource Name (ARN) of the action",
+ "minLength": 1,
+ "maxLength": 1024
+ },
+ "description": "The actions to execute when this alarm transitions to the INSUFFICIENT_DATA state from any other state. Each action is specified as an Amazon Resource Name (ARN).",
+ "maxItems": 5
+ },
+ "ActionsSuppressor": {
+ "description": "Actions will be suppressed if the suppressor alarm is in the ALARM state. ActionsSuppressor can be an AlarmName or an Amazon Resource Name (ARN) from an existing alarm. ",
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 1600
+ },
+ "ActionsSuppressorWaitPeriod": {
+ "description": "Actions will be suppressed if ExtensionPeriod is active. The length of time that actions are suppressed is in seconds.",
+ "type": "integer",
+ "minimum": 0
+ },
+ "ActionsSuppressorExtensionPeriod": {
+ "description": "Actions will be suppressed if WaitPeriod is active. The length of time that actions are suppressed is in seconds.",
+ "type": "integer",
+ "minimum": 0
+ }
+ },
+ "required": [
+ "AlarmRule"
+ ],
+ "readOnlyProperties": [
+ "/properties/Arn"
+ ],
+ "createOnlyProperties": [
+ "/properties/AlarmName"
+ ],
+ "primaryIdentifier": [
+ "/properties/AlarmName"
+ ],
+ "additionalProperties": false,
+ "handlers": {
+ "create": {
+ "permissions": [
+ "cloudwatch:DescribeAlarms",
+ "cloudwatch:PutCompositeAlarm"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "cloudwatch:DescribeAlarms"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "cloudwatch:DescribeAlarms",
+ "cloudwatch:PutCompositeAlarm"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "cloudwatch:DescribeAlarms",
+ "cloudwatch:DeleteAlarms"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "cloudwatch:DescribeAlarms"
+ ]
+ }
+ },
+ "tagging": {
+ "taggable": false
+ }
+}
diff --git a/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_compositealarm_plugin.py b/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_compositealarm_plugin.py
new file mode 100644
index 0000000000000..867cebdbfe31d
--- /dev/null
+++ b/localstack-core/localstack/services/cloudwatch/resource_providers/aws_cloudwatch_compositealarm_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class CloudWatchCompositeAlarmProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::CloudWatch::CompositeAlarm"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.cloudwatch.resource_providers.aws_cloudwatch_compositealarm import (
+ CloudWatchCompositeAlarmProvider,
+ )
+
+ self.factory = CloudWatchCompositeAlarmProvider
diff --git a/localstack-core/localstack/services/configservice/__init__.py b/localstack-core/localstack/services/configservice/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/configservice/provider.py b/localstack-core/localstack/services/configservice/provider.py
new file mode 100644
index 0000000000000..3087c6b23e270
--- /dev/null
+++ b/localstack-core/localstack/services/configservice/provider.py
@@ -0,0 +1,5 @@
+from localstack.aws.api.config import ConfigApi
+
+
+class ConfigProvider(ConfigApi):
+ pass
diff --git a/localstack-core/localstack/services/dynamodb/__init__.py b/localstack-core/localstack/services/dynamodb/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/dynamodb/models.py b/localstack-core/localstack/services/dynamodb/models.py
new file mode 100644
index 0000000000000..cc6d7ee2e4939
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodb/models.py
@@ -0,0 +1,122 @@
+import dataclasses
+from typing import TypedDict
+
+from localstack.aws.api.dynamodb import (
+ AttributeMap,
+ Key,
+ RegionName,
+ ReplicaDescription,
+ StreamViewType,
+ TableName,
+ TimeToLiveSpecification,
+)
+from localstack.services.stores import (
+ AccountRegionBundle,
+ BaseStore,
+ CrossRegionAttribute,
+ LocalAttribute,
+)
+
+
+@dataclasses.dataclass
+class TableStreamType:
+ """
+ When an item in the table is modified, StreamViewType determines what information is written to the stream for this table.
+ - KEYS_ONLY - Only the key attributes of the modified item are written to the stream.
+ - NEW_IMAGE - The entire item, as it appears after it was modified, is written to the stream.
+ - OLD_IMAGE - The entire item, as it appeared before it was modified, is written to the stream.
+ - NEW_AND_OLD_IMAGES - Both the new and the old item images of the item are written to the stream.
+ Special case:
+ is_kinesis: equivalent to NEW_AND_OLD_IMAGES, can be set at the same time as StreamViewType
+ """
+
+ stream_view_type: StreamViewType | None
+ is_kinesis: bool
+
+ @property
+ def needs_old_image(self):
+ return self.is_kinesis or self.stream_view_type in (
+ StreamViewType.OLD_IMAGE,
+ StreamViewType.NEW_AND_OLD_IMAGES,
+ )
+
+ @property
+ def needs_new_image(self):
+ return self.is_kinesis or self.stream_view_type in (
+ StreamViewType.NEW_IMAGE,
+ StreamViewType.NEW_AND_OLD_IMAGES,
+ )
+
+
+class DynamoDbStreamRecord(TypedDict, total=False):
+ ApproximateCreationDateTime: int
+ SizeBytes: int
+ Keys: Key
+ StreamViewType: StreamViewType | None
+ OldImage: AttributeMap | None
+ NewImage: AttributeMap | None
+ SequenceNumber: int | None
+
+
+class StreamRecord(TypedDict, total=False):
+ """
+ Related to DynamoDB Streams and Kinesis Destinations
+ This class contains data necessary for both KinesisRecord and DynamoDBStreams record
+ """
+
+ eventName: str
+ eventID: str
+ eventVersion: str
+ dynamodb: DynamoDbStreamRecord
+ awsRegion: str
+ eventSource: str
+
+
+StreamRecords = list[StreamRecord]
+
+
+class TableRecords(TypedDict):
+ """
+ Container class used to forward events from DynamoDB to DDB Streams and Kinesis destinations.
+ It contains the records to be forwarded and data about the streams to be forwarded to.
+ """
+
+ table_stream_type: TableStreamType
+ records: StreamRecords
+
+
+# the RecordsMap maps the TableName to TableRecords, allowing forwarding to the destinations
+# some DynamoDB calls can modify several tables at once, which is why we need to group those events per table, as each
+# table can have different destinations
+RecordsMap = dict[TableName, TableRecords]
+
+
+class DynamoDBStore(BaseStore):
+ # maps global table names to configurations (for the legacy v.2017 tables)
+ GLOBAL_TABLES: dict[str, dict] = CrossRegionAttribute(default=dict)
+
+ # Maps table name to the region they exist in on DDBLocal (for v.2019 global tables)
+ TABLE_REGION: dict[TableName, RegionName] = CrossRegionAttribute(default=dict)
+
+ # Maps the table replicas (for v.2019 global tables)
+ REPLICAS: dict[TableName, dict[RegionName, ReplicaDescription]] = CrossRegionAttribute(
+ default=dict
+ )
+
+ # cache table taggings - maps table ARN to tags dict
+ TABLE_TAGS: dict[str, dict] = CrossRegionAttribute(default=dict)
+
+ # maps table names to cached table definitions
+ table_definitions: dict[str, dict] = LocalAttribute(default=dict)
+
+ # maps table names to additional table properties that are not stored upstream (e.g., ReplicaUpdates)
+ table_properties: dict[str, dict] = LocalAttribute(default=dict)
+
+ # maps table names to TTL specifications
+ ttl_specifications: dict[str, TimeToLiveSpecification] = LocalAttribute(default=dict)
+
+ # maps backups
+ backups: dict[str, dict] = LocalAttribute(default=dict)
+
+
+dynamodb_stores = AccountRegionBundle("dynamodb", DynamoDBStore)
diff --git a/localstack-core/localstack/services/dynamodb/packages.py b/localstack-core/localstack/services/dynamodb/packages.py
new file mode 100644
index 0000000000000..db2ca14c49bf6
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodb/packages.py
@@ -0,0 +1,105 @@
+import os
+from typing import List
+
+from localstack import config
+from localstack.constants import ARTIFACTS_REPO, MAVEN_REPO_URL
+from localstack.packages import InstallTarget, Package, PackageInstaller
+from localstack.packages.java import java_package
+from localstack.utils.archives import (
+ download_and_extract_with_retry,
+ update_jar_manifest,
+ upgrade_jar_file,
+)
+from localstack.utils.files import rm_rf, save_file
+from localstack.utils.functions import run_safe
+from localstack.utils.http import download
+from localstack.utils.run import run
+
+DDB_AGENT_JAR_URL = f"{ARTIFACTS_REPO}/raw/388cd73f45bfd3bcf7ad40aa35499093061c7962/dynamodb-local-patch/target/ddb-local-loader-0.1.jar"
+JAVASSIST_JAR_URL = f"{MAVEN_REPO_URL}/org/javassist/javassist/3.30.2-GA/javassist-3.30.2-GA.jar"
+
+DDBLOCAL_URL = "https://d1ni2b6xgvw0s0.cloudfront.net/v2.x/dynamodb_local_latest.zip"
+
+
+class DynamoDBLocalPackage(Package):
+ def __init__(self):
+ super().__init__(name="DynamoDBLocal", default_version="2")
+
+ def _get_installer(self, _) -> PackageInstaller:
+ return DynamoDBLocalPackageInstaller()
+
+ def get_versions(self) -> List[str]:
+ return ["2"]
+
+
+class DynamoDBLocalPackageInstaller(PackageInstaller):
+ def __init__(self):
+ super().__init__("dynamodb-local", "2")
+
+ # DDBLocal v2 requires JRE 17+
+ # See: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DynamoDBLocal.DownloadingAndRunning.html
+ self.java_version = "21"
+
+ def _prepare_installation(self, target: InstallTarget) -> None:
+ java_package.get_installer(self.java_version).install(target)
+
+ def get_java_env_vars(self) -> dict[str, str]:
+ java_home = java_package.get_installer(self.java_version).get_java_home()
+ path = f"{java_home}/bin:{os.environ['PATH']}"
+
+ return {
+ "JAVA_HOME": java_home,
+ "PATH": path,
+ }
+
+ def _install(self, target: InstallTarget):
+ # download and extract archive
+ tmp_archive = os.path.join(config.dirs.cache, f"DynamoDBLocal-{self.version}.zip")
+ install_dir = self._get_install_dir(target)
+
+ download_and_extract_with_retry(DDBLOCAL_URL, tmp_archive, install_dir)
+ rm_rf(tmp_archive)
+
+ # Use custom log formatting
+ log4j2_config = """
+
+
+
+
+
+
+
+
+
+
+
+ """
+ log4j2_file = os.path.join(install_dir, "log4j2.xml")
+ run_safe(lambda: save_file(log4j2_file, log4j2_config))
+ run_safe(lambda: run(["zip", "-u", "DynamoDBLocal.jar", "log4j2.xml"], cwd=install_dir))
+
+ # Add patch that enables 20+ GSIs
+ ddb_agent_jar_path = self.get_ddb_agent_jar_path()
+ if not os.path.exists(ddb_agent_jar_path):
+ download(DDB_AGENT_JAR_URL, ddb_agent_jar_path)
+
+ javassit_jar_path = os.path.join(install_dir, "javassist.jar")
+ if not os.path.exists(javassit_jar_path):
+ download(JAVASSIST_JAR_URL, javassit_jar_path)
+
+ # Add javassist in the manifest classpath
+ update_jar_manifest(
+ "DynamoDBLocal.jar", install_dir, "Class-Path: .", "Class-Path: javassist.jar ."
+ )
+
+ ddb_local_lib_dir = os.path.join(install_dir, "DynamoDBLocal_lib")
+ upgrade_jar_file(ddb_local_lib_dir, "slf4j-ext-*.jar", "org/slf4j/slf4j-ext:2.0.13")
+
+ def _get_install_marker_path(self, install_dir: str) -> str:
+ return os.path.join(install_dir, "DynamoDBLocal.jar")
+
+ def get_ddb_agent_jar_path(self):
+ return os.path.join(self.get_installed_dir(), "ddb-local-loader-0.1.jar")
+
+
+dynamodblocal_package = DynamoDBLocalPackage()
diff --git a/localstack-core/localstack/services/dynamodb/plugins.py b/localstack-core/localstack/services/dynamodb/plugins.py
new file mode 100644
index 0000000000000..f5d60a15b914a
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodb/plugins.py
@@ -0,0 +1,8 @@
+from localstack.packages import Package, package
+
+
+@package(name="dynamodb-local")
+def dynamodb_local_package() -> Package:
+ from localstack.services.dynamodb.packages import dynamodblocal_package
+
+ return dynamodblocal_package
diff --git a/localstack-core/localstack/services/dynamodb/provider.py b/localstack-core/localstack/services/dynamodb/provider.py
new file mode 100644
index 0000000000000..407e6400414ca
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodb/provider.py
@@ -0,0 +1,2271 @@
+import copy
+import json
+import logging
+import os
+import random
+import re
+import threading
+import time
+import traceback
+from collections import defaultdict
+from concurrent.futures import ThreadPoolExecutor
+from contextlib import contextmanager
+from datetime import datetime
+from operator import itemgetter
+from typing import Dict, List, Optional
+
+import requests
+import werkzeug
+
+from localstack import config
+from localstack.aws import handlers
+from localstack.aws.api import (
+ CommonServiceException,
+ RequestContext,
+ ServiceRequest,
+ ServiceResponse,
+ handler,
+)
+from localstack.aws.api.dynamodb import (
+ AttributeMap,
+ BatchExecuteStatementOutput,
+ BatchGetItemOutput,
+ BatchGetRequestMap,
+ BatchGetResponseMap,
+ BatchWriteItemInput,
+ BatchWriteItemOutput,
+ BatchWriteItemRequestMap,
+ BillingMode,
+ ContinuousBackupsDescription,
+ ContinuousBackupsStatus,
+ CreateGlobalTableOutput,
+ CreateTableInput,
+ CreateTableOutput,
+ Delete,
+ DeleteItemInput,
+ DeleteItemOutput,
+ DeleteRequest,
+ DeleteTableOutput,
+ DescribeContinuousBackupsOutput,
+ DescribeGlobalTableOutput,
+ DescribeKinesisStreamingDestinationOutput,
+ DescribeTableOutput,
+ DescribeTimeToLiveOutput,
+ DestinationStatus,
+ DynamodbApi,
+ EnableKinesisStreamingConfiguration,
+ ExecuteStatementInput,
+ ExecuteStatementOutput,
+ ExecuteTransactionInput,
+ ExecuteTransactionOutput,
+ GetItemInput,
+ GetItemOutput,
+ GlobalTableAlreadyExistsException,
+ GlobalTableNotFoundException,
+ KinesisStreamingDestinationOutput,
+ ListGlobalTablesOutput,
+ ListTablesInputLimit,
+ ListTablesOutput,
+ ListTagsOfResourceOutput,
+ NextTokenString,
+ PartiQLBatchRequest,
+ PointInTimeRecoveryDescription,
+ PointInTimeRecoverySpecification,
+ PointInTimeRecoveryStatus,
+ PositiveIntegerObject,
+ ProvisionedThroughputExceededException,
+ Put,
+ PutItemInput,
+ PutItemOutput,
+ PutRequest,
+ QueryInput,
+ QueryOutput,
+ RegionName,
+ ReplicaDescription,
+ ReplicaList,
+ ReplicaStatus,
+ ReplicaUpdateList,
+ ResourceArnString,
+ ResourceInUseException,
+ ResourceNotFoundException,
+ ReturnConsumedCapacity,
+ ScanInput,
+ ScanOutput,
+ StreamArn,
+ TableDescription,
+ TableName,
+ TagKeyList,
+ TagList,
+ TimeToLiveSpecification,
+ TransactGetItemList,
+ TransactGetItemsOutput,
+ TransactWriteItem,
+ TransactWriteItemList,
+ TransactWriteItemsInput,
+ TransactWriteItemsOutput,
+ Update,
+ UpdateContinuousBackupsOutput,
+ UpdateGlobalTableOutput,
+ UpdateItemInput,
+ UpdateItemOutput,
+ UpdateTableInput,
+ UpdateTableOutput,
+ UpdateTimeToLiveOutput,
+ WriteRequest,
+)
+from localstack.aws.api.dynamodbstreams import StreamStatus
+from localstack.aws.connect import connect_to
+from localstack.constants import (
+ AUTH_CREDENTIAL_REGEX,
+ AWS_REGION_US_EAST_1,
+ INTERNAL_AWS_SECRET_ACCESS_KEY,
+)
+from localstack.http import Request, Response, route
+from localstack.services.dynamodb.models import (
+ DynamoDBStore,
+ RecordsMap,
+ StreamRecord,
+ StreamRecords,
+ TableRecords,
+ TableStreamType,
+ dynamodb_stores,
+)
+from localstack.services.dynamodb.server import DynamodbServer
+from localstack.services.dynamodb.utils import (
+ ItemFinder,
+ ItemSet,
+ SchemaExtractor,
+ de_dynamize_record,
+ extract_table_name_from_partiql_update,
+ get_ddb_access_key,
+ modify_ddblocal_arns,
+)
+from localstack.services.dynamodbstreams import dynamodbstreams_api
+from localstack.services.dynamodbstreams.models import dynamodbstreams_stores
+from localstack.services.edge import ROUTER
+from localstack.services.plugins import ServiceLifecycleHook
+from localstack.state import AssetDirectory, StateVisitor
+from localstack.utils.aws import arns
+from localstack.utils.aws.arns import (
+ extract_account_id_from_arn,
+ extract_region_from_arn,
+ get_partition,
+)
+from localstack.utils.aws.aws_stack import get_valid_regions_for_service
+from localstack.utils.aws.request_context import (
+ extract_account_id_from_headers,
+ extract_region_from_headers,
+)
+from localstack.utils.collections import select_attributes, select_from_typed_dict
+from localstack.utils.common import short_uid, to_bytes
+from localstack.utils.json import BytesEncoder, canonical_json
+from localstack.utils.scheduler import Scheduler
+from localstack.utils.strings import long_uid, md5, to_str
+from localstack.utils.threads import FuncThread, start_thread
+
+# set up logger
+LOG = logging.getLogger(__name__)
+
+# action header prefix
+ACTION_PREFIX = "DynamoDB_20120810."
+
+# list of actions subject to throughput limitations
+READ_THROTTLED_ACTIONS = [
+ "GetItem",
+ "Query",
+ "Scan",
+ "TransactGetItems",
+ "BatchGetItem",
+]
+WRITE_THROTTLED_ACTIONS = [
+ "PutItem",
+ "BatchWriteItem",
+ "UpdateItem",
+ "DeleteItem",
+ "TransactWriteItems",
+]
+THROTTLED_ACTIONS = READ_THROTTLED_ACTIONS + WRITE_THROTTLED_ACTIONS
+
+MANAGED_KMS_KEYS = {}
+
+
+def dynamodb_table_exists(table_name: str, client=None) -> bool:
+ client = client or connect_to().dynamodb
+ paginator = client.get_paginator("list_tables")
+ pages = paginator.paginate(PaginationConfig={"PageSize": 100})
+ table_name = to_str(table_name)
+ return any(table_name in page["TableNames"] for page in pages)
+
+
+class EventForwarder:
+ def __init__(self, num_thread: int = 10):
+ self.executor = ThreadPoolExecutor(num_thread, thread_name_prefix="ddb_stream_fwd")
+
+ def shutdown(self):
+ self.executor.shutdown(wait=False)
+
+ def forward_to_targets(
+ self, account_id: str, region_name: str, records_map: RecordsMap, background: bool = True
+ ) -> None:
+ if background:
+ self._submit_records(
+ account_id=account_id,
+ region_name=region_name,
+ records_map=records_map,
+ )
+ else:
+ self._forward(account_id, region_name, records_map)
+
+ def _submit_records(self, account_id: str, region_name: str, records_map: RecordsMap):
+ """Required for patching submit with local thread context for EventStudio"""
+ self.executor.submit(
+ self._forward,
+ account_id,
+ region_name,
+ records_map,
+ )
+
+ def _forward(self, account_id: str, region_name: str, records_map: RecordsMap) -> None:
+ try:
+ self.forward_to_kinesis_stream(account_id, region_name, records_map)
+ except Exception as e:
+ LOG.debug(
+ "Error while publishing to Kinesis streams: '%s'",
+ e,
+ exc_info=LOG.isEnabledFor(logging.DEBUG),
+ )
+
+ try:
+ self.forward_to_ddb_stream(account_id, region_name, records_map)
+ except Exception as e:
+ LOG.debug(
+ "Error while publishing to DynamoDB streams, '%s'",
+ e,
+ exc_info=LOG.isEnabledFor(logging.DEBUG),
+ )
+
+ @staticmethod
+ def forward_to_ddb_stream(account_id: str, region_name: str, records_map: RecordsMap) -> None:
+ dynamodbstreams_api.forward_events(account_id, region_name, records_map)
+
+ @staticmethod
+ def forward_to_kinesis_stream(
+ account_id: str, region_name: str, records_map: RecordsMap
+ ) -> None:
+ # You can only stream data from DynamoDB to Kinesis Data Streams in the same AWS account and AWS Region as your
+ # table.
+ # You can only stream data from a DynamoDB table to one Kinesis data stream.
+ store = get_store(account_id, region_name)
+
+ for table_name, table_records in records_map.items():
+ table_stream_type = table_records["table_stream_type"]
+ if not table_stream_type.is_kinesis:
+ continue
+
+ kinesis_records = []
+
+ table_arn = arns.dynamodb_table_arn(table_name, account_id, region_name)
+ records = table_records["records"]
+ table_def = store.table_definitions.get(table_name) or {}
+ stream_arn = table_def["KinesisDataStreamDestinations"][-1]["StreamArn"]
+ for record in records:
+ kinesis_record = dict(
+ tableName=table_name,
+ recordFormat="application/json",
+ userIdentity=None,
+ **record,
+ )
+ fields_to_remove = {"StreamViewType", "SequenceNumber"}
+ kinesis_record["dynamodb"] = {
+ k: v for k, v in record["dynamodb"].items() if k not in fields_to_remove
+ }
+ kinesis_record.pop("eventVersion", None)
+
+ hash_keys = list(
+ filter(lambda key: key["KeyType"] == "HASH", table_def["KeySchema"])
+ )
+ # TODO: reverse properly how AWS creates the partition key, it seems to be an MD5 hash
+ kinesis_partition_key = md5(f"{table_name}{hash_keys[0]['AttributeName']}")
+
+ kinesis_records.append(
+ {
+ "Data": json.dumps(kinesis_record, cls=BytesEncoder),
+ "PartitionKey": kinesis_partition_key,
+ }
+ )
+
+ kinesis = connect_to(
+ aws_access_key_id=account_id,
+ aws_secret_access_key=INTERNAL_AWS_SECRET_ACCESS_KEY,
+ region_name=region_name,
+ ).kinesis.request_metadata(service_principal="dynamodb", source_arn=table_arn)
+
+ kinesis.put_records(
+ StreamARN=stream_arn,
+ Records=kinesis_records,
+ )
+
+ @classmethod
+ def is_kinesis_stream_exists(cls, stream_arn):
+ account_id = extract_account_id_from_arn(stream_arn)
+ region_name = extract_region_from_arn(stream_arn)
+
+ kinesis = connect_to(
+ aws_access_key_id=account_id,
+ aws_secret_access_key=INTERNAL_AWS_SECRET_ACCESS_KEY,
+ region_name=region_name,
+ ).kinesis
+ stream_name_from_arn = stream_arn.split("/", 1)[1]
+ # check if the stream exists in kinesis for the user
+ filtered = list(
+ filter(
+ lambda stream_name: stream_name == stream_name_from_arn,
+ kinesis.list_streams()["StreamNames"],
+ )
+ )
+ return bool(filtered)
+
+
+class SSEUtils:
+ """Utils for server-side encryption (SSE)"""
+
+ @classmethod
+ def get_sse_kms_managed_key(cls, account_id: str, region_name: str):
+ from localstack.services.kms import provider
+
+ existing_key = MANAGED_KMS_KEYS.get(region_name)
+ if existing_key:
+ return existing_key
+ kms_client = connect_to(
+ aws_access_key_id=account_id,
+ aws_secret_access_key=INTERNAL_AWS_SECRET_ACCESS_KEY,
+ region_name=region_name,
+ ).kms
+ key_data = kms_client.create_key(
+ Description="Default key that protects my DynamoDB data when no other key is defined"
+ )
+ key_id = key_data["KeyMetadata"]["KeyId"]
+
+ provider.set_key_managed(key_id, account_id, region_name)
+ MANAGED_KMS_KEYS[region_name] = key_id
+ return key_id
+
+ @classmethod
+ def get_sse_description(cls, account_id: str, region_name: str, data):
+ if data.get("Enabled"):
+ kms_master_key_id = data.get("KMSMasterKeyId")
+ if not kms_master_key_id:
+ # this is of course not the actual key for dynamodb, just a better, since existing, mock
+ kms_master_key_id = cls.get_sse_kms_managed_key(account_id, region_name)
+ kms_master_key_id = arns.kms_key_arn(kms_master_key_id, account_id, region_name)
+ return {
+ "Status": "ENABLED",
+ "SSEType": "KMS", # no other value is allowed here
+ "KMSMasterKeyArn": kms_master_key_id,
+ }
+ return {}
+
+
+class ValidationException(CommonServiceException):
+ def __init__(self, message: str):
+ super().__init__(code="ValidationException", status_code=400, message=message)
+
+
+def get_store(account_id: str, region_name: str) -> DynamoDBStore:
+ # special case: AWS NoSQL Workbench sends "localhost" as region - replace with proper region here
+ region_name = DynamoDBProvider.ddb_region_name(region_name)
+ return dynamodb_stores[account_id][region_name]
+
+
+@contextmanager
+def modify_context_region(context: RequestContext, region: str):
+ """
+ Context manager that modifies the region of a `RequestContext`. At the exit, the context is restored to its
+ original state.
+
+ :param context: the context to modify
+ :param region: the modified region
+ :return: a modified `RequestContext`
+ """
+ original_region = context.region
+ original_authorization = context.request.headers.get("Authorization")
+
+ key = get_ddb_access_key(context.account_id, region)
+
+ context.region = region
+ context.request.headers["Authorization"] = re.sub(
+ AUTH_CREDENTIAL_REGEX,
+ rf"Credential={key}/\2/{region}/\4/",
+ original_authorization or "",
+ flags=re.IGNORECASE,
+ )
+
+ try:
+ yield context
+ except Exception:
+ raise
+ finally:
+ # revert the original context
+ context.region = original_region
+ context.request.headers["Authorization"] = original_authorization
+
+
+class DynamoDBDeveloperEndpoints:
+ """
+ Developer endpoints for DynamoDB
+ DELETE /_aws/dynamodb/expired - delete expired items from tables with TTL enabled; return the number of expired
+ items deleted
+ """
+
+ @route("/_aws/dynamodb/expired", methods=["DELETE"])
+ def delete_expired_messages(self, _: Request):
+ no_expired_items = delete_expired_items()
+ return {"ExpiredItems": no_expired_items}
+
+
+def delete_expired_items() -> int:
+ """
+ This utility function iterates over all stores, looks for tables with TTL enabled,
+ scan such tables and delete expired items.
+ """
+ no_expired_items = 0
+ for account_id, region_name, state in dynamodb_stores.iter_stores():
+ ttl_specs = state.ttl_specifications
+ client = connect_to(aws_access_key_id=account_id, region_name=region_name).dynamodb
+ for table_name, ttl_spec in ttl_specs.items():
+ if ttl_spec.get("Enabled", False):
+ attribute_name = ttl_spec.get("AttributeName")
+ current_time = int(datetime.now().timestamp())
+ try:
+ result = client.scan(
+ TableName=table_name,
+ FilterExpression="#ttl <= :threshold",
+ ExpressionAttributeValues={":threshold": {"N": str(current_time)}},
+ ExpressionAttributeNames={"#ttl": attribute_name},
+ )
+ items_to_delete = result.get("Items", [])
+ no_expired_items += len(items_to_delete)
+ table_description = client.describe_table(TableName=table_name)
+ partition_key, range_key = _get_hash_and_range_key(table_description)
+ keys_to_delete = [
+ {partition_key: item.get(partition_key)}
+ if range_key is None
+ else {
+ partition_key: item.get(partition_key),
+ range_key: item.get(range_key),
+ }
+ for item in items_to_delete
+ ]
+ delete_requests = [{"DeleteRequest": {"Key": key}} for key in keys_to_delete]
+ for i in range(0, len(delete_requests), 25):
+ batch = delete_requests[i : i + 25]
+ client.batch_write_item(RequestItems={table_name: batch})
+ except Exception as e:
+ LOG.warning(
+ "An error occurred when deleting expired items from table %s: %s",
+ table_name,
+ e,
+ )
+ return no_expired_items
+
+
+def _get_hash_and_range_key(table_description: DescribeTableOutput) -> [str, str | None]:
+ key_schema = table_description.get("Table", {}).get("KeySchema", [])
+ hash_key, range_key = None, None
+ for key in key_schema:
+ if key["KeyType"] == "HASH":
+ hash_key = key["AttributeName"]
+ if key["KeyType"] == "RANGE":
+ range_key = key["AttributeName"]
+ return hash_key, range_key
+
+
+class ExpiredItemsWorker:
+ """A worker that periodically computes and deletes expired items from DynamoDB tables"""
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.scheduler = Scheduler()
+ self.thread: Optional[FuncThread] = None
+ self.mutex = threading.RLock()
+
+ def start(self):
+ with self.mutex:
+ if self.thread:
+ return
+
+ self.scheduler = Scheduler()
+ self.scheduler.schedule(
+ delete_expired_items, period=60 * 60
+ ) # the background process seems slow on AWS
+
+ def _run(*_args):
+ self.scheduler.run()
+
+ self.thread = start_thread(_run, name="ddb-remove-expired-items")
+
+ def stop(self):
+ with self.mutex:
+ if self.scheduler:
+ self.scheduler.close()
+
+ if self.thread:
+ self.thread.stop()
+
+ self.thread = None
+ self.scheduler = None
+
+
+class DynamoDBProvider(DynamodbApi, ServiceLifecycleHook):
+ server: DynamodbServer
+ """The instance of the server managing the instance of DynamoDB local"""
+
+ def __init__(self):
+ self.server = self._new_dynamodb_server()
+ self._expired_items_worker = ExpiredItemsWorker()
+ self._router_rules = []
+ self._event_forwarder = EventForwarder()
+
+ def on_before_start(self):
+ self.server.start_dynamodb()
+ if config.DYNAMODB_REMOVE_EXPIRED_ITEMS:
+ self._expired_items_worker.start()
+ self._router_rules = ROUTER.add(DynamoDBDeveloperEndpoints())
+
+ def on_before_stop(self):
+ self._expired_items_worker.stop()
+ ROUTER.remove(self._router_rules)
+ self._event_forwarder.shutdown()
+
+ def accept_state_visitor(self, visitor: StateVisitor):
+ visitor.visit(dynamodb_stores)
+ visitor.visit(dynamodbstreams_stores)
+ visitor.visit(AssetDirectory(self.service, os.path.join(config.dirs.data, self.service)))
+
+ def on_before_state_reset(self):
+ self.server.stop_dynamodb()
+
+ def on_before_state_load(self):
+ self.server.stop_dynamodb()
+
+ def on_after_state_reset(self):
+ self.server.start_dynamodb()
+
+ @staticmethod
+ def _new_dynamodb_server() -> DynamodbServer:
+ return DynamodbServer.get()
+
+ def on_after_state_load(self):
+ self.server.start_dynamodb()
+
+ def on_after_init(self):
+ # add response processor specific to ddblocal
+ handlers.modify_service_response.append(self.service, modify_ddblocal_arns)
+
+ # routes for the shell ui
+ ROUTER.add(
+ path="/shell",
+ endpoint=self.handle_shell_ui_redirect,
+ methods=["GET"],
+ )
+ ROUTER.add(
+ path="/shell/",
+ endpoint=self.handle_shell_ui_request,
+ )
+
+ def _forward_request(
+ self,
+ context: RequestContext,
+ region: str | None,
+ service_request: ServiceRequest | None = None,
+ ) -> ServiceResponse:
+ """
+ Modify the context region and then forward request to DynamoDB Local.
+
+ This is used for operations impacted by global tables. In LocalStack, a single copy of global table
+ is kept, and any requests to replicated tables are forwarded to this original table.
+ """
+ if region:
+ with modify_context_region(context, region):
+ return self.forward_request(context, service_request=service_request)
+ return self.forward_request(context, service_request=service_request)
+
+ def forward_request(
+ self, context: RequestContext, service_request: ServiceRequest = None
+ ) -> ServiceResponse:
+ """
+ Forward a request to DynamoDB Local.
+ """
+ self.check_provisioned_throughput(context.operation.name)
+ self.prepare_request_headers(
+ context.request.headers, account_id=context.account_id, region_name=context.region
+ )
+ return self.server.proxy(context, service_request)
+
+ def get_forward_url(self, account_id: str, region_name: str) -> str:
+ """Return the URL of the backend DynamoDBLocal server to forward requests to"""
+ return self.server.url
+
+ def handle_shell_ui_redirect(self, request: werkzeug.Request) -> Response:
+ headers = {"Refresh": f"0; url={config.external_service_url()}/shell/index.html"}
+ return Response("", headers=headers)
+
+ def handle_shell_ui_request(self, request: werkzeug.Request, req_path: str) -> Response:
+ # TODO: "DynamoDB Local Web Shell was deprecated with version 1.16.X and is not available any
+ # longer from 1.17.X to latest. There are no immediate plans for a new Web Shell to be introduced."
+ # -> keeping this for now, to allow configuring custom installs; should consider removing it in the future
+ # https://repost.aws/questions/QUHyIzoEDqQ3iOKlUEp1LPWQ#ANdBm9Nz9TRf6VqR3jZtcA1g
+ req_path = f"/{req_path}" if not req_path.startswith("/") else req_path
+ account_id = extract_account_id_from_headers(request.headers)
+ region_name = extract_region_from_headers(request.headers)
+ url = f"{self.get_forward_url(account_id, region_name)}/shell{req_path}"
+ result = requests.request(
+ method=request.method, url=url, headers=request.headers, data=request.data
+ )
+ return Response(result.content, headers=dict(result.headers), status=result.status_code)
+
+ #
+ # Table ops
+ #
+
+ @handler("CreateTable", expand=False)
+ def create_table(
+ self,
+ context: RequestContext,
+ create_table_input: CreateTableInput,
+ ) -> CreateTableOutput:
+ table_name = create_table_input["TableName"]
+
+ # Return this specific error message to keep parity with AWS
+ if self.table_exists(context.account_id, context.region, table_name):
+ raise ResourceInUseException(f"Table already exists: {table_name}")
+
+ billing_mode = create_table_input.get("BillingMode")
+ provisioned_throughput = create_table_input.get("ProvisionedThroughput")
+ if billing_mode == BillingMode.PAY_PER_REQUEST and provisioned_throughput is not None:
+ raise ValidationException(
+ "One or more parameter values were invalid: Neither ReadCapacityUnits nor WriteCapacityUnits can be "
+ "specified when BillingMode is PAY_PER_REQUEST"
+ )
+
+ result = self.forward_request(context)
+
+ table_description = result["TableDescription"]
+ table_description["TableArn"] = table_arn = self.fix_table_arn(
+ context.account_id, context.region, table_description["TableArn"]
+ )
+
+ backend = get_store(context.account_id, context.region)
+ backend.table_definitions[table_name] = table_definitions = dict(create_table_input)
+ backend.TABLE_REGION[table_name] = context.region
+
+ if "TableId" not in table_definitions:
+ table_definitions["TableId"] = long_uid()
+
+ if "SSESpecification" in table_definitions:
+ sse_specification = table_definitions.pop("SSESpecification")
+ table_definitions["SSEDescription"] = SSEUtils.get_sse_description(
+ context.account_id, context.region, sse_specification
+ )
+
+ if table_definitions:
+ table_content = result.get("Table", {})
+ table_content.update(table_definitions)
+ table_description.update(table_content)
+
+ if "StreamSpecification" in table_definitions:
+ create_dynamodb_stream(
+ context.account_id,
+ context.region,
+ table_definitions,
+ table_description.get("LatestStreamLabel"),
+ )
+
+ if "TableClass" in table_definitions:
+ table_class = table_description.pop("TableClass", None) or table_definitions.pop(
+ "TableClass"
+ )
+ table_description["TableClassSummary"] = {"TableClass": table_class}
+
+ if "GlobalSecondaryIndexes" in table_description:
+ gsis = copy.deepcopy(table_description["GlobalSecondaryIndexes"])
+ # update the different values, as DynamoDB-local v2 has a regression around GSI and does not return anything
+ # anymore
+ for gsi in gsis:
+ index_name = gsi.get("IndexName", "")
+ gsi.update(
+ {
+ "IndexArn": f"{table_arn}/index/{index_name}",
+ "IndexSizeBytes": 0,
+ "IndexStatus": "ACTIVE",
+ "ItemCount": 0,
+ }
+ )
+ gsi_provisioned_throughput = gsi.setdefault("ProvisionedThroughput", {})
+ gsi_provisioned_throughput["NumberOfDecreasesToday"] = 0
+
+ if billing_mode == BillingMode.PAY_PER_REQUEST:
+ gsi_provisioned_throughput["ReadCapacityUnits"] = 0
+ gsi_provisioned_throughput["WriteCapacityUnits"] = 0
+
+ table_description["GlobalSecondaryIndexes"] = gsis
+
+ if "ProvisionedThroughput" in table_description:
+ if "NumberOfDecreasesToday" not in table_description["ProvisionedThroughput"]:
+ table_description["ProvisionedThroughput"]["NumberOfDecreasesToday"] = 0
+
+ tags = table_definitions.pop("Tags", [])
+ if tags:
+ get_store(context.account_id, context.region).TABLE_TAGS[table_arn] = {
+ tag["Key"]: tag["Value"] for tag in tags
+ }
+
+ # remove invalid attributes from result
+ table_description.pop("Tags", None)
+ table_description.pop("BillingMode", None)
+
+ return result
+
+ def delete_table(
+ self, context: RequestContext, table_name: TableName, **kwargs
+ ) -> DeleteTableOutput:
+ global_table_region = self.get_global_table_region(context, table_name)
+
+ # Limitation note: On AWS, for a replicated table, if the source table is deleted, the replicated tables continue to exist.
+ # This is not the case for LocalStack, where all replicated tables will also be removed if source is deleted.
+
+ result = self._forward_request(context=context, region=global_table_region)
+
+ table_arn = result.get("TableDescription", {}).get("TableArn")
+ table_arn = self.fix_table_arn(context.account_id, context.region, table_arn)
+ dynamodbstreams_api.delete_streams(context.account_id, context.region, table_arn)
+
+ store = get_store(context.account_id, context.region)
+ store.TABLE_TAGS.pop(table_arn, None)
+ store.REPLICAS.pop(table_name, None)
+
+ return result
+
+ def describe_table(
+ self, context: RequestContext, table_name: TableName, **kwargs
+ ) -> DescribeTableOutput:
+ global_table_region = self.get_global_table_region(context, table_name)
+
+ result = self._forward_request(context=context, region=global_table_region)
+ table_description: TableDescription = result["Table"]
+
+ # Update table properties from LocalStack stores
+ if table_props := get_store(context.account_id, context.region).table_properties.get(
+ table_name
+ ):
+ table_description.update(table_props)
+
+ store = get_store(context.account_id, context.region)
+
+ # Update replication details
+ replicas: Dict[RegionName, ReplicaDescription] = store.REPLICAS.get(table_name, {})
+
+ replica_description_list = []
+
+ if global_table_region != context.region:
+ replica_description_list.append(
+ ReplicaDescription(
+ RegionName=global_table_region, ReplicaStatus=ReplicaStatus.ACTIVE
+ )
+ )
+
+ for replica_region, replica_description in replicas.items():
+ # The replica in the region being queried must not be returned
+ if replica_region != context.region:
+ replica_description_list.append(replica_description)
+
+ if replica_description_list:
+ table_description.update({"Replicas": replica_description_list})
+
+ # update only TableId and SSEDescription if present
+ if table_definitions := store.table_definitions.get(table_name):
+ for key in ["TableId", "SSEDescription"]:
+ if table_definitions.get(key):
+ table_description[key] = table_definitions[key]
+ if "TableClass" in table_definitions:
+ table_description["TableClassSummary"] = {
+ "TableClass": table_definitions["TableClass"]
+ }
+
+ if "GlobalSecondaryIndexes" in table_description:
+ for gsi in table_description["GlobalSecondaryIndexes"]:
+ default_values = {
+ "NumberOfDecreasesToday": 0,
+ "ReadCapacityUnits": 0,
+ "WriteCapacityUnits": 0,
+ }
+ # even if the billing mode is PAY_PER_REQUEST, AWS returns the Read and Write Capacity Units
+ # Terraform depends on this parity for update operations
+ gsi["ProvisionedThroughput"] = default_values | gsi.get("ProvisionedThroughput", {})
+
+ return DescribeTableOutput(
+ Table=select_from_typed_dict(TableDescription, table_description)
+ )
+
+ @handler("UpdateTable", expand=False)
+ def update_table(
+ self, context: RequestContext, update_table_input: UpdateTableInput
+ ) -> UpdateTableOutput:
+ table_name = update_table_input["TableName"]
+ global_table_region = self.get_global_table_region(context, table_name)
+
+ try:
+ result = self._forward_request(context=context, region=global_table_region)
+ except CommonServiceException as exc:
+ # DynamoDBLocal refuses to update certain table params and raises.
+ # But we still need to update this info in LocalStack stores
+ if not (exc.code == "ValidationException" and exc.message == "Nothing to update"):
+ raise
+
+ if table_class := update_table_input.get("TableClass"):
+ table_definitions = get_store(
+ context.account_id, context.region
+ ).table_definitions.setdefault(table_name, {})
+ table_definitions["TableClass"] = table_class
+
+ if replica_updates := update_table_input.get("ReplicaUpdates"):
+ store = get_store(context.account_id, global_table_region)
+
+ # Dict with source region to set of replicated regions
+ replicas: Dict[RegionName, ReplicaDescription] = store.REPLICAS.get(table_name, {})
+
+ for replica_update in replica_updates:
+ for key, details in replica_update.items():
+ # Replicated region
+ target_region = details.get("RegionName")
+
+ # Check if replicated region is valid
+ if target_region not in get_valid_regions_for_service("dynamodb"):
+ raise ValidationException(f"Region {target_region} is not supported")
+
+ match key:
+ case "Create":
+ if target_region in replicas:
+ raise ValidationException(
+ f"Failed to create a the new replica of table with name: '{table_name}' because one or more replicas already existed as tables."
+ )
+ replicas[target_region] = ReplicaDescription(
+ RegionName=target_region,
+ KMSMasterKeyId=details.get("KMSMasterKeyId"),
+ ProvisionedThroughputOverride=details.get(
+ "ProvisionedThroughputOverride"
+ ),
+ GlobalSecondaryIndexes=details.get("GlobalSecondaryIndexes"),
+ ReplicaStatus=ReplicaStatus.ACTIVE,
+ )
+ case "Delete":
+ try:
+ replicas.pop(target_region)
+ except KeyError:
+ raise ValidationException(
+ "Update global table operation failed because one or more replicas were not part of the global table."
+ )
+
+ store.REPLICAS[table_name] = replicas
+
+ # update response content
+ SchemaExtractor.invalidate_table_schema(
+ table_name, context.account_id, global_table_region
+ )
+
+ schema = SchemaExtractor.get_table_schema(
+ table_name, context.account_id, global_table_region
+ )
+
+ if sse_specification_input := update_table_input.get("SSESpecification"):
+ # If SSESpecification is changed, update store and return the 'UPDATING' status in the response
+ table_definition = get_store(
+ context.account_id, context.region
+ ).table_definitions.setdefault(table_name, {})
+ if not sse_specification_input["Enabled"]:
+ table_definition.pop("SSEDescription", None)
+ schema["Table"]["SSEDescription"]["Status"] = "UPDATING"
+
+ return UpdateTableOutput(TableDescription=schema["Table"])
+
+ SchemaExtractor.invalidate_table_schema(table_name, context.account_id, global_table_region)
+
+ schema = SchemaExtractor.get_table_schema(
+ table_name, context.account_id, global_table_region
+ )
+
+ # TODO: DDB streams must also be created for replicas
+ if update_table_input.get("StreamSpecification"):
+ create_dynamodb_stream(
+ context.account_id,
+ context.region,
+ update_table_input,
+ result["TableDescription"].get("LatestStreamLabel"),
+ )
+
+ return UpdateTableOutput(TableDescription=schema["Table"])
+
+ def list_tables(
+ self,
+ context: RequestContext,
+ exclusive_start_table_name: TableName = None,
+ limit: ListTablesInputLimit = None,
+ **kwargs,
+ ) -> ListTablesOutput:
+ response = self.forward_request(context)
+
+ # Add replicated tables
+ replicas = get_store(context.account_id, context.region).REPLICAS
+ for replicated_table, replications in replicas.items():
+ for replica_region, replica_description in replications.items():
+ if context.region == replica_region:
+ response["TableNames"].append(replicated_table)
+
+ return response
+
+ #
+ # Item ops
+ #
+
+ @handler("PutItem", expand=False)
+ def put_item(self, context: RequestContext, put_item_input: PutItemInput) -> PutItemOutput:
+ table_name = put_item_input["TableName"]
+ global_table_region = self.get_global_table_region(context, table_name)
+
+ has_return_values = put_item_input.get("ReturnValues") == "ALL_OLD"
+ stream_type = get_table_stream_type(context.account_id, context.region, table_name)
+
+ # if the request doesn't ask for ReturnValues and we have stream enabled, we need to modify the request to
+ # force DDBLocal to return those values
+ if stream_type and not has_return_values:
+ service_req = copy.copy(context.service_request)
+ service_req["ReturnValues"] = "ALL_OLD"
+ result = self._forward_request(
+ context=context, region=global_table_region, service_request=service_req
+ )
+ else:
+ result = self._forward_request(context=context, region=global_table_region)
+
+ # Since this operation makes use of global table region, we need to use the same region for all
+ # calls made via the inter-service client. This is taken care of by passing the account ID and
+ # region, e.g. when getting the stream spec
+
+ # Get stream specifications details for the table
+ if stream_type:
+ item = put_item_input["Item"]
+ # prepare record keys
+ keys = SchemaExtractor.extract_keys(
+ item=item,
+ table_name=table_name,
+ account_id=context.account_id,
+ region_name=global_table_region,
+ )
+ # because we modified the request, we will always have the ReturnValues if we have streams enabled
+ if has_return_values:
+ existing_item = result.get("Attributes")
+ else:
+ # remove the ReturnValues if the client didn't ask for it
+ existing_item = result.pop("Attributes", None)
+
+ if existing_item == item:
+ return result
+
+ # create record
+ record = self.get_record_template(
+ context.region,
+ )
+ record["eventName"] = "INSERT" if not existing_item else "MODIFY"
+ record["dynamodb"]["Keys"] = keys
+ record["dynamodb"]["SizeBytes"] = _get_size_bytes(item)
+
+ if stream_type.needs_new_image:
+ record["dynamodb"]["NewImage"] = item
+ if stream_type.stream_view_type:
+ record["dynamodb"]["StreamViewType"] = stream_type.stream_view_type
+ if existing_item and stream_type.needs_old_image:
+ record["dynamodb"]["OldImage"] = existing_item
+
+ records_map = {
+ table_name: TableRecords(records=[record], table_stream_type=stream_type)
+ }
+ self.forward_stream_records(context.account_id, context.region, records_map)
+ return result
+
+ @handler("DeleteItem", expand=False)
+ def delete_item(
+ self,
+ context: RequestContext,
+ delete_item_input: DeleteItemInput,
+ ) -> DeleteItemOutput:
+ table_name = delete_item_input["TableName"]
+ global_table_region = self.get_global_table_region(context, table_name)
+
+ has_return_values = delete_item_input.get("ReturnValues") == "ALL_OLD"
+ stream_type = get_table_stream_type(context.account_id, context.region, table_name)
+
+ # if the request doesn't ask for ReturnValues and we have stream enabled, we need to modify the request to
+ # force DDBLocal to return those values
+ if stream_type and not has_return_values:
+ service_req = copy.copy(context.service_request)
+ service_req["ReturnValues"] = "ALL_OLD"
+ result = self._forward_request(
+ context=context, region=global_table_region, service_request=service_req
+ )
+ else:
+ result = self._forward_request(context=context, region=global_table_region)
+
+ # determine and forward stream record
+ if stream_type:
+ # because we modified the request, we will always have the ReturnValues if we have streams enabled
+ if has_return_values:
+ existing_item = result.get("Attributes")
+ else:
+ # remove the ReturnValues if the client didn't ask for it
+ existing_item = result.pop("Attributes", None)
+
+ if not existing_item:
+ return result
+
+ # create record
+ record = self.get_record_template(context.region)
+ record["eventName"] = "REMOVE"
+ record["dynamodb"]["Keys"] = delete_item_input["Key"]
+ record["dynamodb"]["SizeBytes"] = _get_size_bytes(existing_item)
+
+ if stream_type.stream_view_type:
+ record["dynamodb"]["StreamViewType"] = stream_type.stream_view_type
+ if stream_type.needs_old_image:
+ record["dynamodb"]["OldImage"] = existing_item
+
+ records_map = {
+ table_name: TableRecords(records=[record], table_stream_type=stream_type)
+ }
+ self.forward_stream_records(context.account_id, context.region, records_map)
+
+ return result
+
+ @handler("UpdateItem", expand=False)
+ def update_item(
+ self,
+ context: RequestContext,
+ update_item_input: UpdateItemInput,
+ ) -> UpdateItemOutput:
+ # TODO: UpdateItem is harder to use ReturnValues for Streams, because it needs the Before and After images.
+ table_name = update_item_input["TableName"]
+ global_table_region = self.get_global_table_region(context, table_name)
+
+ existing_item = None
+ stream_type = get_table_stream_type(context.account_id, context.region, table_name)
+
+ # even if we don't need the OldImage, we still need to fetch the existing item to know if the event is INSERT
+ # or MODIFY (UpdateItem will create the object if it doesn't exist, and you don't use a ConditionExpression)
+ if stream_type:
+ existing_item = ItemFinder.find_existing_item(
+ put_item=update_item_input,
+ table_name=table_name,
+ account_id=context.account_id,
+ region_name=context.region,
+ endpoint_url=self.server.url,
+ )
+
+ result = self._forward_request(context=context, region=global_table_region)
+
+ # construct and forward stream record
+ if stream_type:
+ updated_item = ItemFinder.find_existing_item(
+ put_item=update_item_input,
+ table_name=table_name,
+ account_id=context.account_id,
+ region_name=context.region,
+ endpoint_url=self.server.url,
+ )
+ if not updated_item or updated_item == existing_item:
+ return result
+
+ record = self.get_record_template(context.region)
+ record["eventName"] = "INSERT" if not existing_item else "MODIFY"
+ record["dynamodb"]["Keys"] = update_item_input["Key"]
+ record["dynamodb"]["SizeBytes"] = _get_size_bytes(updated_item)
+
+ if stream_type.stream_view_type:
+ record["dynamodb"]["StreamViewType"] = stream_type.stream_view_type
+ if existing_item and stream_type.needs_old_image:
+ record["dynamodb"]["OldImage"] = existing_item
+ if stream_type.needs_new_image:
+ record["dynamodb"]["NewImage"] = updated_item
+
+ records_map = {
+ table_name: TableRecords(records=[record], table_stream_type=stream_type)
+ }
+ self.forward_stream_records(context.account_id, context.region, records_map)
+
+ return result
+
+ @handler("GetItem", expand=False)
+ def get_item(self, context: RequestContext, get_item_input: GetItemInput) -> GetItemOutput:
+ table_name = get_item_input["TableName"]
+ global_table_region = self.get_global_table_region(context, table_name)
+ result = self._forward_request(context=context, region=global_table_region)
+ self.fix_consumed_capacity(get_item_input, result)
+ return result
+
+ #
+ # Queries
+ #
+
+ @handler("Query", expand=False)
+ def query(self, context: RequestContext, query_input: QueryInput) -> QueryOutput:
+ index_name = query_input.get("IndexName")
+ if index_name:
+ if not is_index_query_valid(context.account_id, context.region, query_input):
+ raise ValidationException(
+ "One or more parameter values were invalid: Select type ALL_ATTRIBUTES "
+ "is not supported for global secondary index id-index because its projection "
+ "type is not ALL",
+ )
+
+ table_name = query_input["TableName"]
+ global_table_region = self.get_global_table_region(context, table_name)
+ result = self._forward_request(context=context, region=global_table_region)
+ self.fix_consumed_capacity(query_input, result)
+ return result
+
+ @handler("Scan", expand=False)
+ def scan(self, context: RequestContext, scan_input: ScanInput) -> ScanOutput:
+ table_name = scan_input["TableName"]
+ global_table_region = self.get_global_table_region(context, table_name)
+ result = self._forward_request(context=context, region=global_table_region)
+ return result
+
+ #
+ # Batch ops
+ #
+
+ @handler("BatchWriteItem", expand=False)
+ def batch_write_item(
+ self,
+ context: RequestContext,
+ batch_write_item_input: BatchWriteItemInput,
+ ) -> BatchWriteItemOutput:
+ # TODO: add global table support
+ existing_items = {}
+ existing_items_to_fetch: BatchWriteItemRequestMap = {}
+ # UnprocessedItems should have the same format as RequestItems
+ unprocessed_items = {}
+ request_items = batch_write_item_input["RequestItems"]
+
+ tables_stream_type: dict[TableName, TableStreamType] = {}
+
+ for table_name, items in sorted(request_items.items(), key=itemgetter(0)):
+ if stream_type := get_table_stream_type(context.account_id, context.region, table_name):
+ tables_stream_type[table_name] = stream_type
+
+ for request in items:
+ request: WriteRequest
+ for key, inner_request in request.items():
+ inner_request: PutRequest | DeleteRequest
+ if self.should_throttle("BatchWriteItem"):
+ unprocessed_items_for_table = unprocessed_items.setdefault(table_name, [])
+ unprocessed_items_for_table.append(request)
+
+ elif stream_type:
+ existing_items_to_fetch_for_table = existing_items_to_fetch.setdefault(
+ table_name, []
+ )
+ existing_items_to_fetch_for_table.append(inner_request)
+
+ if existing_items_to_fetch:
+ existing_items = ItemFinder.find_existing_items(
+ put_items_per_table=existing_items_to_fetch,
+ account_id=context.account_id,
+ region_name=context.region,
+ endpoint_url=self.server.url,
+ )
+
+ try:
+ result = self.forward_request(context)
+ except CommonServiceException as e:
+ # TODO: validate if DynamoDB still raises `One of the required keys was not given a value`
+ # for now, replace with the schema error validation
+ if e.message == "One of the required keys was not given a value":
+ raise ValidationException("The provided key element does not match the schema")
+ raise e
+
+ # determine and forward stream records
+ if tables_stream_type:
+ records_map = self.prepare_batch_write_item_records(
+ account_id=context.account_id,
+ region_name=context.region,
+ tables_stream_type=tables_stream_type,
+ request_items=request_items,
+ existing_items=existing_items,
+ )
+ self.forward_stream_records(context.account_id, context.region, records_map)
+
+ # TODO: should unprocessed item which have mutated by `prepare_batch_write_item_records` be returned
+ for table_name, unprocessed_items_in_table in unprocessed_items.items():
+ unprocessed: dict = result["UnprocessedItems"]
+ result_unprocessed_table = unprocessed.setdefault(table_name, [])
+
+ # add the Unprocessed items to the response
+ # TODO: check before if the same request has not been Unprocessed by DDB local already?
+ # those might actually have been processed? shouldn't we remove them from the proxied request?
+ for request in unprocessed_items_in_table:
+ result_unprocessed_table.append(request)
+
+ # remove any table entry if it's empty
+ result["UnprocessedItems"] = {k: v for k, v in unprocessed.items() if v}
+
+ return result
+
+ @handler("BatchGetItem")
+ def batch_get_item(
+ self,
+ context: RequestContext,
+ request_items: BatchGetRequestMap,
+ return_consumed_capacity: ReturnConsumedCapacity = None,
+ **kwargs,
+ ) -> BatchGetItemOutput:
+ # TODO: add global table support
+ return self.forward_request(context)
+
+ #
+ # Transactions
+ #
+
+ @handler("TransactWriteItems", expand=False)
+ def transact_write_items(
+ self,
+ context: RequestContext,
+ transact_write_items_input: TransactWriteItemsInput,
+ ) -> TransactWriteItemsOutput:
+ # TODO: add global table support
+ existing_items = {}
+ existing_items_to_fetch: dict[str, list[Put | Update | Delete]] = {}
+ updated_items_to_fetch: dict[str, list[Update]] = {}
+ transact_items = transact_write_items_input["TransactItems"]
+ tables_stream_type: dict[TableName, TableStreamType] = {}
+ no_stream_tables = set()
+
+ for item in transact_items:
+ item: TransactWriteItem
+ for key in ["Put", "Update", "Delete"]:
+ inner_item: Put | Delete | Update = item.get(key)
+ if inner_item:
+ table_name = inner_item["TableName"]
+ # if we've seen the table already and it does not have streams, skip
+ if table_name in no_stream_tables:
+ continue
+
+ # if we have not seen the table, fetch its streaming status
+ if table_name not in tables_stream_type:
+ if stream_type := get_table_stream_type(
+ context.account_id, context.region, table_name
+ ):
+ tables_stream_type[table_name] = stream_type
+ else:
+ # no stream,
+ no_stream_tables.add(table_name)
+ continue
+
+ existing_items_to_fetch_for_table = existing_items_to_fetch.setdefault(
+ table_name, []
+ )
+ existing_items_to_fetch_for_table.append(inner_item)
+ if key == "Update":
+ updated_items_to_fetch_for_table = updated_items_to_fetch.setdefault(
+ table_name, []
+ )
+ updated_items_to_fetch_for_table.append(inner_item)
+
+ continue
+
+ if existing_items_to_fetch:
+ existing_items = ItemFinder.find_existing_items(
+ put_items_per_table=existing_items_to_fetch,
+ account_id=context.account_id,
+ region_name=context.region,
+ endpoint_url=self.server.url,
+ )
+
+ client_token: str | None = transact_write_items_input.get("ClientRequestToken")
+
+ if client_token:
+ # we sort the payload since identical payload but with different order could cause
+ # IdempotentParameterMismatchException error if a client token is provided
+ context.request.data = to_bytes(canonical_json(json.loads(context.request.data)))
+
+ result = self.forward_request(context)
+
+ # determine and forward stream records
+ if tables_stream_type:
+ updated_items = (
+ ItemFinder.find_existing_items(
+ put_items_per_table=existing_items_to_fetch,
+ account_id=context.account_id,
+ region_name=context.region,
+ endpoint_url=self.server.url,
+ )
+ if updated_items_to_fetch
+ else {}
+ )
+
+ records_map = self.prepare_transact_write_item_records(
+ account_id=context.account_id,
+ region_name=context.region,
+ transact_items=transact_items,
+ existing_items=existing_items,
+ updated_items=updated_items,
+ tables_stream_type=tables_stream_type,
+ )
+ self.forward_stream_records(context.account_id, context.region, records_map)
+
+ return result
+
+ @handler("TransactGetItems", expand=False)
+ def transact_get_items(
+ self,
+ context: RequestContext,
+ transact_items: TransactGetItemList,
+ return_consumed_capacity: ReturnConsumedCapacity = None,
+ ) -> TransactGetItemsOutput:
+ return self.forward_request(context)
+
+ @handler("ExecuteTransaction", expand=False)
+ def execute_transaction(
+ self, context: RequestContext, execute_transaction_input: ExecuteTransactionInput
+ ) -> ExecuteTransactionOutput:
+ result = self.forward_request(context)
+ return result
+
+ @handler("ExecuteStatement", expand=False)
+ def execute_statement(
+ self,
+ context: RequestContext,
+ execute_statement_input: ExecuteStatementInput,
+ ) -> ExecuteStatementOutput:
+ # TODO: this operation is still really slow with streams enabled
+ # find a way to make it better, same way as the other operations, by using returnvalues
+ # see https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/ql-reference.update.html
+ statement = execute_statement_input["Statement"]
+ # We found out that 'Parameters' can be an empty list when the request comes from the AWS JS client.
+ if execute_statement_input.get("Parameters", None) == []: # noqa
+ raise ValidationException(
+ "1 validation error detected: Value '[]' at 'parameters' failed to satisfy constraint: Member must have length greater than or equal to 1"
+ )
+ table_name = extract_table_name_from_partiql_update(statement)
+ existing_items = None
+ stream_type = table_name and get_table_stream_type(
+ context.account_id, context.region, table_name
+ )
+ if stream_type:
+ # Note: fetching the entire list of items is hugely inefficient, especially for larger tables
+ # TODO: find a mechanism to hook into the PartiQL update mechanism of DynamoDB Local directly!
+ existing_items = ItemFinder.list_existing_items_for_statement(
+ partiql_statement=statement,
+ account_id=context.account_id,
+ region_name=context.region,
+ endpoint_url=self.server.url,
+ )
+
+ result = self.forward_request(context)
+
+ # construct and forward stream record
+ if stream_type:
+ records = get_updated_records(
+ account_id=context.account_id,
+ region_name=context.region,
+ table_name=table_name,
+ existing_items=existing_items,
+ server_url=self.server.url,
+ table_stream_type=stream_type,
+ )
+ self.forward_stream_records(context.account_id, context.region, records)
+
+ return result
+
+ #
+ # Tags
+ #
+
+ def tag_resource(
+ self, context: RequestContext, resource_arn: ResourceArnString, tags: TagList, **kwargs
+ ) -> None:
+ table_tags = get_store(context.account_id, context.region).TABLE_TAGS
+ if resource_arn not in table_tags:
+ table_tags[resource_arn] = {}
+ table_tags[resource_arn].update({tag["Key"]: tag["Value"] for tag in tags})
+
+ def untag_resource(
+ self,
+ context: RequestContext,
+ resource_arn: ResourceArnString,
+ tag_keys: TagKeyList,
+ **kwargs,
+ ) -> None:
+ for tag_key in tag_keys or []:
+ get_store(context.account_id, context.region).TABLE_TAGS.get(resource_arn, {}).pop(
+ tag_key, None
+ )
+
+ def list_tags_of_resource(
+ self,
+ context: RequestContext,
+ resource_arn: ResourceArnString,
+ next_token: NextTokenString = None,
+ **kwargs,
+ ) -> ListTagsOfResourceOutput:
+ result = [
+ {"Key": k, "Value": v}
+ for k, v in get_store(context.account_id, context.region)
+ .TABLE_TAGS.get(resource_arn, {})
+ .items()
+ ]
+ return ListTagsOfResourceOutput(Tags=result)
+
+ #
+ # TTLs
+ #
+
+ def describe_time_to_live(
+ self, context: RequestContext, table_name: TableName, **kwargs
+ ) -> DescribeTimeToLiveOutput:
+ if not self.table_exists(context.account_id, context.region, table_name):
+ raise ResourceNotFoundException(
+ f"Requested resource not found: Table: {table_name} not found"
+ )
+
+ backend = get_store(context.account_id, context.region)
+ ttl_spec = backend.ttl_specifications.get(table_name)
+
+ result = {"TimeToLiveStatus": "DISABLED"}
+ if ttl_spec:
+ if ttl_spec.get("Enabled"):
+ ttl_status = "ENABLED"
+ else:
+ ttl_status = "DISABLED"
+ result = {
+ "AttributeName": ttl_spec.get("AttributeName"),
+ "TimeToLiveStatus": ttl_status,
+ }
+
+ return DescribeTimeToLiveOutput(TimeToLiveDescription=result)
+
+ def update_time_to_live(
+ self,
+ context: RequestContext,
+ table_name: TableName,
+ time_to_live_specification: TimeToLiveSpecification,
+ **kwargs,
+ ) -> UpdateTimeToLiveOutput:
+ if not self.table_exists(context.account_id, context.region, table_name):
+ raise ResourceNotFoundException(
+ f"Requested resource not found: Table: {table_name} not found"
+ )
+
+ # TODO: TTL status is maintained/mocked but no real expiry is happening for items
+ backend = get_store(context.account_id, context.region)
+ backend.ttl_specifications[table_name] = time_to_live_specification
+ return UpdateTimeToLiveOutput(TimeToLiveSpecification=time_to_live_specification)
+
+ #
+ # Global tables
+ #
+
+ def create_global_table(
+ self,
+ context: RequestContext,
+ global_table_name: TableName,
+ replication_group: ReplicaList,
+ **kwargs,
+ ) -> CreateGlobalTableOutput:
+ global_tables: Dict = get_store(context.account_id, context.region).GLOBAL_TABLES
+ if global_table_name in global_tables:
+ raise GlobalTableAlreadyExistsException("Global table with this name already exists")
+ replication_group = [grp.copy() for grp in replication_group or []]
+ data = {"GlobalTableName": global_table_name, "ReplicationGroup": replication_group}
+ global_tables[global_table_name] = data
+ for group in replication_group:
+ group["ReplicaStatus"] = "ACTIVE"
+ group["ReplicaStatusDescription"] = "Replica active"
+ return CreateGlobalTableOutput(GlobalTableDescription=data)
+
+ def describe_global_table(
+ self, context: RequestContext, global_table_name: TableName, **kwargs
+ ) -> DescribeGlobalTableOutput:
+ details = get_store(context.account_id, context.region).GLOBAL_TABLES.get(global_table_name)
+ if not details:
+ raise GlobalTableNotFoundException("Global table with this name does not exist")
+ return DescribeGlobalTableOutput(GlobalTableDescription=details)
+
+ def list_global_tables(
+ self,
+ context: RequestContext,
+ exclusive_start_global_table_name: TableName = None,
+ limit: PositiveIntegerObject = None,
+ region_name: RegionName = None,
+ **kwargs,
+ ) -> ListGlobalTablesOutput:
+ # TODO: add paging support
+ result = [
+ select_attributes(tab, ["GlobalTableName", "ReplicationGroup"])
+ for tab in get_store(context.account_id, context.region).GLOBAL_TABLES.values()
+ ]
+ return ListGlobalTablesOutput(GlobalTables=result)
+
+ def update_global_table(
+ self,
+ context: RequestContext,
+ global_table_name: TableName,
+ replica_updates: ReplicaUpdateList,
+ **kwargs,
+ ) -> UpdateGlobalTableOutput:
+ details = get_store(context.account_id, context.region).GLOBAL_TABLES.get(global_table_name)
+ if not details:
+ raise GlobalTableNotFoundException("Global table with this name does not exist")
+ for update in replica_updates or []:
+ repl_group = details["ReplicationGroup"]
+ # delete existing
+ delete = update.get("Delete")
+ if delete:
+ details["ReplicationGroup"] = [
+ g for g in repl_group if g["RegionName"] != delete["RegionName"]
+ ]
+ # create new
+ create = update.get("Create")
+ if create:
+ exists = [g for g in repl_group if g["RegionName"] == create["RegionName"]]
+ if exists:
+ continue
+ new_group = {
+ "RegionName": create["RegionName"],
+ "ReplicaStatus": "ACTIVE",
+ "ReplicaStatusDescription": "Replica active",
+ }
+ details["ReplicationGroup"].append(new_group)
+ return UpdateGlobalTableOutput(GlobalTableDescription=details)
+
+ #
+ # Kinesis Streaming
+ #
+
+ def enable_kinesis_streaming_destination(
+ self,
+ context: RequestContext,
+ table_name: TableName,
+ stream_arn: StreamArn,
+ enable_kinesis_streaming_configuration: EnableKinesisStreamingConfiguration = None,
+ **kwargs,
+ ) -> KinesisStreamingDestinationOutput:
+ self.ensure_table_exists(context.account_id, context.region, table_name)
+
+ stream = self._event_forwarder.is_kinesis_stream_exists(stream_arn=stream_arn)
+ if not stream:
+ raise ValidationException("User does not have a permission to use kinesis stream")
+
+ table_def = get_store(context.account_id, context.region).table_definitions.setdefault(
+ table_name, {}
+ )
+
+ dest_status = table_def.get("KinesisDataStreamDestinationStatus")
+ if dest_status not in ["DISABLED", "ENABLE_FAILED", None]:
+ raise ValidationException(
+ "Table is not in a valid state to enable Kinesis Streaming "
+ "Destination:EnableKinesisStreamingDestination must be DISABLED or ENABLE_FAILED "
+ "to perform ENABLE operation."
+ )
+
+ table_def["KinesisDataStreamDestinations"] = (
+ table_def.get("KinesisDataStreamDestinations") or []
+ )
+ # remove the stream destination if already present
+ table_def["KinesisDataStreamDestinations"] = [
+ t for t in table_def["KinesisDataStreamDestinations"] if t["StreamArn"] != stream_arn
+ ]
+ # append the active stream destination at the end of the list
+ table_def["KinesisDataStreamDestinations"].append(
+ {
+ "DestinationStatus": DestinationStatus.ACTIVE,
+ "DestinationStatusDescription": "Stream is active",
+ "StreamArn": stream_arn,
+ }
+ )
+ table_def["KinesisDataStreamDestinationStatus"] = DestinationStatus.ACTIVE
+ return KinesisStreamingDestinationOutput(
+ DestinationStatus=DestinationStatus.ACTIVE, StreamArn=stream_arn, TableName=table_name
+ )
+
+ def disable_kinesis_streaming_destination(
+ self,
+ context: RequestContext,
+ table_name: TableName,
+ stream_arn: StreamArn,
+ enable_kinesis_streaming_configuration: EnableKinesisStreamingConfiguration = None,
+ **kwargs,
+ ) -> KinesisStreamingDestinationOutput:
+ self.ensure_table_exists(context.account_id, context.region, table_name)
+
+ stream = self._event_forwarder.is_kinesis_stream_exists(stream_arn=stream_arn)
+ if not stream:
+ raise ValidationException(
+ "User does not have a permission to use kinesis stream",
+ )
+
+ table_def = get_store(context.account_id, context.region).table_definitions.setdefault(
+ table_name, {}
+ )
+
+ stream_destinations = table_def.get("KinesisDataStreamDestinations")
+ if stream_destinations:
+ if table_def["KinesisDataStreamDestinationStatus"] == DestinationStatus.ACTIVE:
+ for dest in stream_destinations:
+ if (
+ dest["StreamArn"] == stream_arn
+ and dest["DestinationStatus"] == DestinationStatus.ACTIVE
+ ):
+ dest["DestinationStatus"] = DestinationStatus.DISABLED
+ dest["DestinationStatusDescription"] = ("Stream is disabled",)
+ table_def["KinesisDataStreamDestinationStatus"] = DestinationStatus.DISABLED
+ return KinesisStreamingDestinationOutput(
+ DestinationStatus=DestinationStatus.DISABLED,
+ StreamArn=stream_arn,
+ TableName=table_name,
+ )
+ raise ValidationException(
+ "Table is not in a valid state to disable Kinesis Streaming Destination:"
+ "DisableKinesisStreamingDestination must be ACTIVE to perform DISABLE operation."
+ )
+
+ def describe_kinesis_streaming_destination(
+ self, context: RequestContext, table_name: TableName, **kwargs
+ ) -> DescribeKinesisStreamingDestinationOutput:
+ self.ensure_table_exists(context.account_id, context.region, table_name)
+
+ table_def = (
+ get_store(context.account_id, context.region).table_definitions.get(table_name) or {}
+ )
+
+ stream_destinations = table_def.get("KinesisDataStreamDestinations") or []
+ return DescribeKinesisStreamingDestinationOutput(
+ KinesisDataStreamDestinations=stream_destinations, TableName=table_name
+ )
+
+ #
+ # Continuous Backups
+ #
+
+ def describe_continuous_backups(
+ self, context: RequestContext, table_name: TableName, **kwargs
+ ) -> DescribeContinuousBackupsOutput:
+ self.get_global_table_region(context, table_name)
+ store = get_store(context.account_id, context.region)
+ continuous_backup_description = (
+ store.table_properties.get(table_name, {}).get("ContinuousBackupsDescription")
+ ) or ContinuousBackupsDescription(
+ ContinuousBackupsStatus=ContinuousBackupsStatus.ENABLED,
+ PointInTimeRecoveryDescription=PointInTimeRecoveryDescription(
+ PointInTimeRecoveryStatus=PointInTimeRecoveryStatus.DISABLED
+ ),
+ )
+
+ return DescribeContinuousBackupsOutput(
+ ContinuousBackupsDescription=continuous_backup_description
+ )
+
+ def update_continuous_backups(
+ self,
+ context: RequestContext,
+ table_name: TableName,
+ point_in_time_recovery_specification: PointInTimeRecoverySpecification,
+ **kwargs,
+ ) -> UpdateContinuousBackupsOutput:
+ self.get_global_table_region(context, table_name)
+
+ store = get_store(context.account_id, context.region)
+ pit_recovery_status = (
+ PointInTimeRecoveryStatus.ENABLED
+ if point_in_time_recovery_specification["PointInTimeRecoveryEnabled"]
+ else PointInTimeRecoveryStatus.DISABLED
+ )
+ continuous_backup_description = ContinuousBackupsDescription(
+ ContinuousBackupsStatus=ContinuousBackupsStatus.ENABLED,
+ PointInTimeRecoveryDescription=PointInTimeRecoveryDescription(
+ PointInTimeRecoveryStatus=pit_recovery_status
+ ),
+ )
+ table_props = store.table_properties.setdefault(table_name, {})
+ table_props["ContinuousBackupsDescription"] = continuous_backup_description
+
+ return UpdateContinuousBackupsOutput(
+ ContinuousBackupsDescription=continuous_backup_description
+ )
+
+ #
+ # Helpers
+ #
+
+ @staticmethod
+ def ddb_region_name(region_name: str) -> str:
+ """Map `local` or `localhost` region to the us-east-1 region. These values are used by NoSQL Workbench."""
+ # TODO: could this be somehow moved into the request handler chain?
+ if region_name in ("local", "localhost"):
+ region_name = AWS_REGION_US_EAST_1
+
+ return region_name
+
+ @staticmethod
+ def table_exists(account_id: str, region_name: str, table_name: str) -> bool:
+ region_name = DynamoDBProvider.ddb_region_name(region_name)
+
+ client = connect_to(
+ aws_access_key_id=account_id,
+ aws_secret_access_key=INTERNAL_AWS_SECRET_ACCESS_KEY,
+ region_name=region_name,
+ ).dynamodb
+ return dynamodb_table_exists(table_name, client)
+
+ @staticmethod
+ def ensure_table_exists(account_id: str, region_name: str, table_name: str):
+ """
+ Raise ResourceNotFoundException if the given table does not exist.
+
+ :param account_id: account id
+ :param region_name: region name
+ :param table_name: table name
+ :raise: ResourceNotFoundException if table does not exist in DynamoDB Local
+ """
+ if not DynamoDBProvider.table_exists(account_id, region_name, table_name):
+ raise ResourceNotFoundException("Cannot do operations on a non-existent table")
+
+ @staticmethod
+ def get_global_table_region(context: RequestContext, table_name: str) -> str:
+ """
+ Return the table region considering that it might be a replicated table.
+
+ Replication in LocalStack works by keeping a single copy of a table and forwarding
+ requests to the region where this table exists.
+
+ This method does not check whether the table actually exists in DDBLocal.
+
+ :param context: request context
+ :param table_name: table name
+ :return: region
+ """
+ store = get_store(context.account_id, context.region)
+
+ table_region = store.TABLE_REGION.get(table_name)
+ replicated_at = store.REPLICAS.get(table_name, {}).keys()
+
+ if context.region == table_region or context.region in replicated_at:
+ return table_region
+
+ return context.region
+
+ @staticmethod
+ def prepare_request_headers(headers: Dict, account_id: str, region_name: str):
+ """
+ Modify the Credentials field of Authorization header to achieve namespacing in DynamoDBLocal.
+ """
+ region_name = DynamoDBProvider.ddb_region_name(region_name)
+ key = get_ddb_access_key(account_id, region_name)
+
+ # DynamoDBLocal namespaces based on the value of Credentials
+ # Since we want to namespace by both account ID and region, use an aggregate key
+ # We also replace the region to keep compatibility with NoSQL Workbench
+ headers["Authorization"] = re.sub(
+ AUTH_CREDENTIAL_REGEX,
+ rf"Credential={key}/\2/{region_name}/\4/",
+ headers.get("Authorization") or "",
+ flags=re.IGNORECASE,
+ )
+
+ def fix_consumed_capacity(self, request: Dict, result: Dict):
+ # make sure we append 'ConsumedCapacity', which is properly
+ # returned by dynalite, but not by AWS's DynamoDBLocal
+ table_name = request.get("TableName")
+ return_cap = request.get("ReturnConsumedCapacity")
+ if "ConsumedCapacity" not in result and return_cap in ["TOTAL", "INDEXES"]:
+ request["ConsumedCapacity"] = {
+ "TableName": table_name,
+ "CapacityUnits": 5, # TODO hardcoded
+ "ReadCapacityUnits": 2,
+ "WriteCapacityUnits": 3,
+ }
+
+ def fix_table_arn(self, account_id: str, region_name: str, arn: str) -> str:
+ """
+ Set the correct account ID and region in ARNs returned by DynamoDB Local.
+ """
+ partition = get_partition(region_name)
+ return (
+ arn.replace("arn:aws:", f"arn:{partition}:")
+ .replace(":ddblocal:", f":{region_name}:")
+ .replace(":000000000000:", f":{account_id}:")
+ )
+
+ def prepare_transact_write_item_records(
+ self,
+ account_id: str,
+ region_name: str,
+ transact_items: TransactWriteItemList,
+ existing_items: BatchGetResponseMap,
+ updated_items: BatchGetResponseMap,
+ tables_stream_type: dict[TableName, TableStreamType],
+ ) -> RecordsMap:
+ records_only_map: dict[TableName, StreamRecords] = defaultdict(list)
+
+ for request in transact_items:
+ record = self.get_record_template(region_name)
+ match request:
+ case {"Put": {"TableName": table_name, "Item": new_item}}:
+ if not (stream_type := tables_stream_type.get(table_name)):
+ continue
+ keys = SchemaExtractor.extract_keys(
+ item=new_item,
+ table_name=table_name,
+ account_id=account_id,
+ region_name=region_name,
+ )
+ existing_item = find_item_for_keys_values_in_batch(
+ table_name, keys, existing_items
+ )
+ if existing_item == new_item:
+ continue
+
+ if stream_type.stream_view_type:
+ record["dynamodb"]["StreamViewType"] = stream_type.stream_view_type
+
+ record["eventID"] = short_uid()
+ record["eventName"] = "INSERT" if not existing_item else "MODIFY"
+ record["dynamodb"]["Keys"] = keys
+ if stream_type.needs_new_image:
+ record["dynamodb"]["NewImage"] = new_item
+ if existing_item and stream_type.needs_old_image:
+ record["dynamodb"]["OldImage"] = existing_item
+
+ record_item = de_dynamize_record(new_item)
+ record["dynamodb"]["SizeBytes"] = _get_size_bytes(record_item)
+ records_only_map[table_name].append(record)
+ continue
+
+ case {"Update": {"TableName": table_name, "Key": keys}}:
+ if not (stream_type := tables_stream_type.get(table_name)):
+ continue
+ updated_item = find_item_for_keys_values_in_batch(
+ table_name, keys, updated_items
+ )
+ if not updated_item:
+ continue
+
+ existing_item = find_item_for_keys_values_in_batch(
+ table_name, keys, existing_items
+ )
+ if existing_item == updated_item:
+ # if the item is the same as the previous version, AWS does not send an event
+ continue
+
+ if stream_type.stream_view_type:
+ record["dynamodb"]["StreamViewType"] = stream_type.stream_view_type
+
+ record["eventID"] = short_uid()
+ record["eventName"] = "MODIFY" if existing_item else "INSERT"
+ record["dynamodb"]["Keys"] = keys
+
+ if existing_item and stream_type.needs_old_image:
+ record["dynamodb"]["OldImage"] = existing_item
+ if stream_type.needs_new_image:
+ record["dynamodb"]["NewImage"] = updated_item
+
+ record["dynamodb"]["SizeBytes"] = _get_size_bytes(updated_item)
+ records_only_map[table_name].append(record)
+ continue
+
+ case {"Delete": {"TableName": table_name, "Key": keys}}:
+ if not (stream_type := tables_stream_type.get(table_name)):
+ continue
+
+ existing_item = find_item_for_keys_values_in_batch(
+ table_name, keys, existing_items
+ )
+ if not existing_item:
+ continue
+
+ if stream_type.stream_view_type:
+ record["dynamodb"]["StreamViewType"] = stream_type.stream_view_type
+
+ record["eventID"] = short_uid()
+ record["eventName"] = "REMOVE"
+ record["dynamodb"]["Keys"] = keys
+ if stream_type.needs_old_image:
+ record["dynamodb"]["OldImage"] = existing_item
+ record_item = de_dynamize_record(existing_item)
+ record["dynamodb"]["SizeBytes"] = _get_size_bytes(record_item)
+
+ records_only_map[table_name].append(record)
+ continue
+
+ records_map = {
+ table_name: TableRecords(
+ records=records, table_stream_type=tables_stream_type[table_name]
+ )
+ for table_name, records in records_only_map.items()
+ }
+
+ return records_map
+
+ def batch_execute_statement(
+ self,
+ context: RequestContext,
+ statements: PartiQLBatchRequest,
+ return_consumed_capacity: ReturnConsumedCapacity = None,
+ **kwargs,
+ ) -> BatchExecuteStatementOutput:
+ result = self.forward_request(context)
+ return result
+
+ def prepare_batch_write_item_records(
+ self,
+ account_id: str,
+ region_name: str,
+ tables_stream_type: dict[TableName, TableStreamType],
+ request_items: BatchWriteItemRequestMap,
+ existing_items: BatchGetResponseMap,
+ ) -> RecordsMap:
+ records_map: RecordsMap = {}
+
+ # only iterate over tables with streams
+ for table_name, stream_type in tables_stream_type.items():
+ existing_items_for_table_unordered = existing_items.get(table_name, [])
+ table_records: StreamRecords = []
+
+ def find_existing_item_for_keys_values(item_keys: dict) -> AttributeMap | None:
+ """
+ This function looks up in the existing items for the provided item keys subset. If present, returns the
+ full item.
+ :param item_keys: the request item keys
+ :return:
+ """
+ keys_items = item_keys.items()
+ for item in existing_items_for_table_unordered:
+ if keys_items <= item.items():
+ return item
+
+ for write_request in request_items[table_name]:
+ record = self.get_record_template(
+ region_name,
+ stream_view_type=stream_type.stream_view_type,
+ )
+ match write_request:
+ case {"PutRequest": request}:
+ keys = SchemaExtractor.extract_keys(
+ item=request["Item"],
+ table_name=table_name,
+ account_id=account_id,
+ region_name=region_name,
+ )
+ # we need to find if there was an existing item even if we don't need it for `OldImage`, because
+ # of the `eventName`
+ existing_item = find_existing_item_for_keys_values(keys)
+ if existing_item == request["Item"]:
+ # if the item is the same as the previous version, AWS does not send an event
+ continue
+ record["eventID"] = short_uid()
+ record["dynamodb"]["SizeBytes"] = _get_size_bytes(request["Item"])
+ record["eventName"] = "INSERT" if not existing_item else "MODIFY"
+ record["dynamodb"]["Keys"] = keys
+
+ if stream_type.needs_new_image:
+ record["dynamodb"]["NewImage"] = request["Item"]
+ if existing_item and stream_type.needs_old_image:
+ record["dynamodb"]["OldImage"] = existing_item
+
+ table_records.append(record)
+ continue
+
+ case {"DeleteRequest": request}:
+ keys = request["Key"]
+ if not (existing_item := find_existing_item_for_keys_values(keys)):
+ continue
+
+ record["eventID"] = short_uid()
+ record["eventName"] = "REMOVE"
+ record["dynamodb"]["Keys"] = keys
+ if stream_type.needs_old_image:
+ record["dynamodb"]["OldImage"] = existing_item
+ record["dynamodb"]["SizeBytes"] = _get_size_bytes(existing_item)
+ table_records.append(record)
+ continue
+
+ records_map[table_name] = TableRecords(
+ records=table_records, table_stream_type=stream_type
+ )
+
+ return records_map
+
+ def forward_stream_records(
+ self,
+ account_id: str,
+ region_name: str,
+ records_map: RecordsMap,
+ ) -> None:
+ if not records_map:
+ return
+
+ self._event_forwarder.forward_to_targets(
+ account_id, region_name, records_map, background=True
+ )
+
+ @staticmethod
+ def get_record_template(region_name: str, stream_view_type: str | None = None) -> StreamRecord:
+ record = {
+ "eventID": short_uid(),
+ "eventVersion": "1.1",
+ "dynamodb": {
+ # expects nearest second rounded down
+ "ApproximateCreationDateTime": int(time.time()),
+ "SizeBytes": -1,
+ },
+ "awsRegion": region_name,
+ "eventSource": "aws:dynamodb",
+ }
+ if stream_view_type:
+ record["dynamodb"]["StreamViewType"] = stream_view_type
+
+ return record
+
+ def check_provisioned_throughput(self, action):
+ """
+ Check rate limiting for an API operation and raise an error if provisioned throughput is exceeded.
+ """
+ if self.should_throttle(action):
+ message = (
+ "The level of configured provisioned throughput for the table was exceeded. "
+ + "Consider increasing your provisioning level with the UpdateTable API"
+ )
+ raise ProvisionedThroughputExceededException(message)
+
+ def action_should_throttle(self, action, actions):
+ throttled = [f"{ACTION_PREFIX}{a}" for a in actions]
+ return (action in throttled) or (action in actions)
+
+ def should_throttle(self, action):
+ if (
+ not config.DYNAMODB_READ_ERROR_PROBABILITY
+ and not config.DYNAMODB_ERROR_PROBABILITY
+ and not config.DYNAMODB_WRITE_ERROR_PROBABILITY
+ ):
+ # early exit so we don't need to call random()
+ return False
+
+ rand = random.random()
+ if rand < config.DYNAMODB_READ_ERROR_PROBABILITY and self.action_should_throttle(
+ action, READ_THROTTLED_ACTIONS
+ ):
+ return True
+ elif rand < config.DYNAMODB_WRITE_ERROR_PROBABILITY and self.action_should_throttle(
+ action, WRITE_THROTTLED_ACTIONS
+ ):
+ return True
+ elif rand < config.DYNAMODB_ERROR_PROBABILITY and self.action_should_throttle(
+ action, THROTTLED_ACTIONS
+ ):
+ return True
+ return False
+
+
+# ---
+# Misc. util functions
+# ---
+
+
+def _get_size_bytes(item: dict) -> int:
+ try:
+ size_bytes = len(json.dumps(item, separators=(",", ":")))
+ except TypeError:
+ size_bytes = len(str(item))
+ return size_bytes
+
+
+def get_global_secondary_index(account_id: str, region_name: str, table_name: str, index_name: str):
+ schema = SchemaExtractor.get_table_schema(table_name, account_id, region_name)
+ for index in schema["Table"].get("GlobalSecondaryIndexes", []):
+ if index["IndexName"] == index_name:
+ return index
+ raise ResourceNotFoundException("Index not found")
+
+
+def is_local_secondary_index(
+ account_id: str, region_name: str, table_name: str, index_name: str
+) -> bool:
+ schema = SchemaExtractor.get_table_schema(table_name, account_id, region_name)
+ for index in schema["Table"].get("LocalSecondaryIndexes", []):
+ if index["IndexName"] == index_name:
+ return True
+ return False
+
+
+def is_index_query_valid(account_id: str, region_name: str, query_data: dict) -> bool:
+ table_name = to_str(query_data["TableName"])
+ index_name = to_str(query_data["IndexName"])
+ if is_local_secondary_index(account_id, region_name, table_name, index_name):
+ return True
+ index_query_type = query_data.get("Select")
+ index = get_global_secondary_index(account_id, region_name, table_name, index_name)
+ index_projection_type = index.get("Projection").get("ProjectionType")
+ if index_query_type == "ALL_ATTRIBUTES" and index_projection_type != "ALL":
+ return False
+ return True
+
+
+def get_table_stream_type(
+ account_id: str, region_name: str, table_name_or_arn: str
+) -> TableStreamType | None:
+ """
+ :param account_id: the account id of the table
+ :param region_name: the region of the table
+ :param table_name_or_arn: the table name or ARN
+ :return: a TableStreamViewType object if the table has streams enabled. If not, return None
+ """
+ if not table_name_or_arn:
+ return
+
+ table_name = table_name_or_arn.split(":table/")[-1]
+
+ is_kinesis = False
+ stream_view_type = None
+
+ if table_definition := get_store(account_id, region_name).table_definitions.get(table_name):
+ if table_definition.get("KinesisDataStreamDestinationStatus") == "ACTIVE":
+ is_kinesis = True
+
+ table_arn = arns.dynamodb_table_arn(table_name, account_id=account_id, region_name=region_name)
+
+ if (
+ stream := dynamodbstreams_api.get_stream_for_table(account_id, region_name, table_arn)
+ ) and stream["StreamStatus"] in (StreamStatus.ENABLING, StreamStatus.ENABLED):
+ stream_view_type = stream["StreamViewType"]
+
+ if is_kinesis or stream_view_type:
+ return TableStreamType(stream_view_type, is_kinesis=is_kinesis)
+
+
+def get_updated_records(
+ account_id: str,
+ region_name: str,
+ table_name: str,
+ existing_items: List,
+ server_url: str,
+ table_stream_type: TableStreamType,
+) -> RecordsMap:
+ """
+ Determine the list of record updates, to be sent to a DDB stream after a PartiQL update operation.
+
+ Note: This is currently a fairly expensive operation, as we need to retrieve the list of all items
+ from the table, and compare the items to the previously available. This is a limitation as
+ we're currently using the DynamoDB Local backend as a blackbox. In future, we should consider hooking
+ into the PartiQL query execution inside DynamoDB Local and directly extract the list of updated items.
+ """
+ result = []
+
+ key_schema = SchemaExtractor.get_key_schema(table_name, account_id, region_name)
+ before = ItemSet(existing_items, key_schema=key_schema)
+ all_table_items = ItemFinder.get_all_table_items(
+ account_id=account_id,
+ region_name=region_name,
+ table_name=table_name,
+ endpoint_url=server_url,
+ )
+ after = ItemSet(all_table_items, key_schema=key_schema)
+
+ def _add_record(item, comparison_set: ItemSet):
+ matching_item = comparison_set.find_item(item)
+ if matching_item == item:
+ return
+
+ # determine event type
+ if comparison_set == after:
+ if matching_item:
+ return
+ event_name = "REMOVE"
+ else:
+ event_name = "INSERT" if not matching_item else "MODIFY"
+
+ old_image = item if event_name == "REMOVE" else matching_item
+ new_image = matching_item if event_name == "REMOVE" else item
+
+ # prepare record
+ keys = SchemaExtractor.extract_keys_for_schema(item=item, key_schema=key_schema)
+
+ record = DynamoDBProvider.get_record_template(region_name)
+ record["eventName"] = event_name
+ record["dynamodb"]["Keys"] = keys
+ record["dynamodb"]["SizeBytes"] = _get_size_bytes(item)
+
+ if table_stream_type.stream_view_type:
+ record["dynamodb"]["StreamViewType"] = table_stream_type.stream_view_type
+ if table_stream_type.needs_new_image:
+ record["dynamodb"]["NewImage"] = new_image
+ if old_image and table_stream_type.needs_old_image:
+ record["dynamodb"]["OldImage"] = old_image
+
+ result.append(record)
+
+ # loop over items in new item list (find INSERT/MODIFY events)
+ for item in after.items_list:
+ _add_record(item, before)
+ # loop over items in old item list (find REMOVE events)
+ for item in before.items_list:
+ _add_record(item, after)
+
+ return {table_name: TableRecords(records=result, table_stream_type=table_stream_type)}
+
+
+def create_dynamodb_stream(account_id: str, region_name: str, data, latest_stream_label):
+ stream = data["StreamSpecification"]
+ enabled = stream.get("StreamEnabled")
+
+ if enabled not in [False, "False"]:
+ table_name = data["TableName"]
+ view_type = stream["StreamViewType"]
+
+ dynamodbstreams_api.add_dynamodb_stream(
+ account_id=account_id,
+ region_name=region_name,
+ table_name=table_name,
+ latest_stream_label=latest_stream_label,
+ view_type=view_type,
+ enabled=enabled,
+ )
+
+
+def dynamodb_get_table_stream_specification(account_id: str, region_name: str, table_name: str):
+ try:
+ table_schema = SchemaExtractor.get_table_schema(
+ table_name, account_id=account_id, region_name=region_name
+ )
+ return table_schema["Table"].get("StreamSpecification")
+ except Exception as e:
+ LOG.info(
+ "Unable to get stream specification for table %s: %s %s",
+ table_name,
+ e,
+ traceback.format_exc(),
+ )
+ raise e
+
+
+def find_item_for_keys_values_in_batch(
+ table_name: str, item_keys: dict, batch: BatchGetResponseMap
+) -> AttributeMap | None:
+ """
+ This function looks up in the existing items for the provided item keys subset. If present, returns the
+ full item.
+ :param table_name: the table name for the item
+ :param item_keys: the request item keys
+ :param batch: the values in which to look for the item
+ :return: a DynamoDB Item (AttributeMap)
+ """
+ keys = item_keys.items()
+ for item in batch.get(table_name, []):
+ if keys <= item.items():
+ return item
diff --git a/localstack-core/localstack/services/dynamodb/resource_providers/__init__.py b/localstack-core/localstack/services/dynamodb/resource_providers/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_globaltable.py b/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_globaltable.py
new file mode 100644
index 0000000000000..af199a479576c
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_globaltable.py
@@ -0,0 +1,423 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class DynamoDBGlobalTableProperties(TypedDict):
+ AttributeDefinitions: Optional[list[AttributeDefinition]]
+ KeySchema: Optional[list[KeySchema]]
+ Replicas: Optional[list[ReplicaSpecification]]
+ Arn: Optional[str]
+ BillingMode: Optional[str]
+ GlobalSecondaryIndexes: Optional[list[GlobalSecondaryIndex]]
+ LocalSecondaryIndexes: Optional[list[LocalSecondaryIndex]]
+ SSESpecification: Optional[SSESpecification]
+ StreamArn: Optional[str]
+ StreamSpecification: Optional[StreamSpecification]
+ TableId: Optional[str]
+ TableName: Optional[str]
+ TimeToLiveSpecification: Optional[TimeToLiveSpecification]
+ WriteProvisionedThroughputSettings: Optional[WriteProvisionedThroughputSettings]
+
+
+class AttributeDefinition(TypedDict):
+ AttributeName: Optional[str]
+ AttributeType: Optional[str]
+
+
+class KeySchema(TypedDict):
+ AttributeName: Optional[str]
+ KeyType: Optional[str]
+
+
+class Projection(TypedDict):
+ NonKeyAttributes: Optional[list[str]]
+ ProjectionType: Optional[str]
+
+
+class TargetTrackingScalingPolicyConfiguration(TypedDict):
+ TargetValue: Optional[float]
+ DisableScaleIn: Optional[bool]
+ ScaleInCooldown: Optional[int]
+ ScaleOutCooldown: Optional[int]
+
+
+class CapacityAutoScalingSettings(TypedDict):
+ MaxCapacity: Optional[int]
+ MinCapacity: Optional[int]
+ TargetTrackingScalingPolicyConfiguration: Optional[TargetTrackingScalingPolicyConfiguration]
+ SeedCapacity: Optional[int]
+
+
+class WriteProvisionedThroughputSettings(TypedDict):
+ WriteCapacityAutoScalingSettings: Optional[CapacityAutoScalingSettings]
+
+
+class GlobalSecondaryIndex(TypedDict):
+ IndexName: Optional[str]
+ KeySchema: Optional[list[KeySchema]]
+ Projection: Optional[Projection]
+ WriteProvisionedThroughputSettings: Optional[WriteProvisionedThroughputSettings]
+
+
+class LocalSecondaryIndex(TypedDict):
+ IndexName: Optional[str]
+ KeySchema: Optional[list[KeySchema]]
+ Projection: Optional[Projection]
+
+
+class ContributorInsightsSpecification(TypedDict):
+ Enabled: Optional[bool]
+
+
+class ReadProvisionedThroughputSettings(TypedDict):
+ ReadCapacityAutoScalingSettings: Optional[CapacityAutoScalingSettings]
+ ReadCapacityUnits: Optional[int]
+
+
+class ReplicaGlobalSecondaryIndexSpecification(TypedDict):
+ IndexName: Optional[str]
+ ContributorInsightsSpecification: Optional[ContributorInsightsSpecification]
+ ReadProvisionedThroughputSettings: Optional[ReadProvisionedThroughputSettings]
+
+
+class PointInTimeRecoverySpecification(TypedDict):
+ PointInTimeRecoveryEnabled: Optional[bool]
+
+
+class ReplicaSSESpecification(TypedDict):
+ KMSMasterKeyId: Optional[str]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+class KinesisStreamSpecification(TypedDict):
+ StreamArn: Optional[str]
+
+
+class ReplicaSpecification(TypedDict):
+ Region: Optional[str]
+ ContributorInsightsSpecification: Optional[ContributorInsightsSpecification]
+ DeletionProtectionEnabled: Optional[bool]
+ GlobalSecondaryIndexes: Optional[list[ReplicaGlobalSecondaryIndexSpecification]]
+ KinesisStreamSpecification: Optional[KinesisStreamSpecification]
+ PointInTimeRecoverySpecification: Optional[PointInTimeRecoverySpecification]
+ ReadProvisionedThroughputSettings: Optional[ReadProvisionedThroughputSettings]
+ SSESpecification: Optional[ReplicaSSESpecification]
+ TableClass: Optional[str]
+ Tags: Optional[list[Tag]]
+
+
+class SSESpecification(TypedDict):
+ SSEEnabled: Optional[bool]
+ SSEType: Optional[str]
+
+
+class StreamSpecification(TypedDict):
+ StreamViewType: Optional[str]
+
+
+class TimeToLiveSpecification(TypedDict):
+ Enabled: Optional[bool]
+ AttributeName: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class DynamoDBGlobalTableProvider(ResourceProvider[DynamoDBGlobalTableProperties]):
+ TYPE = "AWS::DynamoDB::GlobalTable" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[DynamoDBGlobalTableProperties],
+ ) -> ProgressEvent[DynamoDBGlobalTableProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/TableName
+
+ Required properties:
+ - KeySchema
+ - AttributeDefinitions
+ - Replicas
+
+ Create-only properties:
+ - /properties/LocalSecondaryIndexes
+ - /properties/TableName
+ - /properties/KeySchema
+
+ Read-only properties:
+ - /properties/Arn
+ - /properties/StreamArn
+ - /properties/TableId
+
+ IAM permissions required:
+ - dynamodb:CreateTable
+ - dynamodb:CreateTableReplica
+ - dynamodb:Describe*
+ - dynamodb:UpdateTimeToLive
+ - dynamodb:UpdateContributorInsights
+ - dynamodb:UpdateContinuousBackups
+ - dynamodb:ListTagsOfResource
+ - dynamodb:Query
+ - dynamodb:Scan
+ - dynamodb:UpdateItem
+ - dynamodb:PutItem
+ - dynamodb:GetItem
+ - dynamodb:DeleteItem
+ - dynamodb:BatchWriteItem
+ - dynamodb:TagResource
+ - dynamodb:EnableKinesisStreamingDestination
+ - dynamodb:DisableKinesisStreamingDestination
+ - dynamodb:DescribeKinesisStreamingDestination
+ - dynamodb:DescribeTableReplicaAutoScaling
+ - dynamodb:UpdateTableReplicaAutoScaling
+ - dynamodb:TagResource
+ - application-autoscaling:DeleteScalingPolicy
+ - application-autoscaling:DeleteScheduledAction
+ - application-autoscaling:DeregisterScalableTarget
+ - application-autoscaling:Describe*
+ - application-autoscaling:PutScalingPolicy
+ - application-autoscaling:PutScheduledAction
+ - application-autoscaling:RegisterScalableTarget
+ - kinesis:ListStreams
+ - kinesis:DescribeStream
+ - kinesis:PutRecords
+ - kms:CreateGrant
+ - kms:Describe*
+ - kms:Get*
+ - kms:List*
+ - kms:RevokeGrant
+ - cloudwatch:PutMetricData
+
+ """
+ model = request.desired_state
+
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ request.custom_context[REPEATED_INVOCATION] = True
+
+ if not model.get("TableName"):
+ model["TableName"] = util.generate_default_name(
+ stack_name=request.stack_name, logical_resource_id=request.logical_resource_id
+ )
+
+ create_params = util.select_attributes(
+ model,
+ [
+ "AttributeDefinitions",
+ "BillingMode",
+ "GlobalSecondaryIndexes",
+ "KeySchema",
+ "LocalSecondaryIndexes",
+ "Replicas",
+ "SSESpecification",
+ "StreamSpecification",
+ "TableName",
+ "WriteProvisionedThroughputSettings",
+ ],
+ )
+
+ replicas = create_params.pop("Replicas", [])
+
+ if sse_specification := create_params.get("SSESpecification"):
+ # rename bool attribute to fit boto call
+ sse_specification["Enabled"] = sse_specification.pop("SSEEnabled")
+
+ if stream_spec := model.get("StreamSpecification"):
+ create_params["StreamSpecification"] = {
+ "StreamEnabled": True,
+ **stream_spec,
+ }
+
+ creation_response = request.aws_client_factory.dynamodb.create_table(**create_params)
+ model["Arn"] = creation_response["TableDescription"]["TableArn"]
+ model["TableId"] = creation_response["TableDescription"]["TableId"]
+
+ if creation_response["TableDescription"].get("LatestStreamArn"):
+ model["StreamArn"] = creation_response["TableDescription"]["LatestStreamArn"]
+
+ replicas_to_create = []
+ for replica in replicas:
+ create = {
+ "RegionName": replica.get("Region"),
+ "KMSMasterKeyId": replica.get("KMSMasterKeyId"),
+ "ProvisionedThroughputOverride": replica.get("ProvisionedThroughputOverride"),
+ "GlobalSecondaryIndexes": replica.get("GlobalSecondaryIndexes"),
+ "TableClassOverride": replica.get("TableClassOverride"),
+ }
+
+ create = {k: v for k, v in create.items() if v is not None}
+
+ replicas_to_create.append({"Create": create})
+
+ request.aws_client_factory.dynamodb.update_table(
+ ReplicaUpdates=replicas_to_create, TableName=model["TableName"]
+ )
+
+ # add TTL config
+ if ttl_config := model.get("TimeToLiveSpecification"):
+ request.aws_client_factory.dynamodb.update_time_to_live(
+ TableName=model["TableName"], TimeToLiveSpecification=ttl_config
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ status = request.aws_client_factory.dynamodb.describe_table(TableName=model["TableName"])[
+ "Table"
+ ]["TableStatus"]
+ if status == "ACTIVE":
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ elif status == "CREATING":
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+ else:
+ return ProgressEvent(
+ status=OperationStatus.FAILED,
+ resource_model=model,
+ custom_context=request.custom_context,
+ message=f"Table creation failed with status {status}",
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[DynamoDBGlobalTableProperties],
+ ) -> ProgressEvent[DynamoDBGlobalTableProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - dynamodb:Describe*
+ - application-autoscaling:Describe*
+ - cloudwatch:PutMetricData
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[DynamoDBGlobalTableProperties],
+ ) -> ProgressEvent[DynamoDBGlobalTableProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - dynamodb:Describe*
+ - application-autoscaling:DeleteScalingPolicy
+ - application-autoscaling:DeleteScheduledAction
+ - application-autoscaling:DeregisterScalableTarget
+ - application-autoscaling:Describe*
+ - application-autoscaling:PutScalingPolicy
+ - application-autoscaling:PutScheduledAction
+ - application-autoscaling:RegisterScalableTarget
+ """
+
+ model = request.desired_state
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ request.custom_context[REPEATED_INVOCATION] = True
+ request.aws_client_factory.dynamodb.delete_table(TableName=model["TableName"])
+
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ try:
+ request.aws_client_factory.dynamodb.describe_table(TableName=model["TableName"])
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+ except Exception as ex:
+ if "ResourceNotFoundException" in str(ex):
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.FAILED,
+ resource_model=model,
+ custom_context=request.custom_context,
+ message=str(ex),
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[DynamoDBGlobalTableProperties],
+ ) -> ProgressEvent[DynamoDBGlobalTableProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - dynamodb:Describe*
+ - dynamodb:CreateTableReplica
+ - dynamodb:UpdateTable
+ - dynamodb:UpdateTimeToLive
+ - dynamodb:UpdateContinuousBackups
+ - dynamodb:UpdateContributorInsights
+ - dynamodb:ListTagsOfResource
+ - dynamodb:Query
+ - dynamodb:Scan
+ - dynamodb:UpdateItem
+ - dynamodb:PutItem
+ - dynamodb:GetItem
+ - dynamodb:DeleteItem
+ - dynamodb:BatchWriteItem
+ - dynamodb:DeleteTable
+ - dynamodb:DeleteTableReplica
+ - dynamodb:UpdateItem
+ - dynamodb:TagResource
+ - dynamodb:UntagResource
+ - dynamodb:EnableKinesisStreamingDestination
+ - dynamodb:DisableKinesisStreamingDestination
+ - dynamodb:DescribeKinesisStreamingDestination
+ - dynamodb:DescribeTableReplicaAutoScaling
+ - dynamodb:UpdateTableReplicaAutoScaling
+ - application-autoscaling:DeleteScalingPolicy
+ - application-autoscaling:DeleteScheduledAction
+ - application-autoscaling:DeregisterScalableTarget
+ - application-autoscaling:Describe*
+ - application-autoscaling:PutScalingPolicy
+ - application-autoscaling:PutScheduledAction
+ - application-autoscaling:RegisterScalableTarget
+ - kinesis:ListStreams
+ - kinesis:DescribeStream
+ - kinesis:PutRecords
+ - kms:CreateGrant
+ - kms:Describe*
+ - kms:Get*
+ - kms:List*
+ - kms:RevokeGrant
+ - cloudwatch:PutMetricData
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_globaltable.schema.json b/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_globaltable.schema.json
new file mode 100644
index 0000000000000..3caa6a203393a
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_globaltable.schema.json
@@ -0,0 +1,574 @@
+{
+ "typeName": "AWS::DynamoDB::GlobalTable",
+ "description": "Version: None. Resource Type definition for AWS::DynamoDB::GlobalTable",
+ "additionalProperties": false,
+ "properties": {
+ "Arn": {
+ "type": "string"
+ },
+ "StreamArn": {
+ "type": "string"
+ },
+ "AttributeDefinitions": {
+ "type": "array",
+ "uniqueItems": true,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/AttributeDefinition"
+ },
+ "minItems": 1
+ },
+ "BillingMode": {
+ "type": "string"
+ },
+ "GlobalSecondaryIndexes": {
+ "type": "array",
+ "uniqueItems": true,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/GlobalSecondaryIndex"
+ }
+ },
+ "KeySchema": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/KeySchema"
+ },
+ "minItems": 1,
+ "maxItems": 2
+ },
+ "LocalSecondaryIndexes": {
+ "type": "array",
+ "uniqueItems": true,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/LocalSecondaryIndex"
+ }
+ },
+ "WriteProvisionedThroughputSettings": {
+ "$ref": "#/definitions/WriteProvisionedThroughputSettings"
+ },
+ "Replicas": {
+ "type": "array",
+ "uniqueItems": true,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/ReplicaSpecification"
+ },
+ "minItems": 1
+ },
+ "SSESpecification": {
+ "$ref": "#/definitions/SSESpecification"
+ },
+ "StreamSpecification": {
+ "$ref": "#/definitions/StreamSpecification"
+ },
+ "TableName": {
+ "type": "string"
+ },
+ "TableId": {
+ "type": "string"
+ },
+ "TimeToLiveSpecification": {
+ "$ref": "#/definitions/TimeToLiveSpecification"
+ }
+ },
+ "definitions": {
+ "StreamSpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "StreamViewType": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "StreamViewType"
+ ]
+ },
+ "KinesisStreamSpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "StreamArn": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "StreamArn"
+ ]
+ },
+ "KeySchema": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "AttributeName": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 255
+ },
+ "KeyType": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "KeyType",
+ "AttributeName"
+ ]
+ },
+ "PointInTimeRecoverySpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "PointInTimeRecoveryEnabled": {
+ "type": "boolean"
+ }
+ }
+ },
+ "ReplicaSpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Region": {
+ "type": "string"
+ },
+ "GlobalSecondaryIndexes": {
+ "type": "array",
+ "uniqueItems": true,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/ReplicaGlobalSecondaryIndexSpecification"
+ }
+ },
+ "ContributorInsightsSpecification": {
+ "$ref": "#/definitions/ContributorInsightsSpecification"
+ },
+ "PointInTimeRecoverySpecification": {
+ "$ref": "#/definitions/PointInTimeRecoverySpecification"
+ },
+ "TableClass": {
+ "type": "string"
+ },
+ "DeletionProtectionEnabled": {
+ "type": "boolean"
+ },
+ "SSESpecification": {
+ "$ref": "#/definitions/ReplicaSSESpecification"
+ },
+ "Tags": {
+ "type": "array",
+ "insertionOrder": false,
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "ReadProvisionedThroughputSettings": {
+ "$ref": "#/definitions/ReadProvisionedThroughputSettings"
+ },
+ "KinesisStreamSpecification": {
+ "$ref": "#/definitions/KinesisStreamSpecification"
+ }
+ },
+ "required": [
+ "Region"
+ ]
+ },
+ "TimeToLiveSpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "AttributeName": {
+ "type": "string"
+ },
+ "Enabled": {
+ "type": "boolean"
+ }
+ },
+ "required": [
+ "Enabled"
+ ]
+ },
+ "LocalSecondaryIndex": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "IndexName": {
+ "type": "string",
+ "minLength": 3,
+ "maxLength": 255
+ },
+ "KeySchema": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/KeySchema"
+ },
+ "maxItems": 2
+ },
+ "Projection": {
+ "$ref": "#/definitions/Projection"
+ }
+ },
+ "required": [
+ "IndexName",
+ "Projection",
+ "KeySchema"
+ ]
+ },
+ "GlobalSecondaryIndex": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "IndexName": {
+ "type": "string",
+ "minLength": 3,
+ "maxLength": 255
+ },
+ "KeySchema": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/KeySchema"
+ },
+ "minItems": 1,
+ "maxItems": 2
+ },
+ "Projection": {
+ "$ref": "#/definitions/Projection"
+ },
+ "WriteProvisionedThroughputSettings": {
+ "$ref": "#/definitions/WriteProvisionedThroughputSettings"
+ }
+ },
+ "required": [
+ "IndexName",
+ "Projection",
+ "KeySchema"
+ ]
+ },
+ "SSESpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "SSEEnabled": {
+ "type": "boolean"
+ },
+ "SSEType": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "SSEEnabled"
+ ]
+ },
+ "ReplicaSSESpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "KMSMasterKeyId": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "KMSMasterKeyId"
+ ]
+ },
+ "AttributeDefinition": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "AttributeName": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 255
+ },
+ "AttributeType": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "AttributeName",
+ "AttributeType"
+ ]
+ },
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Key": {
+ "type": "string"
+ },
+ "Value": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ },
+ "Projection": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "NonKeyAttributes": {
+ "type": "array",
+ "uniqueItems": true,
+ "insertionOrder": false,
+ "items": {
+ "type": "string"
+ },
+ "maxItems": 20
+ },
+ "ProjectionType": {
+ "type": "string"
+ }
+ }
+ },
+ "ReplicaGlobalSecondaryIndexSpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "IndexName": {
+ "type": "string",
+ "minLength": 3,
+ "maxLength": 255
+ },
+ "ContributorInsightsSpecification": {
+ "$ref": "#/definitions/ContributorInsightsSpecification"
+ },
+ "ReadProvisionedThroughputSettings": {
+ "$ref": "#/definitions/ReadProvisionedThroughputSettings"
+ }
+ },
+ "required": [
+ "IndexName"
+ ]
+ },
+ "ContributorInsightsSpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Enabled": {
+ "type": "boolean"
+ }
+ },
+ "required": [
+ "Enabled"
+ ]
+ },
+ "ReadProvisionedThroughputSettings": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "ReadCapacityUnits": {
+ "type": "integer",
+ "minimum": 1
+ },
+ "ReadCapacityAutoScalingSettings": {
+ "$ref": "#/definitions/CapacityAutoScalingSettings"
+ }
+ }
+ },
+ "WriteProvisionedThroughputSettings": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "WriteCapacityAutoScalingSettings": {
+ "$ref": "#/definitions/CapacityAutoScalingSettings"
+ }
+ }
+ },
+ "CapacityAutoScalingSettings": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "MinCapacity": {
+ "type": "integer",
+ "minimum": 1
+ },
+ "MaxCapacity": {
+ "type": "integer",
+ "minimum": 1
+ },
+ "SeedCapacity": {
+ "type": "integer",
+ "minimum": 1
+ },
+ "TargetTrackingScalingPolicyConfiguration": {
+ "$ref": "#/definitions/TargetTrackingScalingPolicyConfiguration"
+ }
+ },
+ "required": [
+ "MinCapacity",
+ "MaxCapacity",
+ "TargetTrackingScalingPolicyConfiguration"
+ ]
+ },
+ "TargetTrackingScalingPolicyConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "DisableScaleIn": {
+ "type": "boolean"
+ },
+ "ScaleInCooldown": {
+ "type": "integer",
+ "minimum": 0
+ },
+ "ScaleOutCooldown": {
+ "type": "integer",
+ "minimum": 0
+ },
+ "TargetValue": {
+ "type": "number",
+ "format": "double"
+ }
+ },
+ "required": [
+ "TargetValue"
+ ]
+ }
+ },
+ "required": [
+ "KeySchema",
+ "AttributeDefinitions",
+ "Replicas"
+ ],
+ "readOnlyProperties": [
+ "/properties/Arn",
+ "/properties/StreamArn",
+ "/properties/TableId"
+ ],
+ "createOnlyProperties": [
+ "/properties/LocalSecondaryIndexes",
+ "/properties/TableName",
+ "/properties/KeySchema"
+ ],
+ "primaryIdentifier": [
+ "/properties/TableName"
+ ],
+ "additionalIdentifiers": [
+ [
+ "/properties/Arn"
+ ],
+ [
+ "/properties/StreamArn"
+ ]
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "dynamodb:CreateTable",
+ "dynamodb:CreateTableReplica",
+ "dynamodb:Describe*",
+ "dynamodb:UpdateTimeToLive",
+ "dynamodb:UpdateContributorInsights",
+ "dynamodb:UpdateContinuousBackups",
+ "dynamodb:ListTagsOfResource",
+ "dynamodb:Query",
+ "dynamodb:Scan",
+ "dynamodb:UpdateItem",
+ "dynamodb:PutItem",
+ "dynamodb:GetItem",
+ "dynamodb:DeleteItem",
+ "dynamodb:BatchWriteItem",
+ "dynamodb:TagResource",
+ "dynamodb:EnableKinesisStreamingDestination",
+ "dynamodb:DisableKinesisStreamingDestination",
+ "dynamodb:DescribeKinesisStreamingDestination",
+ "dynamodb:DescribeTableReplicaAutoScaling",
+ "dynamodb:UpdateTableReplicaAutoScaling",
+ "dynamodb:TagResource",
+ "application-autoscaling:DeleteScalingPolicy",
+ "application-autoscaling:DeleteScheduledAction",
+ "application-autoscaling:DeregisterScalableTarget",
+ "application-autoscaling:Describe*",
+ "application-autoscaling:PutScalingPolicy",
+ "application-autoscaling:PutScheduledAction",
+ "application-autoscaling:RegisterScalableTarget",
+ "kinesis:ListStreams",
+ "kinesis:DescribeStream",
+ "kinesis:PutRecords",
+ "kms:CreateGrant",
+ "kms:Describe*",
+ "kms:Get*",
+ "kms:List*",
+ "kms:RevokeGrant",
+ "cloudwatch:PutMetricData"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "dynamodb:Describe*",
+ "application-autoscaling:Describe*",
+ "cloudwatch:PutMetricData"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "dynamodb:Describe*",
+ "dynamodb:CreateTableReplica",
+ "dynamodb:UpdateTable",
+ "dynamodb:UpdateTimeToLive",
+ "dynamodb:UpdateContinuousBackups",
+ "dynamodb:UpdateContributorInsights",
+ "dynamodb:ListTagsOfResource",
+ "dynamodb:Query",
+ "dynamodb:Scan",
+ "dynamodb:UpdateItem",
+ "dynamodb:PutItem",
+ "dynamodb:GetItem",
+ "dynamodb:DeleteItem",
+ "dynamodb:BatchWriteItem",
+ "dynamodb:DeleteTable",
+ "dynamodb:DeleteTableReplica",
+ "dynamodb:UpdateItem",
+ "dynamodb:TagResource",
+ "dynamodb:UntagResource",
+ "dynamodb:EnableKinesisStreamingDestination",
+ "dynamodb:DisableKinesisStreamingDestination",
+ "dynamodb:DescribeKinesisStreamingDestination",
+ "dynamodb:DescribeTableReplicaAutoScaling",
+ "dynamodb:UpdateTableReplicaAutoScaling",
+ "application-autoscaling:DeleteScalingPolicy",
+ "application-autoscaling:DeleteScheduledAction",
+ "application-autoscaling:DeregisterScalableTarget",
+ "application-autoscaling:Describe*",
+ "application-autoscaling:PutScalingPolicy",
+ "application-autoscaling:PutScheduledAction",
+ "application-autoscaling:RegisterScalableTarget",
+ "kinesis:ListStreams",
+ "kinesis:DescribeStream",
+ "kinesis:PutRecords",
+ "kms:CreateGrant",
+ "kms:Describe*",
+ "kms:Get*",
+ "kms:List*",
+ "kms:RevokeGrant",
+ "cloudwatch:PutMetricData"
+ ],
+ "timeoutInMinutes": 1200
+ },
+ "delete": {
+ "permissions": [
+ "dynamodb:Describe*",
+ "application-autoscaling:DeleteScalingPolicy",
+ "application-autoscaling:DeleteScheduledAction",
+ "application-autoscaling:DeregisterScalableTarget",
+ "application-autoscaling:Describe*",
+ "application-autoscaling:PutScalingPolicy",
+ "application-autoscaling:PutScheduledAction",
+ "application-autoscaling:RegisterScalableTarget"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "dynamodb:ListTables",
+ "cloudwatch:PutMetricData"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_globaltable_plugin.py b/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_globaltable_plugin.py
new file mode 100644
index 0000000000000..8de0265d3d5f1
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_globaltable_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class DynamoDBGlobalTableProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::DynamoDB::GlobalTable"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.dynamodb.resource_providers.aws_dynamodb_globaltable import (
+ DynamoDBGlobalTableProvider,
+ )
+
+ self.factory = DynamoDBGlobalTableProvider
diff --git a/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_table.py b/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_table.py
new file mode 100644
index 0000000000000..469c944cca898
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_table.py
@@ -0,0 +1,442 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class DynamoDBTableProperties(TypedDict):
+ KeySchema: Optional[list[KeySchema] | dict]
+ Arn: Optional[str]
+ AttributeDefinitions: Optional[list[AttributeDefinition]]
+ BillingMode: Optional[str]
+ ContributorInsightsSpecification: Optional[ContributorInsightsSpecification]
+ DeletionProtectionEnabled: Optional[bool]
+ GlobalSecondaryIndexes: Optional[list[GlobalSecondaryIndex]]
+ ImportSourceSpecification: Optional[ImportSourceSpecification]
+ KinesisStreamSpecification: Optional[KinesisStreamSpecification]
+ LocalSecondaryIndexes: Optional[list[LocalSecondaryIndex]]
+ PointInTimeRecoverySpecification: Optional[PointInTimeRecoverySpecification]
+ ProvisionedThroughput: Optional[ProvisionedThroughput]
+ SSESpecification: Optional[SSESpecification]
+ StreamArn: Optional[str]
+ StreamSpecification: Optional[StreamSpecification]
+ TableClass: Optional[str]
+ TableName: Optional[str]
+ Tags: Optional[list[Tag]]
+ TimeToLiveSpecification: Optional[TimeToLiveSpecification]
+
+
+class AttributeDefinition(TypedDict):
+ AttributeName: Optional[str]
+ AttributeType: Optional[str]
+
+
+class KeySchema(TypedDict):
+ AttributeName: Optional[str]
+ KeyType: Optional[str]
+
+
+class Projection(TypedDict):
+ NonKeyAttributes: Optional[list[str]]
+ ProjectionType: Optional[str]
+
+
+class ProvisionedThroughput(TypedDict):
+ ReadCapacityUnits: Optional[int]
+ WriteCapacityUnits: Optional[int]
+
+
+class ContributorInsightsSpecification(TypedDict):
+ Enabled: Optional[bool]
+
+
+class GlobalSecondaryIndex(TypedDict):
+ IndexName: Optional[str]
+ KeySchema: Optional[list[KeySchema]]
+ Projection: Optional[Projection]
+ ContributorInsightsSpecification: Optional[ContributorInsightsSpecification]
+ ProvisionedThroughput: Optional[ProvisionedThroughput]
+
+
+class LocalSecondaryIndex(TypedDict):
+ IndexName: Optional[str]
+ KeySchema: Optional[list[KeySchema]]
+ Projection: Optional[Projection]
+
+
+class PointInTimeRecoverySpecification(TypedDict):
+ PointInTimeRecoveryEnabled: Optional[bool]
+
+
+class SSESpecification(TypedDict):
+ SSEEnabled: Optional[bool]
+ KMSMasterKeyId: Optional[str]
+ SSEType: Optional[str]
+
+
+class StreamSpecification(TypedDict):
+ StreamViewType: Optional[str]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+class TimeToLiveSpecification(TypedDict):
+ AttributeName: Optional[str]
+ Enabled: Optional[bool]
+
+
+class KinesisStreamSpecification(TypedDict):
+ StreamArn: Optional[str]
+
+
+class S3BucketSource(TypedDict):
+ S3Bucket: Optional[str]
+ S3BucketOwner: Optional[str]
+ S3KeyPrefix: Optional[str]
+
+
+class Csv(TypedDict):
+ Delimiter: Optional[str]
+ HeaderList: Optional[list[str]]
+
+
+class InputFormatOptions(TypedDict):
+ Csv: Optional[Csv]
+
+
+class ImportSourceSpecification(TypedDict):
+ InputFormat: Optional[str]
+ S3BucketSource: Optional[S3BucketSource]
+ InputCompressionType: Optional[str]
+ InputFormatOptions: Optional[InputFormatOptions]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class DynamoDBTableProvider(ResourceProvider[DynamoDBTableProperties]):
+ TYPE = "AWS::DynamoDB::Table" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[DynamoDBTableProperties],
+ ) -> ProgressEvent[DynamoDBTableProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/TableName
+
+ Required properties:
+ - KeySchema
+
+ Create-only properties:
+ - /properties/TableName
+ - /properties/ImportSourceSpecification
+
+ Read-only properties:
+ - /properties/Arn
+ - /properties/StreamArn
+
+ IAM permissions required:
+ - dynamodb:CreateTable
+ - dynamodb:DescribeImport
+ - dynamodb:DescribeTable
+ - dynamodb:DescribeTimeToLive
+ - dynamodb:UpdateTimeToLive
+ - dynamodb:UpdateContributorInsights
+ - dynamodb:UpdateContinuousBackups
+ - dynamodb:DescribeContinuousBackups
+ - dynamodb:DescribeContributorInsights
+ - dynamodb:EnableKinesisStreamingDestination
+ - dynamodb:DisableKinesisStreamingDestination
+ - dynamodb:DescribeKinesisStreamingDestination
+ - dynamodb:ImportTable
+ - dynamodb:ListTagsOfResource
+ - dynamodb:TagResource
+ - dynamodb:UpdateTable
+ - kinesis:DescribeStream
+ - kinesis:PutRecords
+ - iam:CreateServiceLinkedRole
+ - kms:CreateGrant
+ - kms:Decrypt
+ - kms:Describe*
+ - kms:Encrypt
+ - kms:Get*
+ - kms:List*
+ - kms:RevokeGrant
+ - logs:CreateLogGroup
+ - logs:CreateLogStream
+ - logs:DescribeLogGroups
+ - logs:DescribeLogStreams
+ - logs:PutLogEvents
+ - logs:PutRetentionPolicy
+ - s3:GetObject
+ - s3:GetObjectMetadata
+ - s3:ListBucket
+
+ """
+ model = request.desired_state
+
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ request.custom_context[REPEATED_INVOCATION] = True
+
+ if not model.get("TableName"):
+ model["TableName"] = util.generate_default_name(
+ request.stack_name, request.logical_resource_id
+ )
+
+ if model.get("ProvisionedThroughput"):
+ model["ProvisionedThroughput"] = self.get_ddb_provisioned_throughput(model)
+
+ if model.get("GlobalSecondaryIndexes"):
+ model["GlobalSecondaryIndexes"] = self.get_ddb_global_sec_indexes(model)
+
+ properties = [
+ "TableName",
+ "AttributeDefinitions",
+ "KeySchema",
+ "BillingMode",
+ "ProvisionedThroughput",
+ "LocalSecondaryIndexes",
+ "GlobalSecondaryIndexes",
+ "Tags",
+ "SSESpecification",
+ ]
+ create_params = util.select_attributes(model, properties)
+
+ if sse_specification := create_params.get("SSESpecification"):
+ # rename bool attribute to fit boto call
+ sse_specification["Enabled"] = sse_specification.pop("SSEEnabled")
+
+ if stream_spec := model.get("StreamSpecification"):
+ create_params["StreamSpecification"] = {
+ "StreamEnabled": True,
+ **(stream_spec or {}),
+ }
+
+ response = request.aws_client_factory.dynamodb.create_table(**create_params)
+ model["Arn"] = response["TableDescription"]["TableArn"]
+
+ if model.get("KinesisStreamSpecification"):
+ request.aws_client_factory.dynamodb.enable_kinesis_streaming_destination(
+ **self.get_ddb_kinesis_stream_specification(model)
+ )
+
+ # add TTL config
+ if ttl_config := model.get("TimeToLiveSpecification"):
+ request.aws_client_factory.dynamodb.update_time_to_live(
+ TableName=model["TableName"], TimeToLiveSpecification=ttl_config
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ description = request.aws_client_factory.dynamodb.describe_table(
+ TableName=model["TableName"]
+ )
+
+ if description["Table"]["TableStatus"] != "ACTIVE":
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ if model.get("TimeToLiveSpecification"):
+ request.aws_client_factory.dynamodb.update_time_to_live(
+ TableName=model["TableName"],
+ TimeToLiveSpecification=model["TimeToLiveSpecification"],
+ )
+
+ if description["Table"].get("LatestStreamArn"):
+ model["StreamArn"] = description["Table"]["LatestStreamArn"]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[DynamoDBTableProperties],
+ ) -> ProgressEvent[DynamoDBTableProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - dynamodb:DescribeTable
+ - dynamodb:DescribeContinuousBackups
+ - dynamodb:DescribeContributorInsights
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[DynamoDBTableProperties],
+ ) -> ProgressEvent[DynamoDBTableProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - dynamodb:DeleteTable
+ - dynamodb:DescribeTable
+ """
+ model = request.desired_state
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ request.custom_context[REPEATED_INVOCATION] = True
+ request.aws_client_factory.dynamodb.delete_table(TableName=model["TableName"])
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ try:
+ table_state = request.aws_client_factory.dynamodb.describe_table(
+ TableName=model["TableName"]
+ )
+
+ match table_state["Table"]["TableStatus"]:
+ case "DELETING":
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+ case invalid_state:
+ return ProgressEvent(
+ status=OperationStatus.FAILED,
+ message=f"Table deletion failed. Table {model['TableName']} found in state {invalid_state}", # TODO: not validated yet
+ resource_model={},
+ )
+ except request.aws_client_factory.dynamodb.exceptions.TableNotFoundException:
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model={},
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[DynamoDBTableProperties],
+ ) -> ProgressEvent[DynamoDBTableProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - dynamodb:UpdateTable
+ - dynamodb:DescribeTable
+ - dynamodb:DescribeTimeToLive
+ - dynamodb:UpdateTimeToLive
+ - dynamodb:UpdateContinuousBackups
+ - dynamodb:UpdateContributorInsights
+ - dynamodb:DescribeContinuousBackups
+ - dynamodb:DescribeKinesisStreamingDestination
+ - dynamodb:ListTagsOfResource
+ - dynamodb:TagResource
+ - dynamodb:UntagResource
+ - dynamodb:DescribeContributorInsights
+ - dynamodb:EnableKinesisStreamingDestination
+ - dynamodb:DisableKinesisStreamingDestination
+ - kinesis:DescribeStream
+ - kinesis:PutRecords
+ - iam:CreateServiceLinkedRole
+ - kms:CreateGrant
+ - kms:Describe*
+ - kms:Get*
+ - kms:List*
+ - kms:RevokeGrant
+ """
+ raise NotImplementedError
+
+ def get_ddb_provisioned_throughput(
+ self,
+ properties: dict,
+ ) -> dict | None:
+ # see https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-dynamodb-table.html#cfn-dynamodb-table-provisionedthroughput
+ args = properties.get("ProvisionedThroughput")
+ if args == "AWS::NoValue":
+ return None
+ is_ondemand = properties.get("BillingMode") == "PAY_PER_REQUEST"
+ # if the BillingMode is set to PAY_PER_REQUEST, you cannot specify ProvisionedThroughput
+ # if the BillingMode is set to PROVISIONED (default), you have to specify ProvisionedThroughput
+
+ if args is None:
+ if is_ondemand:
+ # do not return default value if it's on demand
+ return
+
+ # return default values if it's not on demand
+ return {
+ "ReadCapacityUnits": 5,
+ "WriteCapacityUnits": 5,
+ }
+
+ if isinstance(args["ReadCapacityUnits"], str):
+ args["ReadCapacityUnits"] = int(args["ReadCapacityUnits"])
+ if isinstance(args["WriteCapacityUnits"], str):
+ args["WriteCapacityUnits"] = int(args["WriteCapacityUnits"])
+
+ return args
+
+ def get_ddb_global_sec_indexes(
+ self,
+ properties: dict,
+ ) -> list | None:
+ args: list = properties.get("GlobalSecondaryIndexes")
+ is_ondemand = properties.get("BillingMode") == "PAY_PER_REQUEST"
+ if not args:
+ return
+
+ for index in args:
+ # we ignore ContributorInsightsSpecification as not supported yet in DynamoDB and CloudWatch
+ index.pop("ContributorInsightsSpecification", None)
+ provisioned_throughput = index.get("ProvisionedThroughput")
+ if is_ondemand and provisioned_throughput is None:
+ pass # optional for API calls
+ elif provisioned_throughput is not None:
+ # convert types
+ if isinstance((read_units := provisioned_throughput["ReadCapacityUnits"]), str):
+ provisioned_throughput["ReadCapacityUnits"] = int(read_units)
+ if isinstance((write_units := provisioned_throughput["WriteCapacityUnits"]), str):
+ provisioned_throughput["WriteCapacityUnits"] = int(write_units)
+ else:
+ raise Exception("Can't specify ProvisionedThroughput with PAY_PER_REQUEST")
+ return args
+
+ def get_ddb_kinesis_stream_specification(
+ self,
+ properties: dict,
+ ) -> dict:
+ args = properties.get("KinesisStreamSpecification")
+ if args:
+ args["TableName"] = properties["TableName"]
+ return args
+
+ def list(
+ self,
+ request: ResourceRequest[DynamoDBTableProperties],
+ ) -> ProgressEvent[DynamoDBTableProperties]:
+ resources = request.aws_client_factory.dynamodb.list_tables()
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_models=[
+ DynamoDBTableProperties(TableName=resource) for resource in resources["TableNames"]
+ ],
+ )
diff --git a/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_table.schema.json b/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_table.schema.json
new file mode 100644
index 0000000000000..c4dd5ef70eb3d
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_table.schema.json
@@ -0,0 +1,514 @@
+{
+ "typeName": "AWS::DynamoDB::Table",
+ "description": "Version: None. Resource Type definition for AWS::DynamoDB::Table",
+ "additionalProperties": false,
+ "properties": {
+ "Arn": {
+ "type": "string"
+ },
+ "StreamArn": {
+ "type": "string"
+ },
+ "AttributeDefinitions": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/AttributeDefinition"
+ }
+ },
+ "BillingMode": {
+ "type": "string"
+ },
+ "DeletionProtectionEnabled": {
+ "type": "boolean"
+ },
+ "GlobalSecondaryIndexes": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/GlobalSecondaryIndex"
+ }
+ },
+ "KeySchema": {
+ "oneOf": [
+ {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/KeySchema"
+ }
+ },
+ {
+ "type": "object"
+ }
+ ]
+ },
+ "LocalSecondaryIndexes": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/LocalSecondaryIndex"
+ }
+ },
+ "PointInTimeRecoverySpecification": {
+ "$ref": "#/definitions/PointInTimeRecoverySpecification"
+ },
+ "TableClass": {
+ "type": "string"
+ },
+ "ProvisionedThroughput": {
+ "$ref": "#/definitions/ProvisionedThroughput"
+ },
+ "SSESpecification": {
+ "$ref": "#/definitions/SSESpecification"
+ },
+ "StreamSpecification": {
+ "$ref": "#/definitions/StreamSpecification"
+ },
+ "TableName": {
+ "type": "string"
+ },
+ "Tags": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "TimeToLiveSpecification": {
+ "$ref": "#/definitions/TimeToLiveSpecification"
+ },
+ "ContributorInsightsSpecification": {
+ "$ref": "#/definitions/ContributorInsightsSpecification"
+ },
+ "KinesisStreamSpecification": {
+ "$ref": "#/definitions/KinesisStreamSpecification"
+ },
+ "ImportSourceSpecification": {
+ "$ref": "#/definitions/ImportSourceSpecification"
+ }
+ },
+ "propertyTransform": {
+ "/properties/SSESpecification/KMSMasterKeyId": "$join([\"arn:(aws)[-]{0,1}[a-z]{0,2}[-]{0,1}[a-z]{0,3}:kms:[a-z]{2}[-]{1}[a-z]{3,10}[-]{0,1}[a-z]{0,4}[-]{1}[1-4]{1}:[0-9]{12}[:]{1}key\\/\", SSESpecification.KMSMasterKeyId])"
+ },
+ "definitions": {
+ "StreamSpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "StreamViewType": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "StreamViewType"
+ ]
+ },
+ "DeprecatedKeySchema": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "HashKeyElement": {
+ "$ref": "#/definitions/DeprecatedHashKeyElement"
+ }
+ },
+ "required": [
+ "HashKeyElement"
+ ]
+ },
+ "DeprecatedHashKeyElement": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "AttributeType": {
+ "type": "string"
+ },
+ "AttributeName": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "AttributeType",
+ "AttributeName"
+ ]
+ },
+ "KeySchema": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "AttributeName": {
+ "type": "string"
+ },
+ "KeyType": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "KeyType",
+ "AttributeName"
+ ]
+ },
+ "PointInTimeRecoverySpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "PointInTimeRecoveryEnabled": {
+ "type": "boolean"
+ }
+ }
+ },
+ "KinesisStreamSpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "StreamArn": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "StreamArn"
+ ]
+ },
+ "TimeToLiveSpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "AttributeName": {
+ "type": "string"
+ },
+ "Enabled": {
+ "type": "boolean"
+ }
+ },
+ "required": [
+ "Enabled",
+ "AttributeName"
+ ]
+ },
+ "LocalSecondaryIndex": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "IndexName": {
+ "type": "string"
+ },
+ "KeySchema": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/KeySchema"
+ }
+ },
+ "Projection": {
+ "$ref": "#/definitions/Projection"
+ }
+ },
+ "required": [
+ "IndexName",
+ "Projection",
+ "KeySchema"
+ ]
+ },
+ "GlobalSecondaryIndex": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "IndexName": {
+ "type": "string"
+ },
+ "KeySchema": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/KeySchema"
+ }
+ },
+ "Projection": {
+ "$ref": "#/definitions/Projection"
+ },
+ "ProvisionedThroughput": {
+ "$ref": "#/definitions/ProvisionedThroughput"
+ },
+ "ContributorInsightsSpecification": {
+ "$ref": "#/definitions/ContributorInsightsSpecification"
+ }
+ },
+ "required": [
+ "IndexName",
+ "Projection",
+ "KeySchema"
+ ]
+ },
+ "SSESpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "KMSMasterKeyId": {
+ "type": "string"
+ },
+ "SSEEnabled": {
+ "type": "boolean"
+ },
+ "SSEType": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "SSEEnabled"
+ ]
+ },
+ "AttributeDefinition": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "AttributeName": {
+ "type": "string"
+ },
+ "AttributeType": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "AttributeName",
+ "AttributeType"
+ ]
+ },
+ "ProvisionedThroughput": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "ReadCapacityUnits": {
+ "type": "integer"
+ },
+ "WriteCapacityUnits": {
+ "type": "integer"
+ }
+ },
+ "required": [
+ "WriteCapacityUnits",
+ "ReadCapacityUnits"
+ ]
+ },
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Key": {
+ "type": "string"
+ },
+ "Value": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ },
+ "Projection": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "NonKeyAttributes": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "ProjectionType": {
+ "type": "string"
+ }
+ }
+ },
+ "ContributorInsightsSpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Enabled": {
+ "type": "boolean"
+ }
+ },
+ "required": [
+ "Enabled"
+ ]
+ },
+ "ImportSourceSpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "S3BucketSource": {
+ "$ref": "#/definitions/S3BucketSource"
+ },
+ "InputFormat": {
+ "type": "string"
+ },
+ "InputFormatOptions": {
+ "$ref": "#/definitions/InputFormatOptions"
+ },
+ "InputCompressionType": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "S3BucketSource",
+ "InputFormat"
+ ]
+ },
+ "S3BucketSource": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "S3BucketOwner": {
+ "type": "string"
+ },
+ "S3Bucket": {
+ "type": "string"
+ },
+ "S3KeyPrefix": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "S3Bucket"
+ ]
+ },
+ "InputFormatOptions": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Csv": {
+ "$ref": "#/definitions/Csv"
+ }
+ }
+ },
+ "Csv": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "HeaderList": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Delimiter": {
+ "type": "string"
+ }
+ }
+ }
+ },
+ "tagging": {
+ "taggable": true,
+ "tagOnCreate": true,
+ "tagUpdatable": true,
+ "cloudFormationSystemTags": false,
+ "tagProperty": "/properties/Tags"
+ },
+ "required": [
+ "KeySchema"
+ ],
+ "readOnlyProperties": [
+ "/properties/Arn",
+ "/properties/StreamArn"
+ ],
+ "createOnlyProperties": [
+ "/properties/TableName",
+ "/properties/ImportSourceSpecification"
+ ],
+ "primaryIdentifier": [
+ "/properties/TableName"
+ ],
+ "writeOnlyProperties": [
+ "/properties/ImportSourceSpecification"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "dynamodb:CreateTable",
+ "dynamodb:DescribeImport",
+ "dynamodb:DescribeTable",
+ "dynamodb:DescribeTimeToLive",
+ "dynamodb:UpdateTimeToLive",
+ "dynamodb:UpdateContributorInsights",
+ "dynamodb:UpdateContinuousBackups",
+ "dynamodb:DescribeContinuousBackups",
+ "dynamodb:DescribeContributorInsights",
+ "dynamodb:EnableKinesisStreamingDestination",
+ "dynamodb:DisableKinesisStreamingDestination",
+ "dynamodb:DescribeKinesisStreamingDestination",
+ "dynamodb:ImportTable",
+ "dynamodb:ListTagsOfResource",
+ "dynamodb:TagResource",
+ "dynamodb:UpdateTable",
+ "kinesis:DescribeStream",
+ "kinesis:PutRecords",
+ "iam:CreateServiceLinkedRole",
+ "kms:CreateGrant",
+ "kms:Decrypt",
+ "kms:Describe*",
+ "kms:Encrypt",
+ "kms:Get*",
+ "kms:List*",
+ "kms:RevokeGrant",
+ "logs:CreateLogGroup",
+ "logs:CreateLogStream",
+ "logs:DescribeLogGroups",
+ "logs:DescribeLogStreams",
+ "logs:PutLogEvents",
+ "logs:PutRetentionPolicy",
+ "s3:GetObject",
+ "s3:GetObjectMetadata",
+ "s3:ListBucket"
+ ],
+ "timeoutInMinutes": 720
+ },
+ "read": {
+ "permissions": [
+ "dynamodb:DescribeTable",
+ "dynamodb:DescribeContinuousBackups",
+ "dynamodb:DescribeContributorInsights"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "dynamodb:UpdateTable",
+ "dynamodb:DescribeTable",
+ "dynamodb:DescribeTimeToLive",
+ "dynamodb:UpdateTimeToLive",
+ "dynamodb:UpdateContinuousBackups",
+ "dynamodb:UpdateContributorInsights",
+ "dynamodb:DescribeContinuousBackups",
+ "dynamodb:DescribeKinesisStreamingDestination",
+ "dynamodb:ListTagsOfResource",
+ "dynamodb:TagResource",
+ "dynamodb:UntagResource",
+ "dynamodb:DescribeContributorInsights",
+ "dynamodb:EnableKinesisStreamingDestination",
+ "dynamodb:DisableKinesisStreamingDestination",
+ "kinesis:DescribeStream",
+ "kinesis:PutRecords",
+ "iam:CreateServiceLinkedRole",
+ "kms:CreateGrant",
+ "kms:Describe*",
+ "kms:Get*",
+ "kms:List*",
+ "kms:RevokeGrant"
+ ],
+ "timeoutInMinutes": 720
+ },
+ "delete": {
+ "permissions": [
+ "dynamodb:DeleteTable",
+ "dynamodb:DescribeTable"
+ ],
+ "timeoutInMinutes": 720
+ },
+ "list": {
+ "permissions": [
+ "dynamodb:ListTables"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_table_plugin.py b/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_table_plugin.py
new file mode 100644
index 0000000000000..5f263b9e9d068
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodb/resource_providers/aws_dynamodb_table_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class DynamoDBTableProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::DynamoDB::Table"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.dynamodb.resource_providers.aws_dynamodb_table import (
+ DynamoDBTableProvider,
+ )
+
+ self.factory = DynamoDBTableProvider
diff --git a/localstack-core/localstack/services/dynamodb/server.py b/localstack-core/localstack/services/dynamodb/server.py
new file mode 100644
index 0000000000000..66921057cc627
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodb/server.py
@@ -0,0 +1,234 @@
+import contextlib
+import logging
+import os
+import subprocess
+import threading
+
+from localstack import config
+from localstack.aws.connect import connect_externally_to
+from localstack.aws.forwarder import AwsRequestProxy
+from localstack.config import is_env_true
+from localstack.constants import DEFAULT_AWS_ACCOUNT_ID
+from localstack.services.dynamodb.packages import dynamodblocal_package
+from localstack.utils.common import TMP_THREADS, ShellCommandThread, get_free_tcp_port, mkdir
+from localstack.utils.functions import run_safe
+from localstack.utils.net import wait_for_port_closed
+from localstack.utils.objects import singleton_factory
+from localstack.utils.run import FuncThread, run
+from localstack.utils.serving import Server
+from localstack.utils.sync import retry, synchronized
+
+LOG = logging.getLogger(__name__)
+RESTART_LOCK = threading.RLock()
+
+
+def _log_listener(line, **_kwargs):
+ LOG.debug(line.rstrip())
+
+
+class DynamodbServer(Server):
+ db_path: str | None
+ heap_size: str
+
+ delay_transient_statuses: bool
+ optimize_db_before_startup: bool
+ share_db: bool
+ cors: str | None
+
+ proxy: AwsRequestProxy
+
+ def __init__(
+ self,
+ port: int | None = None,
+ host: str = "localhost",
+ db_path: str | None = None,
+ ) -> None:
+ """
+ Creates a DynamoDB server from the local configuration.
+
+ :param port: optional, the port to start the server on (defaults to a random port)
+ :param host: localhost by default
+ :param db_path: path to the persistence state files used by the DynamoDB Local process
+ """
+
+ port = port or get_free_tcp_port()
+ super().__init__(port, host)
+
+ self.db_path = (
+ f"{config.dirs.data}/dynamodb" if not db_path and config.dirs.data else db_path
+ )
+
+ # the DYNAMODB_IN_MEMORY variable takes precedence and will set the DB path to None which forces inMemory=true
+ if is_env_true("DYNAMODB_IN_MEMORY"):
+ # note: with DYNAMODB_IN_MEMORY we do not support persistence
+ self.db_path = None
+
+ if self.db_path:
+ self.db_path = os.path.abspath(self.db_path)
+
+ self.heap_size = config.DYNAMODB_HEAP_SIZE
+ self.delay_transient_statuses = is_env_true("DYNAMODB_DELAY_TRANSIENT_STATUSES")
+ self.optimize_db_before_startup = is_env_true("DYNAMODB_OPTIMIZE_DB_BEFORE_STARTUP")
+ self.share_db = is_env_true("DYNAMODB_SHARE_DB")
+ self.cors = os.getenv("DYNAMODB_CORS", None)
+ self.proxy = AwsRequestProxy(self.url)
+
+ @staticmethod
+ @singleton_factory
+ def get() -> "DynamodbServer":
+ return DynamodbServer(config.DYNAMODB_LOCAL_PORT)
+
+ @synchronized(lock=RESTART_LOCK)
+ def start_dynamodb(self) -> bool:
+ """Start the DynamoDB server."""
+
+ # We want this method to be idempotent.
+ if self.is_running() and self.is_up():
+ return True
+
+ # For the v2 provider, the DynamodbServer has been made a singleton. Yet, the Server abstraction is modelled
+ # after threading.Thread, where Start -> Stop -> Start is not allowed. This flow happens during state resets.
+ # The following is a workaround that permits this flow
+ self._started.clear()
+ self._stopped.clear()
+
+ # Note: when starting the server, we had a flag for wiping the assets directory before the actual start.
+ # This behavior was needed in some particular cases:
+ # - pod load with some assets already lying in the asset folder
+ # - ...
+ # The cleaning is now done via the reset endpoint
+ if self.db_path:
+ mkdir(self.db_path)
+
+ started = self.start()
+ self.wait_for_dynamodb()
+ return started
+
+ @synchronized(lock=RESTART_LOCK)
+ def stop_dynamodb(self) -> None:
+ """Stop the DynamoDB server."""
+ import psutil
+
+ if self._thread is None:
+ return
+ self._thread.auto_restart = False
+ self.shutdown()
+ self.join(timeout=10)
+ try:
+ wait_for_port_closed(self.port, sleep_time=0.8, retries=10)
+ except Exception:
+ LOG.warning(
+ "DynamoDB server port %s (%s) unexpectedly still open; running processes: %s",
+ self.port,
+ self._thread,
+ run(["ps", "aux"]),
+ )
+
+ # attempt to terminate/kill the process manually
+ server_pid = self._thread.process.pid # noqa
+ LOG.info("Attempting to kill DynamoDB process %s", server_pid)
+ process = psutil.Process(server_pid)
+ run_safe(process.terminate)
+ run_safe(process.kill)
+ wait_for_port_closed(self.port, sleep_time=0.5, retries=8)
+
+ @property
+ def in_memory(self) -> bool:
+ return self.db_path is None
+
+ @property
+ def jar_path(self) -> str:
+ return f"{dynamodblocal_package.get_installed_dir()}/DynamoDBLocal.jar"
+
+ @property
+ def library_path(self) -> str:
+ return f"{dynamodblocal_package.get_installed_dir()}/DynamoDBLocal_lib"
+
+ def _get_java_vm_options(self) -> list[str]:
+ dynamodblocal_installer = dynamodblocal_package.get_installer()
+
+ # Workaround for JVM SIGILL crash on Apple Silicon M4
+ # See https://bugs.openjdk.org/browse/JDK-8345296
+ # To be removed after Java is bumped to 17.0.15+ and 21.0.7+
+
+ # This command returns all supported JVM options
+ with contextlib.suppress(subprocess.CalledProcessError):
+ stdout = run(
+ cmd=["java", "-XX:+UnlockDiagnosticVMOptions", "-XX:+PrintFlagsFinal", "-version"],
+ env_vars=dynamodblocal_installer.get_java_env_vars(),
+ print_error=True,
+ )
+ # Check if Scalable Vector Extensions are support on this JVM and CPU. If so, disable it
+ if "UseSVE" in stdout:
+ return ["-XX:UseSVE=0"]
+ return []
+
+ def _create_shell_command(self) -> list[str]:
+ cmd = [
+ "java",
+ *self._get_java_vm_options(),
+ "-Xmx%s" % self.heap_size,
+ f"-javaagent:{dynamodblocal_package.get_installer().get_ddb_agent_jar_path()}",
+ f"-Djava.library.path={self.library_path}",
+ "-jar",
+ self.jar_path,
+ ]
+ parameters = []
+
+ parameters.extend(["-port", str(self.port)])
+ if self.in_memory:
+ parameters.append("-inMemory")
+ if self.db_path:
+ parameters.extend(["-dbPath", self.db_path])
+ if self.delay_transient_statuses:
+ parameters.extend(["-delayTransientStatuses"])
+ if self.optimize_db_before_startup:
+ parameters.extend(["-optimizeDbBeforeStartup"])
+ if self.share_db:
+ parameters.extend(["-sharedDb"])
+
+ return cmd + parameters
+
+ def do_start_thread(self) -> FuncThread:
+ dynamodblocal_installer = dynamodblocal_package.get_installer()
+ dynamodblocal_installer.install()
+
+ cmd = self._create_shell_command()
+ env_vars = {
+ **dynamodblocal_installer.get_java_env_vars(),
+ "DDB_LOCAL_TELEMETRY": "0",
+ }
+
+ LOG.debug("Starting DynamoDB Local: %s", cmd)
+ t = ShellCommandThread(
+ cmd,
+ strip_color=True,
+ log_listener=_log_listener,
+ auto_restart=True,
+ name="dynamodb-local",
+ env_vars=env_vars,
+ )
+ TMP_THREADS.append(t)
+ t.start()
+ return t
+
+ def check_dynamodb(self, expect_shutdown: bool = False) -> None:
+ """Checks if DynamoDB server is up"""
+ out = None
+
+ try:
+ self.wait_is_up()
+ out = connect_externally_to(
+ endpoint_url=self.url,
+ aws_access_key_id=DEFAULT_AWS_ACCOUNT_ID,
+ aws_secret_access_key=DEFAULT_AWS_ACCOUNT_ID,
+ ).dynamodb.list_tables()
+ except Exception:
+ LOG.exception("DynamoDB health check failed")
+ if expect_shutdown:
+ assert out is None
+ else:
+ assert isinstance(out["TableNames"], list)
+
+ def wait_for_dynamodb(self) -> None:
+ retry(self.check_dynamodb, sleep=0.4, retries=10)
diff --git a/localstack-core/localstack/services/dynamodb/utils.py b/localstack-core/localstack/services/dynamodb/utils.py
new file mode 100644
index 0000000000000..995458b2deed7
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodb/utils.py
@@ -0,0 +1,350 @@
+import logging
+import re
+from binascii import crc32
+from typing import Dict, List, Optional
+
+from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
+from cachetools import TTLCache
+from moto.core.exceptions import JsonRESTError
+
+from localstack.aws.api import RequestContext
+from localstack.aws.api.dynamodb import (
+ AttributeMap,
+ BatchGetRequestMap,
+ BatchGetResponseMap,
+ Delete,
+ DeleteRequest,
+ Put,
+ PutRequest,
+ ResourceNotFoundException,
+ TableName,
+ Update,
+)
+from localstack.aws.connect import connect_to
+from localstack.constants import INTERNAL_AWS_SECRET_ACCESS_KEY
+from localstack.http import Response
+from localstack.utils.aws.arns import dynamodb_table_arn, get_partition
+from localstack.utils.json import canonical_json
+from localstack.utils.testutil import list_all_resources
+
+LOG = logging.getLogger(__name__)
+
+# cache schema definitions
+SCHEMA_CACHE = TTLCache(maxsize=50, ttl=20)
+
+_ddb_local_arn_pattern = re.compile(
+ r'("TableArn"|"LatestStreamArn"|"StreamArn"|"ShardIterator"|"IndexArn")\s*:\s*"arn:[a-z-]+:dynamodb:ddblocal:000000000000:([^"]+)"'
+)
+_ddb_local_region_pattern = re.compile(r'"awsRegion"\s*:\s*"([^"]+)"')
+_ddb_local_exception_arn_pattern = re.compile(r'arn:[a-z-]+:dynamodb:ddblocal:000000000000:([^"]+)')
+
+
+def get_ddb_access_key(account_id: str, region_name: str) -> str:
+ """
+ Get the access key to be used while communicating with DynamoDB Local.
+
+ DDBLocal supports namespacing as an undocumented feature. It works based on the value of the `Credentials`
+ field of the `Authorization` header. We use a concatenated value of account ID and region to achieve
+ namespacing.
+ """
+ return f"{account_id}{region_name}".replace("-", "")
+
+
+class ItemSet:
+ """Represents a set of items and provides utils to find individual items in the set"""
+
+ def __init__(self, items: List[Dict], key_schema: List[Dict]):
+ self.items_list = items
+ self.key_schema = key_schema
+ self._build_dict()
+
+ def _build_dict(self):
+ self.items_dict = {}
+ for item in self.items_list:
+ self.items_dict[self._hashable_key(item)] = item
+
+ def _hashable_key(self, item: Dict):
+ keys = SchemaExtractor.extract_keys_for_schema(item=item, key_schema=self.key_schema)
+ return canonical_json(keys)
+
+ def find_item(self, item: Dict) -> Optional[Dict]:
+ key = self._hashable_key(item)
+ return self.items_dict.get(key)
+
+
+class SchemaExtractor:
+ @classmethod
+ def extract_keys(
+ cls, item: Dict, table_name: str, account_id: str, region_name: str
+ ) -> Optional[Dict]:
+ key_schema = cls.get_key_schema(table_name, account_id, region_name)
+ return cls.extract_keys_for_schema(item, key_schema)
+
+ @classmethod
+ def extract_keys_for_schema(cls, item: Dict, key_schema: List[Dict]):
+ result = {}
+ for key in key_schema:
+ attr_name = key["AttributeName"]
+ if attr_name not in item:
+ raise JsonRESTError(
+ error_type="ValidationException",
+ message="One of the required keys was not given a value",
+ )
+ result[attr_name] = item[attr_name]
+ return result
+
+ @classmethod
+ def get_key_schema(
+ cls, table_name: str, account_id: str, region_name: str
+ ) -> Optional[List[Dict]]:
+ from localstack.services.dynamodb.provider import get_store
+
+ table_definitions: Dict = get_store(
+ account_id=account_id,
+ region_name=region_name,
+ ).table_definitions
+ table_def = table_definitions.get(table_name)
+ if not table_def:
+ # Try fetching from the backend in case table_definitions has been reset
+ schema = cls.get_table_schema(
+ table_name=table_name, account_id=account_id, region_name=region_name
+ )
+ if not schema:
+ raise ResourceNotFoundException(f"Unknown table: {table_name} not found")
+ # Save the schema in the cache
+ table_definitions[table_name] = schema["Table"]
+ table_def = table_definitions[table_name]
+ return table_def["KeySchema"]
+
+ @classmethod
+ def get_table_schema(cls, table_name: str, account_id: str, region_name: str):
+ key = dynamodb_table_arn(
+ table_name=table_name, account_id=account_id, region_name=region_name
+ )
+ schema = SCHEMA_CACHE.get(key)
+ if not schema:
+ # TODO: consider making in-memory lookup instead of API call
+ ddb_client = connect_to(
+ aws_access_key_id=account_id,
+ aws_secret_access_key=INTERNAL_AWS_SECRET_ACCESS_KEY,
+ region_name=region_name,
+ ).dynamodb
+ try:
+ schema = ddb_client.describe_table(TableName=table_name)
+ SCHEMA_CACHE[key] = schema
+ except Exception as e:
+ if "ResourceNotFoundException" in str(e):
+ raise ResourceNotFoundException(f"Unknown table: {table_name}") from e
+ raise
+ return schema
+
+ @classmethod
+ def invalidate_table_schema(cls, table_name: str, account_id: str, region_name: str):
+ """
+ Allow cached table schemas to be invalidated without waiting for the TTL to expire
+ """
+ key = dynamodb_table_arn(
+ table_name=table_name, account_id=account_id, region_name=region_name
+ )
+ SCHEMA_CACHE.pop(key, None)
+
+
+class ItemFinder:
+ @staticmethod
+ def get_ddb_local_client(account_id: str, region_name: str, endpoint_url: str):
+ ddb_client = connect_to(
+ aws_access_key_id=get_ddb_access_key(account_id, region_name),
+ region_name=region_name,
+ endpoint_url=endpoint_url,
+ ).dynamodb
+ return ddb_client
+
+ @staticmethod
+ def find_existing_item(
+ put_item: Dict,
+ table_name: str,
+ account_id: str,
+ region_name: str,
+ endpoint_url: str,
+ ) -> Optional[AttributeMap]:
+ from localstack.services.dynamodb.provider import ValidationException
+
+ ddb_client = ItemFinder.get_ddb_local_client(account_id, region_name, endpoint_url)
+
+ search_key = {}
+ if "Key" in put_item:
+ search_key = put_item["Key"]
+ else:
+ schema = SchemaExtractor.get_table_schema(table_name, account_id, region_name)
+ schemas = [schema["Table"]["KeySchema"]]
+ for index in schema["Table"].get("GlobalSecondaryIndexes", []):
+ # TODO
+ # schemas.append(index['KeySchema'])
+ pass
+ for schema in schemas:
+ for key in schema:
+ key_name = key["AttributeName"]
+ key_value = put_item["Item"].get(key_name)
+ if not key_value:
+ raise ValidationException(
+ "The provided key element does not match the schema"
+ )
+ search_key[key_name] = key_value
+ if not search_key:
+ return
+
+ try:
+ existing_item = ddb_client.get_item(TableName=table_name, Key=search_key)
+ except ddb_client.exceptions.ClientError as e:
+ LOG.warning(
+ "Unable to get item from DynamoDB table '%s': %s",
+ table_name,
+ e,
+ )
+ return
+
+ return existing_item.get("Item")
+
+ @staticmethod
+ def find_existing_items(
+ put_items_per_table: dict[
+ TableName, list[PutRequest | DeleteRequest | Put | Update | Delete]
+ ],
+ account_id: str,
+ region_name: str,
+ endpoint_url: str,
+ ) -> BatchGetResponseMap:
+ from localstack.services.dynamodb.provider import ValidationException
+
+ ddb_client = ItemFinder.get_ddb_local_client(account_id, region_name, endpoint_url)
+
+ get_items_request: BatchGetRequestMap = {}
+ for table_name, put_item_reqs in put_items_per_table.items():
+ table_schema = None
+ for put_item in put_item_reqs:
+ search_key = {}
+ if "Key" in put_item:
+ search_key = put_item["Key"]
+ else:
+ if not table_schema:
+ table_schema = SchemaExtractor.get_table_schema(
+ table_name, account_id, region_name
+ )
+
+ schemas = [table_schema["Table"]["KeySchema"]]
+ for index in table_schema["Table"].get("GlobalSecondaryIndexes", []):
+ # TODO
+ # schemas.append(index['KeySchema'])
+ pass
+ for schema in schemas:
+ for key in schema:
+ key_name = key["AttributeName"]
+ key_value = put_item["Item"].get(key_name)
+ if not key_value:
+ raise ValidationException(
+ "The provided key element does not match the schema"
+ )
+ search_key[key_name] = key_value
+ if not search_key:
+ continue
+ table_keys = get_items_request.setdefault(table_name, {"Keys": []})
+ table_keys["Keys"].append(search_key)
+
+ try:
+ existing_items = ddb_client.batch_get_item(RequestItems=get_items_request)
+ except ddb_client.exceptions.ClientError as e:
+ LOG.warning(
+ "Unable to get items from DynamoDB tables '%s': %s",
+ list(put_items_per_table.values()),
+ e,
+ )
+ return {}
+
+ return existing_items.get("Responses", {})
+
+ @classmethod
+ def list_existing_items_for_statement(
+ cls, partiql_statement: str, account_id: str, region_name: str, endpoint_url: str
+ ) -> List:
+ table_name = extract_table_name_from_partiql_update(partiql_statement)
+ if not table_name:
+ return []
+ all_items = cls.get_all_table_items(
+ account_id=account_id,
+ region_name=region_name,
+ table_name=table_name,
+ endpoint_url=endpoint_url,
+ )
+ return all_items
+
+ @staticmethod
+ def get_all_table_items(
+ account_id: str, region_name: str, table_name: str, endpoint_url: str
+ ) -> List:
+ ddb_client = ItemFinder.get_ddb_local_client(account_id, region_name, endpoint_url)
+ dynamodb_kwargs = {"TableName": table_name}
+ all_items = list_all_resources(
+ lambda kwargs: ddb_client.scan(**{**kwargs, **dynamodb_kwargs}),
+ last_token_attr_name="LastEvaluatedKey",
+ next_token_attr_name="ExclusiveStartKey",
+ list_attr_name="Items",
+ )
+ return all_items
+
+
+def extract_table_name_from_partiql_update(statement: str) -> Optional[str]:
+ regex = r"^\s*(UPDATE|INSERT\s+INTO|DELETE\s+FROM)\s+([^\s]+).*"
+ match = re.match(regex, statement, flags=re.IGNORECASE | re.MULTILINE)
+ return match and match.group(2)
+
+
+def dynamize_value(value) -> dict:
+ """
+ Take a scalar Python value or dict/list and return a dict consisting of the Amazon DynamoDB type specification and
+ the value that needs to be sent to Amazon DynamoDB. If the type of the value is not supported, raise a TypeError
+ """
+ return TypeSerializer().serialize(value)
+
+
+def de_dynamize_record(item: dict) -> dict:
+ """
+ Return the given item in DynamoDB format parsed as regular dict object, i.e., convert
+ something like `{'foo': {'S': 'test'}, 'bar': {'N': 123}}` to `{'foo': 'test', 'bar': 123}`.
+ Note: This is the reverse operation of `dynamize_value(...)` above.
+ """
+ deserializer = TypeDeserializer()
+ return {k: deserializer.deserialize(v) for k, v in item.items()}
+
+
+def modify_ddblocal_arns(chain, context: RequestContext, response: Response):
+ """A service response handler that modifies the dynamodb backend response."""
+ if response_content := response.get_data(as_text=True):
+ partition = get_partition(context.region)
+
+ def _convert_arn(matchobj):
+ key = matchobj.group(1)
+ table_name = matchobj.group(2)
+ return f'{key}: "arn:{partition}:dynamodb:{context.region}:{context.account_id}:{table_name}"'
+
+ # fix the table and latest stream ARNs (DynamoDBLocal hardcodes "ddblocal" as the region)
+ content_replaced = _ddb_local_arn_pattern.sub(
+ _convert_arn,
+ response_content,
+ )
+ if context.service.service_name == "dynamodbstreams":
+ content_replaced = _ddb_local_region_pattern.sub(
+ f'"awsRegion": "{context.region}"', content_replaced
+ )
+ if context.service_exception:
+ content_replaced = _ddb_local_exception_arn_pattern.sub(
+ rf"arn:{partition}:dynamodb:{context.region}:{context.account_id}:\g<1>",
+ content_replaced,
+ )
+
+ if content_replaced != response_content:
+ response.data = content_replaced
+ # make sure the service response is parsed again later
+ context.service_response = None
+
+ # update x-amz-crc32 header required by some clients
+ response.headers["x-amz-crc32"] = crc32(response.data) & 0xFFFFFFFF
diff --git a/localstack-core/localstack/services/dynamodb/v2/__init__.py b/localstack-core/localstack/services/dynamodb/v2/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/dynamodb/v2/provider.py b/localstack-core/localstack/services/dynamodb/v2/provider.py
new file mode 100644
index 0000000000000..f6dee3a68e854
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodb/v2/provider.py
@@ -0,0 +1,1477 @@
+import copy
+import json
+import logging
+import os
+import random
+import re
+import threading
+import time
+from contextlib import contextmanager
+from datetime import datetime
+from operator import itemgetter
+from typing import Dict, Optional
+
+import requests
+import werkzeug
+
+from localstack import config
+from localstack.aws import handlers
+from localstack.aws.api import (
+ CommonServiceException,
+ RequestContext,
+ ServiceRequest,
+ ServiceResponse,
+ handler,
+)
+from localstack.aws.api.dynamodb import (
+ BatchExecuteStatementOutput,
+ BatchGetItemOutput,
+ BatchGetRequestMap,
+ BatchWriteItemInput,
+ BatchWriteItemOutput,
+ BillingMode,
+ ContinuousBackupsDescription,
+ ContinuousBackupsStatus,
+ CreateGlobalTableOutput,
+ CreateTableInput,
+ CreateTableOutput,
+ DeleteItemInput,
+ DeleteItemOutput,
+ DeleteRequest,
+ DeleteTableOutput,
+ DescribeContinuousBackupsOutput,
+ DescribeGlobalTableOutput,
+ DescribeKinesisStreamingDestinationOutput,
+ DescribeTableOutput,
+ DescribeTimeToLiveOutput,
+ DestinationStatus,
+ DynamodbApi,
+ EnableKinesisStreamingConfiguration,
+ ExecuteStatementInput,
+ ExecuteStatementOutput,
+ ExecuteTransactionInput,
+ ExecuteTransactionOutput,
+ GetItemInput,
+ GetItemOutput,
+ GlobalTableAlreadyExistsException,
+ GlobalTableNotFoundException,
+ KinesisStreamingDestinationOutput,
+ ListGlobalTablesOutput,
+ ListTablesInputLimit,
+ ListTablesOutput,
+ ListTagsOfResourceOutput,
+ NextTokenString,
+ PartiQLBatchRequest,
+ PointInTimeRecoveryDescription,
+ PointInTimeRecoverySpecification,
+ PointInTimeRecoveryStatus,
+ PositiveIntegerObject,
+ ProvisionedThroughputExceededException,
+ PutItemInput,
+ PutItemOutput,
+ PutRequest,
+ QueryInput,
+ QueryOutput,
+ RegionName,
+ ReplicaDescription,
+ ReplicaList,
+ ReplicaStatus,
+ ReplicaUpdateList,
+ ResourceArnString,
+ ResourceInUseException,
+ ResourceNotFoundException,
+ ReturnConsumedCapacity,
+ ScanInput,
+ ScanOutput,
+ StreamArn,
+ TableDescription,
+ TableName,
+ TagKeyList,
+ TagList,
+ TimeToLiveSpecification,
+ TransactGetItemList,
+ TransactGetItemsOutput,
+ TransactWriteItemsInput,
+ TransactWriteItemsOutput,
+ UpdateContinuousBackupsOutput,
+ UpdateGlobalTableOutput,
+ UpdateItemInput,
+ UpdateItemOutput,
+ UpdateTableInput,
+ UpdateTableOutput,
+ UpdateTimeToLiveOutput,
+ WriteRequest,
+)
+from localstack.aws.connect import connect_to
+from localstack.constants import (
+ AUTH_CREDENTIAL_REGEX,
+ AWS_REGION_US_EAST_1,
+ INTERNAL_AWS_SECRET_ACCESS_KEY,
+)
+from localstack.http import Request, Response, route
+from localstack.services.dynamodb.models import (
+ DynamoDBStore,
+ StreamRecord,
+ dynamodb_stores,
+)
+from localstack.services.dynamodb.server import DynamodbServer
+from localstack.services.dynamodb.utils import (
+ SchemaExtractor,
+ get_ddb_access_key,
+ modify_ddblocal_arns,
+)
+from localstack.services.dynamodbstreams.models import dynamodbstreams_stores
+from localstack.services.edge import ROUTER
+from localstack.services.plugins import ServiceLifecycleHook
+from localstack.state import AssetDirectory, StateVisitor
+from localstack.utils.aws import arns
+from localstack.utils.aws.arns import (
+ extract_account_id_from_arn,
+ extract_region_from_arn,
+ get_partition,
+)
+from localstack.utils.aws.aws_stack import get_valid_regions_for_service
+from localstack.utils.aws.request_context import (
+ extract_account_id_from_headers,
+ extract_region_from_headers,
+)
+from localstack.utils.collections import select_attributes, select_from_typed_dict
+from localstack.utils.common import short_uid, to_bytes
+from localstack.utils.json import canonical_json
+from localstack.utils.scheduler import Scheduler
+from localstack.utils.strings import long_uid, to_str
+from localstack.utils.threads import FuncThread, start_thread
+
+# set up logger
+LOG = logging.getLogger(__name__)
+
+# action header prefix
+ACTION_PREFIX = "DynamoDB_20120810."
+
+# list of actions subject to throughput limitations
+READ_THROTTLED_ACTIONS = [
+ "GetItem",
+ "Query",
+ "Scan",
+ "TransactGetItems",
+ "BatchGetItem",
+]
+WRITE_THROTTLED_ACTIONS = [
+ "PutItem",
+ "BatchWriteItem",
+ "UpdateItem",
+ "DeleteItem",
+ "TransactWriteItems",
+]
+THROTTLED_ACTIONS = READ_THROTTLED_ACTIONS + WRITE_THROTTLED_ACTIONS
+
+MANAGED_KMS_KEYS = {}
+
+
+def dynamodb_table_exists(table_name: str, client=None) -> bool:
+ client = client or connect_to().dynamodb
+ paginator = client.get_paginator("list_tables")
+ pages = paginator.paginate(PaginationConfig={"PageSize": 100})
+ table_name = to_str(table_name)
+ return any(table_name in page["TableNames"] for page in pages)
+
+
+class SSEUtils:
+ """Utils for server-side encryption (SSE)"""
+
+ @classmethod
+ def get_sse_kms_managed_key(cls, account_id: str, region_name: str):
+ from localstack.services.kms import provider
+
+ existing_key = MANAGED_KMS_KEYS.get(region_name)
+ if existing_key:
+ return existing_key
+ kms_client = connect_to(
+ aws_access_key_id=account_id,
+ aws_secret_access_key=INTERNAL_AWS_SECRET_ACCESS_KEY,
+ region_name=region_name,
+ ).kms
+ key_data = kms_client.create_key(
+ Description="Default key that protects my DynamoDB data when no other key is defined"
+ )
+ key_id = key_data["KeyMetadata"]["KeyId"]
+
+ provider.set_key_managed(key_id, account_id, region_name)
+ MANAGED_KMS_KEYS[region_name] = key_id
+ return key_id
+
+ @classmethod
+ def get_sse_description(cls, account_id: str, region_name: str, data):
+ if data.get("Enabled"):
+ kms_master_key_id = data.get("KMSMasterKeyId")
+ if not kms_master_key_id:
+ # this is of course not the actual key for dynamodb, just a better, since existing, mock
+ kms_master_key_id = cls.get_sse_kms_managed_key(account_id, region_name)
+ kms_master_key_id = arns.kms_key_arn(kms_master_key_id, account_id, region_name)
+ return {
+ "Status": "ENABLED",
+ "SSEType": "KMS", # no other value is allowed here
+ "KMSMasterKeyArn": kms_master_key_id,
+ }
+ return {}
+
+
+class ValidationException(CommonServiceException):
+ def __init__(self, message: str):
+ super().__init__(code="ValidationException", status_code=400, message=message)
+
+
+def get_store(account_id: str, region_name: str) -> DynamoDBStore:
+ # special case: AWS NoSQL Workbench sends "localhost" as region - replace with proper region here
+ region_name = DynamoDBProvider.ddb_region_name(region_name)
+ return dynamodb_stores[account_id][region_name]
+
+
+@contextmanager
+def modify_context_region(context: RequestContext, region: str):
+ """
+ Context manager that modifies the region of a `RequestContext`. At the exit, the context is restored to its
+ original state.
+
+ :param context: the context to modify
+ :param region: the modified region
+ :return: a modified `RequestContext`
+ """
+ original_region = context.region
+ original_authorization = context.request.headers.get("Authorization")
+
+ key = get_ddb_access_key(context.account_id, region)
+
+ context.region = region
+ context.request.headers["Authorization"] = re.sub(
+ AUTH_CREDENTIAL_REGEX,
+ rf"Credential={key}/\2/{region}/\4/",
+ original_authorization or "",
+ flags=re.IGNORECASE,
+ )
+
+ try:
+ yield context
+ except Exception:
+ raise
+ finally:
+ # revert the original context
+ context.region = original_region
+ context.request.headers["Authorization"] = original_authorization
+
+
+class DynamoDBDeveloperEndpoints:
+ """
+ Developer endpoints for DynamoDB
+ DELETE /_aws/dynamodb/expired - delete expired items from tables with TTL enabled; return the number of expired
+ items deleted
+ """
+
+ @route("/_aws/dynamodb/expired", methods=["DELETE"])
+ def delete_expired_messages(self, _: Request):
+ no_expired_items = delete_expired_items()
+ return {"ExpiredItems": no_expired_items}
+
+
+def delete_expired_items() -> int:
+ """
+ This utility function iterates over all stores, looks for tables with TTL enabled,
+ scan such tables and delete expired items.
+ """
+ no_expired_items = 0
+ for account_id, region_name, state in dynamodb_stores.iter_stores():
+ ttl_specs = state.ttl_specifications
+ client = connect_to(aws_access_key_id=account_id, region_name=region_name).dynamodb
+ for table_name, ttl_spec in ttl_specs.items():
+ if ttl_spec.get("Enabled", False):
+ attribute_name = ttl_spec.get("AttributeName")
+ current_time = int(datetime.now().timestamp())
+ try:
+ result = client.scan(
+ TableName=table_name,
+ FilterExpression="#ttl <= :threshold",
+ ExpressionAttributeValues={":threshold": {"N": str(current_time)}},
+ ExpressionAttributeNames={"#ttl": attribute_name},
+ )
+ items_to_delete = result.get("Items", [])
+ no_expired_items += len(items_to_delete)
+ table_description = client.describe_table(TableName=table_name)
+ partition_key, range_key = _get_hash_and_range_key(table_description)
+ keys_to_delete = [
+ {partition_key: item.get(partition_key)}
+ if range_key is None
+ else {
+ partition_key: item.get(partition_key),
+ range_key: item.get(range_key),
+ }
+ for item in items_to_delete
+ ]
+ delete_requests = [{"DeleteRequest": {"Key": key}} for key in keys_to_delete]
+ for i in range(0, len(delete_requests), 25):
+ batch = delete_requests[i : i + 25]
+ client.batch_write_item(RequestItems={table_name: batch})
+ except Exception as e:
+ LOG.warning(
+ "An error occurred when deleting expired items from table %s: %s",
+ table_name,
+ e,
+ )
+ return no_expired_items
+
+
+def _get_hash_and_range_key(table_description: DescribeTableOutput) -> [str, str | None]:
+ key_schema = table_description.get("Table", {}).get("KeySchema", [])
+ hash_key, range_key = None, None
+ for key in key_schema:
+ if key["KeyType"] == "HASH":
+ hash_key = key["AttributeName"]
+ if key["KeyType"] == "RANGE":
+ range_key = key["AttributeName"]
+ return hash_key, range_key
+
+
+class ExpiredItemsWorker:
+ """A worker that periodically computes and deletes expired items from DynamoDB tables"""
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.scheduler = Scheduler()
+ self.thread: Optional[FuncThread] = None
+ self.mutex = threading.RLock()
+
+ def start(self):
+ with self.mutex:
+ if self.thread:
+ return
+
+ self.scheduler = Scheduler()
+ self.scheduler.schedule(
+ delete_expired_items, period=60 * 60
+ ) # the background process seems slow on AWS
+
+ def _run(*_args):
+ self.scheduler.run()
+
+ self.thread = start_thread(_run, name="ddb-remove-expired-items")
+
+ def stop(self):
+ with self.mutex:
+ if self.scheduler:
+ self.scheduler.close()
+
+ if self.thread:
+ self.thread.stop()
+
+ self.thread = None
+ self.scheduler = None
+
+
+class DynamoDBProvider(DynamodbApi, ServiceLifecycleHook):
+ server: DynamodbServer
+ """The instance of the server managing the instance of DynamoDB local"""
+
+ def __init__(self):
+ self.server = self._new_dynamodb_server()
+ self._expired_items_worker = ExpiredItemsWorker()
+ self._router_rules = []
+
+ def on_before_start(self):
+ self.server.start_dynamodb()
+ if config.DYNAMODB_REMOVE_EXPIRED_ITEMS:
+ self._expired_items_worker.start()
+ self._router_rules = ROUTER.add(DynamoDBDeveloperEndpoints())
+
+ def on_before_stop(self):
+ self._expired_items_worker.stop()
+ ROUTER.remove(self._router_rules)
+
+ def accept_state_visitor(self, visitor: StateVisitor):
+ visitor.visit(dynamodb_stores)
+ visitor.visit(dynamodbstreams_stores)
+ visitor.visit(AssetDirectory(self.service, os.path.join(config.dirs.data, self.service)))
+
+ def on_before_state_reset(self):
+ self.server.stop_dynamodb()
+
+ def on_before_state_load(self):
+ self.server.stop_dynamodb()
+
+ def on_after_state_reset(self):
+ self.server.start_dynamodb()
+
+ @staticmethod
+ def _new_dynamodb_server() -> DynamodbServer:
+ return DynamodbServer.get()
+
+ def on_after_state_load(self):
+ self.server.start_dynamodb()
+
+ def on_after_init(self):
+ # add response processor specific to ddblocal
+ handlers.modify_service_response.append(self.service, modify_ddblocal_arns)
+
+ # routes for the shell ui
+ ROUTER.add(
+ path="/shell",
+ endpoint=self.handle_shell_ui_redirect,
+ methods=["GET"],
+ )
+ ROUTER.add(
+ path="/shell/",
+ endpoint=self.handle_shell_ui_request,
+ )
+
+ def _forward_request(
+ self,
+ context: RequestContext,
+ region: str | None,
+ service_request: ServiceRequest | None = None,
+ ) -> ServiceResponse:
+ """
+ Modify the context region and then forward request to DynamoDB Local.
+
+ This is used for operations impacted by global tables. In LocalStack, a single copy of global table
+ is kept, and any requests to replicated tables are forwarded to this original table.
+ """
+ if region:
+ with modify_context_region(context, region):
+ return self.forward_request(context, service_request=service_request)
+ return self.forward_request(context, service_request=service_request)
+
+ def forward_request(
+ self, context: RequestContext, service_request: ServiceRequest = None
+ ) -> ServiceResponse:
+ """
+ Forward a request to DynamoDB Local.
+ """
+ self.check_provisioned_throughput(context.operation.name)
+ self.prepare_request_headers(
+ context.request.headers, account_id=context.account_id, region_name=context.region
+ )
+ return self.server.proxy(context, service_request)
+
+ def get_forward_url(self, account_id: str, region_name: str) -> str:
+ """Return the URL of the backend DynamoDBLocal server to forward requests to"""
+ return self.server.url
+
+ def handle_shell_ui_redirect(self, request: werkzeug.Request) -> Response:
+ headers = {"Refresh": f"0; url={config.external_service_url()}/shell/index.html"}
+ return Response("", headers=headers)
+
+ def handle_shell_ui_request(self, request: werkzeug.Request, req_path: str) -> Response:
+ # TODO: "DynamoDB Local Web Shell was deprecated with version 1.16.X and is not available any
+ # longer from 1.17.X to latest. There are no immediate plans for a new Web Shell to be introduced."
+ # -> keeping this for now, to allow configuring custom installs; should consider removing it in the future
+ # https://repost.aws/questions/QUHyIzoEDqQ3iOKlUEp1LPWQ#ANdBm9Nz9TRf6VqR3jZtcA1g
+ req_path = f"/{req_path}" if not req_path.startswith("/") else req_path
+ account_id = extract_account_id_from_headers(request.headers)
+ region_name = extract_region_from_headers(request.headers)
+ url = f"{self.get_forward_url(account_id, region_name)}/shell{req_path}"
+ result = requests.request(
+ method=request.method, url=url, headers=request.headers, data=request.data
+ )
+ return Response(result.content, headers=dict(result.headers), status=result.status_code)
+
+ #
+ # Table ops
+ #
+
+ @handler("CreateTable", expand=False)
+ def create_table(
+ self,
+ context: RequestContext,
+ create_table_input: CreateTableInput,
+ ) -> CreateTableOutput:
+ table_name = create_table_input["TableName"]
+
+ # Return this specific error message to keep parity with AWS
+ if self.table_exists(context.account_id, context.region, table_name):
+ raise ResourceInUseException(f"Table already exists: {table_name}")
+
+ billing_mode = create_table_input.get("BillingMode")
+ provisioned_throughput = create_table_input.get("ProvisionedThroughput")
+ if billing_mode == BillingMode.PAY_PER_REQUEST and provisioned_throughput is not None:
+ raise ValidationException(
+ "One or more parameter values were invalid: Neither ReadCapacityUnits nor WriteCapacityUnits can be "
+ "specified when BillingMode is PAY_PER_REQUEST"
+ )
+
+ result = self.forward_request(context)
+
+ table_description = result["TableDescription"]
+ table_description["TableArn"] = table_arn = self.fix_table_arn(
+ context.account_id, context.region, table_description["TableArn"]
+ )
+
+ backend = get_store(context.account_id, context.region)
+ backend.table_definitions[table_name] = table_definitions = dict(create_table_input)
+ backend.TABLE_REGION[table_name] = context.region
+
+ if "TableId" not in table_definitions:
+ table_definitions["TableId"] = long_uid()
+
+ if "SSESpecification" in table_definitions:
+ sse_specification = table_definitions.pop("SSESpecification")
+ table_definitions["SSEDescription"] = SSEUtils.get_sse_description(
+ context.account_id, context.region, sse_specification
+ )
+
+ if table_definitions:
+ table_content = result.get("Table", {})
+ table_content.update(table_definitions)
+ table_description.update(table_content)
+
+ if "TableClass" in table_definitions:
+ table_class = table_description.pop("TableClass", None) or table_definitions.pop(
+ "TableClass"
+ )
+ table_description["TableClassSummary"] = {"TableClass": table_class}
+
+ if "GlobalSecondaryIndexes" in table_description:
+ gsis = copy.deepcopy(table_description["GlobalSecondaryIndexes"])
+ # update the different values, as DynamoDB-local v2 has a regression around GSI and does not return anything
+ # anymore
+ for gsi in gsis:
+ index_name = gsi.get("IndexName", "")
+ gsi.update(
+ {
+ "IndexArn": f"{table_arn}/index/{index_name}",
+ "IndexSizeBytes": 0,
+ "IndexStatus": "ACTIVE",
+ "ItemCount": 0,
+ }
+ )
+ gsi_provisioned_throughput = gsi.setdefault("ProvisionedThroughput", {})
+ gsi_provisioned_throughput["NumberOfDecreasesToday"] = 0
+
+ if billing_mode == BillingMode.PAY_PER_REQUEST:
+ gsi_provisioned_throughput["ReadCapacityUnits"] = 0
+ gsi_provisioned_throughput["WriteCapacityUnits"] = 0
+
+ # table_definitions["GlobalSecondaryIndexes"] = gsis
+ table_description["GlobalSecondaryIndexes"] = gsis
+
+ if "ProvisionedThroughput" in table_description:
+ if "NumberOfDecreasesToday" not in table_description["ProvisionedThroughput"]:
+ table_description["ProvisionedThroughput"]["NumberOfDecreasesToday"] = 0
+
+ tags = table_definitions.pop("Tags", [])
+ if tags:
+ get_store(context.account_id, context.region).TABLE_TAGS[table_arn] = {
+ tag["Key"]: tag["Value"] for tag in tags
+ }
+
+ # remove invalid attributes from result
+ table_description.pop("Tags", None)
+ table_description.pop("BillingMode", None)
+
+ return result
+
+ def delete_table(
+ self, context: RequestContext, table_name: TableName, **kwargs
+ ) -> DeleteTableOutput:
+ global_table_region = self.get_global_table_region(context, table_name)
+
+ # Limitation note: On AWS, for a replicated table, if the source table is deleted, the replicated tables continue to exist.
+ # This is not the case for LocalStack, where all replicated tables will also be removed if source is deleted.
+
+ result = self._forward_request(context=context, region=global_table_region)
+
+ table_arn = result.get("TableDescription", {}).get("TableArn")
+ table_arn = self.fix_table_arn(context.account_id, context.region, table_arn)
+
+ store = get_store(context.account_id, context.region)
+ store.TABLE_TAGS.pop(table_arn, None)
+ store.REPLICAS.pop(table_name, None)
+
+ return result
+
+ def describe_table(
+ self, context: RequestContext, table_name: TableName, **kwargs
+ ) -> DescribeTableOutput:
+ global_table_region = self.get_global_table_region(context, table_name)
+
+ result = self._forward_request(context=context, region=global_table_region)
+ table_description: TableDescription = result["Table"]
+
+ # Update table properties from LocalStack stores
+ if table_props := get_store(context.account_id, context.region).table_properties.get(
+ table_name
+ ):
+ table_description.update(table_props)
+
+ store = get_store(context.account_id, context.region)
+
+ # Update replication details
+ replicas: Dict[RegionName, ReplicaDescription] = store.REPLICAS.get(table_name, {})
+
+ replica_description_list = []
+
+ if global_table_region != context.region:
+ replica_description_list.append(
+ ReplicaDescription(
+ RegionName=global_table_region, ReplicaStatus=ReplicaStatus.ACTIVE
+ )
+ )
+
+ for replica_region, replica_description in replicas.items():
+ # The replica in the region being queried must not be returned
+ if replica_region != context.region:
+ replica_description_list.append(replica_description)
+
+ if replica_description_list:
+ table_description.update({"Replicas": replica_description_list})
+
+ # update only TableId and SSEDescription if present
+ if table_definitions := store.table_definitions.get(table_name):
+ for key in ["TableId", "SSEDescription"]:
+ if table_definitions.get(key):
+ table_description[key] = table_definitions[key]
+ if "TableClass" in table_definitions:
+ table_description["TableClassSummary"] = {
+ "TableClass": table_definitions["TableClass"]
+ }
+
+ if "GlobalSecondaryIndexes" in table_description:
+ for gsi in table_description["GlobalSecondaryIndexes"]:
+ default_values = {
+ "NumberOfDecreasesToday": 0,
+ "ReadCapacityUnits": 0,
+ "WriteCapacityUnits": 0,
+ }
+ # even if the billing mode is PAY_PER_REQUEST, AWS returns the Read and Write Capacity Units
+ # Terraform depends on this parity for update operations
+ gsi["ProvisionedThroughput"] = default_values | gsi.get("ProvisionedThroughput", {})
+
+ return DescribeTableOutput(
+ Table=select_from_typed_dict(TableDescription, table_description)
+ )
+
+ @handler("UpdateTable", expand=False)
+ def update_table(
+ self, context: RequestContext, update_table_input: UpdateTableInput
+ ) -> UpdateTableOutput:
+ table_name = update_table_input["TableName"]
+ global_table_region = self.get_global_table_region(context, table_name)
+
+ try:
+ self._forward_request(context=context, region=global_table_region)
+ except CommonServiceException as exc:
+ # DynamoDBLocal refuses to update certain table params and raises.
+ # But we still need to update this info in LocalStack stores
+ if not (exc.code == "ValidationException" and exc.message == "Nothing to update"):
+ raise
+
+ if table_class := update_table_input.get("TableClass"):
+ table_definitions = get_store(
+ context.account_id, context.region
+ ).table_definitions.setdefault(table_name, {})
+ table_definitions["TableClass"] = table_class
+
+ if replica_updates := update_table_input.get("ReplicaUpdates"):
+ store = get_store(context.account_id, global_table_region)
+
+ # Dict with source region to set of replicated regions
+ replicas: Dict[RegionName, ReplicaDescription] = store.REPLICAS.get(table_name, {})
+
+ for replica_update in replica_updates:
+ for key, details in replica_update.items():
+ # Replicated region
+ target_region = details.get("RegionName")
+
+ # Check if replicated region is valid
+ if target_region not in get_valid_regions_for_service("dynamodb"):
+ raise ValidationException(f"Region {target_region} is not supported")
+
+ match key:
+ case "Create":
+ if target_region in replicas.keys():
+ raise ValidationException(
+ f"Failed to create a the new replica of table with name: '{table_name}' because one or more replicas already existed as tables."
+ )
+ replicas[target_region] = ReplicaDescription(
+ RegionName=target_region,
+ KMSMasterKeyId=details.get("KMSMasterKeyId"),
+ ProvisionedThroughputOverride=details.get(
+ "ProvisionedThroughputOverride"
+ ),
+ GlobalSecondaryIndexes=details.get("GlobalSecondaryIndexes"),
+ ReplicaStatus=ReplicaStatus.ACTIVE,
+ )
+ case "Delete":
+ try:
+ replicas.pop(target_region)
+ except KeyError:
+ raise ValidationException(
+ "Update global table operation failed because one or more replicas were not part of the global table."
+ )
+
+ store.REPLICAS[table_name] = replicas
+
+ # update response content
+ SchemaExtractor.invalidate_table_schema(
+ table_name, context.account_id, global_table_region
+ )
+
+ schema = SchemaExtractor.get_table_schema(
+ table_name, context.account_id, global_table_region
+ )
+
+ if sse_specification_input := update_table_input.get("SSESpecification"):
+ # If SSESpecification is changed, update store and return the 'UPDATING' status in the response
+ table_definition = get_store(
+ context.account_id, context.region
+ ).table_definitions.setdefault(table_name, {})
+ if not sse_specification_input["Enabled"]:
+ table_definition.pop("SSEDescription", None)
+ schema["Table"]["SSEDescription"]["Status"] = "UPDATING"
+
+ return UpdateTableOutput(TableDescription=schema["Table"])
+
+ SchemaExtractor.invalidate_table_schema(table_name, context.account_id, global_table_region)
+
+ schema = SchemaExtractor.get_table_schema(
+ table_name, context.account_id, global_table_region
+ )
+
+ return UpdateTableOutput(TableDescription=schema["Table"])
+
+ def list_tables(
+ self,
+ context: RequestContext,
+ exclusive_start_table_name: TableName = None,
+ limit: ListTablesInputLimit = None,
+ **kwargs,
+ ) -> ListTablesOutput:
+ response = self.forward_request(context)
+
+ # Add replicated tables
+ replicas = get_store(context.account_id, context.region).REPLICAS
+ for replicated_table, replications in replicas.items():
+ for replica_region, replica_description in replications.items():
+ if context.region == replica_region:
+ response["TableNames"].append(replicated_table)
+
+ return response
+
+ #
+ # Item ops
+ #
+
+ @handler("PutItem", expand=False)
+ def put_item(self, context: RequestContext, put_item_input: PutItemInput) -> PutItemOutput:
+ table_name = put_item_input["TableName"]
+ global_table_region = self.get_global_table_region(context, table_name)
+
+ return self._forward_request(context=context, region=global_table_region)
+
+ @handler("DeleteItem", expand=False)
+ def delete_item(
+ self,
+ context: RequestContext,
+ delete_item_input: DeleteItemInput,
+ ) -> DeleteItemOutput:
+ table_name = delete_item_input["TableName"]
+ global_table_region = self.get_global_table_region(context, table_name)
+
+ return self._forward_request(context=context, region=global_table_region)
+
+ @handler("UpdateItem", expand=False)
+ def update_item(
+ self,
+ context: RequestContext,
+ update_item_input: UpdateItemInput,
+ ) -> UpdateItemOutput:
+ # TODO: UpdateItem is harder to use ReturnValues for Streams, because it needs the Before and After images.
+ table_name = update_item_input["TableName"]
+ global_table_region = self.get_global_table_region(context, table_name)
+
+ return self._forward_request(context=context, region=global_table_region)
+
+ @handler("GetItem", expand=False)
+ def get_item(self, context: RequestContext, get_item_input: GetItemInput) -> GetItemOutput:
+ table_name = get_item_input["TableName"]
+ global_table_region = self.get_global_table_region(context, table_name)
+ result = self._forward_request(context=context, region=global_table_region)
+ self.fix_consumed_capacity(get_item_input, result)
+ return result
+
+ #
+ # Queries
+ #
+
+ @handler("Query", expand=False)
+ def query(self, context: RequestContext, query_input: QueryInput) -> QueryOutput:
+ index_name = query_input.get("IndexName")
+ if index_name:
+ if not is_index_query_valid(context.account_id, context.region, query_input):
+ raise ValidationException(
+ "One or more parameter values were invalid: Select type ALL_ATTRIBUTES "
+ "is not supported for global secondary index id-index because its projection "
+ "type is not ALL",
+ )
+
+ table_name = query_input["TableName"]
+ global_table_region = self.get_global_table_region(context, table_name)
+ result = self._forward_request(context=context, region=global_table_region)
+ self.fix_consumed_capacity(query_input, result)
+ return result
+
+ @handler("Scan", expand=False)
+ def scan(self, context: RequestContext, scan_input: ScanInput) -> ScanOutput:
+ table_name = scan_input["TableName"]
+ global_table_region = self.get_global_table_region(context, table_name)
+ result = self._forward_request(context=context, region=global_table_region)
+ return result
+
+ #
+ # Batch ops
+ #
+
+ @handler("BatchWriteItem", expand=False)
+ def batch_write_item(
+ self,
+ context: RequestContext,
+ batch_write_item_input: BatchWriteItemInput,
+ ) -> BatchWriteItemOutput:
+ # TODO: add global table support
+ # UnprocessedItems should have the same format as RequestItems
+ unprocessed_items = {}
+ request_items = batch_write_item_input["RequestItems"]
+
+ for table_name, items in sorted(request_items.items(), key=itemgetter(0)):
+ for request in items:
+ request: WriteRequest
+ for key, inner_request in request.items():
+ inner_request: PutRequest | DeleteRequest
+ if self.should_throttle("BatchWriteItem"):
+ unprocessed_items_for_table = unprocessed_items.setdefault(table_name, [])
+ unprocessed_items_for_table.append(request)
+
+ try:
+ result = self.forward_request(context)
+ except CommonServiceException as e:
+ # TODO: validate if DynamoDB still raises `One of the required keys was not given a value`
+ # for now, replace with the schema error validation
+ if e.message == "One of the required keys was not given a value":
+ raise ValidationException("The provided key element does not match the schema")
+ raise e
+
+ # TODO: should unprocessed item which have mutated by `prepare_batch_write_item_records` be returned
+ for table_name, unprocessed_items_in_table in unprocessed_items.items():
+ unprocessed: dict = result["UnprocessedItems"]
+ result_unprocessed_table = unprocessed.setdefault(table_name, [])
+
+ # add the Unprocessed items to the response
+ # TODO: check before if the same request has not been Unprocessed by DDB local already?
+ # those might actually have been processed? shouldn't we remove them from the proxied request?
+ for request in unprocessed_items_in_table:
+ result_unprocessed_table.append(request)
+
+ # remove any table entry if it's empty
+ result["UnprocessedItems"] = {k: v for k, v in unprocessed.items() if v}
+
+ return result
+
+ @handler("BatchGetItem")
+ def batch_get_item(
+ self,
+ context: RequestContext,
+ request_items: BatchGetRequestMap,
+ return_consumed_capacity: ReturnConsumedCapacity = None,
+ **kwargs,
+ ) -> BatchGetItemOutput:
+ # TODO: add global table support
+ return self.forward_request(context)
+
+ #
+ # Transactions
+ #
+
+ @handler("TransactWriteItems", expand=False)
+ def transact_write_items(
+ self,
+ context: RequestContext,
+ transact_write_items_input: TransactWriteItemsInput,
+ ) -> TransactWriteItemsOutput:
+ # TODO: add global table support
+ client_token: str | None = transact_write_items_input.get("ClientRequestToken")
+
+ if client_token:
+ # we sort the payload since identical payload but with different order could cause
+ # IdempotentParameterMismatchException error if a client token is provided
+ context.request.data = to_bytes(canonical_json(json.loads(context.request.data)))
+
+ return self.forward_request(context)
+
+ @handler("TransactGetItems", expand=False)
+ def transact_get_items(
+ self,
+ context: RequestContext,
+ transact_items: TransactGetItemList,
+ return_consumed_capacity: ReturnConsumedCapacity = None,
+ ) -> TransactGetItemsOutput:
+ return self.forward_request(context)
+
+ @handler("ExecuteTransaction", expand=False)
+ def execute_transaction(
+ self, context: RequestContext, execute_transaction_input: ExecuteTransactionInput
+ ) -> ExecuteTransactionOutput:
+ result = self.forward_request(context)
+ return result
+
+ @handler("ExecuteStatement", expand=False)
+ def execute_statement(
+ self,
+ context: RequestContext,
+ execute_statement_input: ExecuteStatementInput,
+ ) -> ExecuteStatementOutput:
+ # TODO: this operation is still really slow with streams enabled
+ # find a way to make it better, same way as the other operations, by using returnvalues
+ # see https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/ql-reference.update.html
+
+ # We found out that 'Parameters' can be an empty list when the request comes from the AWS JS client.
+ if execute_statement_input.get("Parameters", None) == []: # noqa
+ raise ValidationException(
+ "1 validation error detected: Value '[]' at 'parameters' failed to satisfy constraint: Member must have length greater than or equal to 1"
+ )
+ return self.forward_request(context)
+
+ #
+ # Tags
+ #
+
+ def tag_resource(
+ self, context: RequestContext, resource_arn: ResourceArnString, tags: TagList, **kwargs
+ ) -> None:
+ table_tags = get_store(context.account_id, context.region).TABLE_TAGS
+ if resource_arn not in table_tags:
+ table_tags[resource_arn] = {}
+ table_tags[resource_arn].update({tag["Key"]: tag["Value"] for tag in tags})
+
+ def untag_resource(
+ self,
+ context: RequestContext,
+ resource_arn: ResourceArnString,
+ tag_keys: TagKeyList,
+ **kwargs,
+ ) -> None:
+ for tag_key in tag_keys or []:
+ get_store(context.account_id, context.region).TABLE_TAGS.get(resource_arn, {}).pop(
+ tag_key, None
+ )
+
+ def list_tags_of_resource(
+ self,
+ context: RequestContext,
+ resource_arn: ResourceArnString,
+ next_token: NextTokenString = None,
+ **kwargs,
+ ) -> ListTagsOfResourceOutput:
+ result = [
+ {"Key": k, "Value": v}
+ for k, v in get_store(context.account_id, context.region)
+ .TABLE_TAGS.get(resource_arn, {})
+ .items()
+ ]
+ return ListTagsOfResourceOutput(Tags=result)
+
+ #
+ # TTLs
+ #
+
+ def describe_time_to_live(
+ self, context: RequestContext, table_name: TableName, **kwargs
+ ) -> DescribeTimeToLiveOutput:
+ if not self.table_exists(context.account_id, context.region, table_name):
+ raise ResourceNotFoundException(
+ f"Requested resource not found: Table: {table_name} not found"
+ )
+
+ backend = get_store(context.account_id, context.region)
+ ttl_spec = backend.ttl_specifications.get(table_name)
+
+ result = {"TimeToLiveStatus": "DISABLED"}
+ if ttl_spec:
+ if ttl_spec.get("Enabled"):
+ ttl_status = "ENABLED"
+ else:
+ ttl_status = "DISABLED"
+ result = {
+ "AttributeName": ttl_spec.get("AttributeName"),
+ "TimeToLiveStatus": ttl_status,
+ }
+
+ return DescribeTimeToLiveOutput(TimeToLiveDescription=result)
+
+ def update_time_to_live(
+ self,
+ context: RequestContext,
+ table_name: TableName,
+ time_to_live_specification: TimeToLiveSpecification,
+ **kwargs,
+ ) -> UpdateTimeToLiveOutput:
+ if not self.table_exists(context.account_id, context.region, table_name):
+ raise ResourceNotFoundException(
+ f"Requested resource not found: Table: {table_name} not found"
+ )
+
+ # TODO: TTL status is maintained/mocked but no real expiry is happening for items
+ backend = get_store(context.account_id, context.region)
+ backend.ttl_specifications[table_name] = time_to_live_specification
+ return UpdateTimeToLiveOutput(TimeToLiveSpecification=time_to_live_specification)
+
+ #
+ # Global tables
+ #
+
+ def create_global_table(
+ self,
+ context: RequestContext,
+ global_table_name: TableName,
+ replication_group: ReplicaList,
+ **kwargs,
+ ) -> CreateGlobalTableOutput:
+ global_tables: Dict = get_store(context.account_id, context.region).GLOBAL_TABLES
+ if global_table_name in global_tables:
+ raise GlobalTableAlreadyExistsException("Global table with this name already exists")
+ replication_group = [grp.copy() for grp in replication_group or []]
+ data = {"GlobalTableName": global_table_name, "ReplicationGroup": replication_group}
+ global_tables[global_table_name] = data
+ for group in replication_group:
+ group["ReplicaStatus"] = "ACTIVE"
+ group["ReplicaStatusDescription"] = "Replica active"
+ return CreateGlobalTableOutput(GlobalTableDescription=data)
+
+ def describe_global_table(
+ self, context: RequestContext, global_table_name: TableName, **kwargs
+ ) -> DescribeGlobalTableOutput:
+ details = get_store(context.account_id, context.region).GLOBAL_TABLES.get(global_table_name)
+ if not details:
+ raise GlobalTableNotFoundException("Global table with this name does not exist")
+ return DescribeGlobalTableOutput(GlobalTableDescription=details)
+
+ def list_global_tables(
+ self,
+ context: RequestContext,
+ exclusive_start_global_table_name: TableName = None,
+ limit: PositiveIntegerObject = None,
+ region_name: RegionName = None,
+ **kwargs,
+ ) -> ListGlobalTablesOutput:
+ # TODO: add paging support
+ result = [
+ select_attributes(tab, ["GlobalTableName", "ReplicationGroup"])
+ for tab in get_store(context.account_id, context.region).GLOBAL_TABLES.values()
+ ]
+ return ListGlobalTablesOutput(GlobalTables=result)
+
+ def update_global_table(
+ self,
+ context: RequestContext,
+ global_table_name: TableName,
+ replica_updates: ReplicaUpdateList,
+ **kwargs,
+ ) -> UpdateGlobalTableOutput:
+ details = get_store(context.account_id, context.region).GLOBAL_TABLES.get(global_table_name)
+ if not details:
+ raise GlobalTableNotFoundException("Global table with this name does not exist")
+ for update in replica_updates or []:
+ repl_group = details["ReplicationGroup"]
+ # delete existing
+ delete = update.get("Delete")
+ if delete:
+ details["ReplicationGroup"] = [
+ g for g in repl_group if g["RegionName"] != delete["RegionName"]
+ ]
+ # create new
+ create = update.get("Create")
+ if create:
+ exists = [g for g in repl_group if g["RegionName"] == create["RegionName"]]
+ if exists:
+ continue
+ new_group = {
+ "RegionName": create["RegionName"],
+ "ReplicaStatus": "ACTIVE",
+ "ReplicaStatusDescription": "Replica active",
+ }
+ details["ReplicationGroup"].append(new_group)
+ return UpdateGlobalTableOutput(GlobalTableDescription=details)
+
+ #
+ # Kinesis Streaming
+ #
+
+ def enable_kinesis_streaming_destination(
+ self,
+ context: RequestContext,
+ table_name: TableName,
+ stream_arn: StreamArn,
+ enable_kinesis_streaming_configuration: EnableKinesisStreamingConfiguration = None,
+ **kwargs,
+ ) -> KinesisStreamingDestinationOutput:
+ self.ensure_table_exists(context.account_id, context.region, table_name)
+
+ if not kinesis_stream_exists(stream_arn=stream_arn):
+ raise ValidationException("User does not have a permission to use kinesis stream")
+
+ table_def = get_store(context.account_id, context.region).table_definitions.setdefault(
+ table_name, {}
+ )
+
+ dest_status = table_def.get("KinesisDataStreamDestinationStatus")
+ if dest_status not in ["DISABLED", "ENABLE_FAILED", None]:
+ raise ValidationException(
+ "Table is not in a valid state to enable Kinesis Streaming "
+ "Destination:EnableKinesisStreamingDestination must be DISABLED or ENABLE_FAILED "
+ "to perform ENABLE operation."
+ )
+
+ table_def["KinesisDataStreamDestinations"] = (
+ table_def.get("KinesisDataStreamDestinations") or []
+ )
+ # remove the stream destination if already present
+ table_def["KinesisDataStreamDestinations"] = [
+ t for t in table_def["KinesisDataStreamDestinations"] if t["StreamArn"] != stream_arn
+ ]
+ # append the active stream destination at the end of the list
+ table_def["KinesisDataStreamDestinations"].append(
+ {
+ "DestinationStatus": DestinationStatus.ACTIVE,
+ "DestinationStatusDescription": "Stream is active",
+ "StreamArn": stream_arn,
+ }
+ )
+ table_def["KinesisDataStreamDestinationStatus"] = DestinationStatus.ACTIVE
+ return KinesisStreamingDestinationOutput(
+ DestinationStatus=DestinationStatus.ACTIVE, StreamArn=stream_arn, TableName=table_name
+ )
+
+ def disable_kinesis_streaming_destination(
+ self,
+ context: RequestContext,
+ table_name: TableName,
+ stream_arn: StreamArn,
+ enable_kinesis_streaming_configuration: EnableKinesisStreamingConfiguration = None,
+ **kwargs,
+ ) -> KinesisStreamingDestinationOutput:
+ self.ensure_table_exists(context.account_id, context.region, table_name)
+ if not kinesis_stream_exists(stream_arn):
+ raise ValidationException(
+ "User does not have a permission to use kinesis stream",
+ )
+
+ table_def = get_store(context.account_id, context.region).table_definitions.setdefault(
+ table_name, {}
+ )
+
+ stream_destinations = table_def.get("KinesisDataStreamDestinations")
+ if stream_destinations:
+ if table_def["KinesisDataStreamDestinationStatus"] == DestinationStatus.ACTIVE:
+ for dest in stream_destinations:
+ if (
+ dest["StreamArn"] == stream_arn
+ and dest["DestinationStatus"] == DestinationStatus.ACTIVE
+ ):
+ dest["DestinationStatus"] = DestinationStatus.DISABLED
+ dest["DestinationStatusDescription"] = ("Stream is disabled",)
+ table_def["KinesisDataStreamDestinationStatus"] = DestinationStatus.DISABLED
+ return KinesisStreamingDestinationOutput(
+ DestinationStatus=DestinationStatus.DISABLED,
+ StreamArn=stream_arn,
+ TableName=table_name,
+ )
+ raise ValidationException(
+ "Table is not in a valid state to disable Kinesis Streaming Destination:"
+ "DisableKinesisStreamingDestination must be ACTIVE to perform DISABLE operation."
+ )
+
+ def describe_kinesis_streaming_destination(
+ self, context: RequestContext, table_name: TableName, **kwargs
+ ) -> DescribeKinesisStreamingDestinationOutput:
+ self.ensure_table_exists(context.account_id, context.region, table_name)
+
+ table_def = (
+ get_store(context.account_id, context.region).table_definitions.get(table_name) or {}
+ )
+
+ stream_destinations = table_def.get("KinesisDataStreamDestinations") or []
+ return DescribeKinesisStreamingDestinationOutput(
+ KinesisDataStreamDestinations=stream_destinations, TableName=table_name
+ )
+
+ #
+ # Continuous Backups
+ #
+
+ def describe_continuous_backups(
+ self, context: RequestContext, table_name: TableName, **kwargs
+ ) -> DescribeContinuousBackupsOutput:
+ self.get_global_table_region(context, table_name)
+ store = get_store(context.account_id, context.region)
+ continuous_backup_description = (
+ store.table_properties.get(table_name, {}).get("ContinuousBackupsDescription")
+ ) or ContinuousBackupsDescription(
+ ContinuousBackupsStatus=ContinuousBackupsStatus.ENABLED,
+ PointInTimeRecoveryDescription=PointInTimeRecoveryDescription(
+ PointInTimeRecoveryStatus=PointInTimeRecoveryStatus.DISABLED
+ ),
+ )
+
+ return DescribeContinuousBackupsOutput(
+ ContinuousBackupsDescription=continuous_backup_description
+ )
+
+ def update_continuous_backups(
+ self,
+ context: RequestContext,
+ table_name: TableName,
+ point_in_time_recovery_specification: PointInTimeRecoverySpecification,
+ **kwargs,
+ ) -> UpdateContinuousBackupsOutput:
+ self.get_global_table_region(context, table_name)
+
+ store = get_store(context.account_id, context.region)
+ pit_recovery_status = (
+ PointInTimeRecoveryStatus.ENABLED
+ if point_in_time_recovery_specification["PointInTimeRecoveryEnabled"]
+ else PointInTimeRecoveryStatus.DISABLED
+ )
+ continuous_backup_description = ContinuousBackupsDescription(
+ ContinuousBackupsStatus=ContinuousBackupsStatus.ENABLED,
+ PointInTimeRecoveryDescription=PointInTimeRecoveryDescription(
+ PointInTimeRecoveryStatus=pit_recovery_status
+ ),
+ )
+ table_props = store.table_properties.setdefault(table_name, {})
+ table_props["ContinuousBackupsDescription"] = continuous_backup_description
+
+ return UpdateContinuousBackupsOutput(
+ ContinuousBackupsDescription=continuous_backup_description
+ )
+
+ #
+ # Helpers
+ #
+
+ @staticmethod
+ def ddb_region_name(region_name: str) -> str:
+ """Map `local` or `localhost` region to the us-east-1 region. These values are used by NoSQL Workbench."""
+ # TODO: could this be somehow moved into the request handler chain?
+ if region_name in ("local", "localhost"):
+ region_name = AWS_REGION_US_EAST_1
+
+ return region_name
+
+ @staticmethod
+ def table_exists(account_id: str, region_name: str, table_name: str) -> bool:
+ region_name = DynamoDBProvider.ddb_region_name(region_name)
+
+ client = connect_to(
+ aws_access_key_id=account_id,
+ aws_secret_access_key=INTERNAL_AWS_SECRET_ACCESS_KEY,
+ region_name=region_name,
+ ).dynamodb
+ return dynamodb_table_exists(table_name, client)
+
+ @staticmethod
+ def ensure_table_exists(account_id: str, region_name: str, table_name: str):
+ """
+ Raise ResourceNotFoundException if the given table does not exist.
+
+ :param account_id: account id
+ :param region_name: region name
+ :param table_name: table name
+ :raise: ResourceNotFoundException if table does not exist in DynamoDB Local
+ """
+ if not DynamoDBProvider.table_exists(account_id, region_name, table_name):
+ raise ResourceNotFoundException("Cannot do operations on a non-existent table")
+
+ @staticmethod
+ def get_global_table_region(context: RequestContext, table_name: str) -> str:
+ """
+ Return the table region considering that it might be a replicated table.
+
+ Replication in LocalStack works by keeping a single copy of a table and forwarding
+ requests to the region where this table exists.
+
+ This method does not check whether the table actually exists in DDBLocal.
+
+ :param context: request context
+ :param table_name: table name
+ :return: region
+ """
+ store = get_store(context.account_id, context.region)
+
+ table_region = store.TABLE_REGION.get(table_name)
+ replicated_at = store.REPLICAS.get(table_name, {}).keys()
+
+ if context.region == table_region or context.region in replicated_at:
+ return table_region
+
+ return context.region
+
+ @staticmethod
+ def prepare_request_headers(headers: Dict, account_id: str, region_name: str):
+ """
+ Modify the Credentials field of Authorization header to achieve namespacing in DynamoDBLocal.
+ """
+ region_name = DynamoDBProvider.ddb_region_name(region_name)
+ key = get_ddb_access_key(account_id, region_name)
+
+ # DynamoDBLocal namespaces based on the value of Credentials
+ # Since we want to namespace by both account ID and region, use an aggregate key
+ # We also replace the region to keep compatibility with NoSQL Workbench
+ headers["Authorization"] = re.sub(
+ AUTH_CREDENTIAL_REGEX,
+ rf"Credential={key}/\2/{region_name}/\4/",
+ headers.get("Authorization") or "",
+ flags=re.IGNORECASE,
+ )
+
+ def fix_consumed_capacity(self, request: Dict, result: Dict):
+ # make sure we append 'ConsumedCapacity', which is properly
+ # returned by dynalite, but not by AWS's DynamoDBLocal
+ table_name = request.get("TableName")
+ return_cap = request.get("ReturnConsumedCapacity")
+ if "ConsumedCapacity" not in result and return_cap in ["TOTAL", "INDEXES"]:
+ request["ConsumedCapacity"] = {
+ "TableName": table_name,
+ "CapacityUnits": 5, # TODO hardcoded
+ "ReadCapacityUnits": 2,
+ "WriteCapacityUnits": 3,
+ }
+
+ def fix_table_arn(self, account_id: str, region_name: str, arn: str) -> str:
+ """
+ Set the correct account ID and region in ARNs returned by DynamoDB Local.
+ """
+ partition = get_partition(region_name)
+ return (
+ arn.replace("arn:aws:", f"arn:{partition}:")
+ .replace(":ddblocal:", f":{region_name}:")
+ .replace(":000000000000:", f":{account_id}:")
+ )
+
+ def batch_execute_statement(
+ self,
+ context: RequestContext,
+ statements: PartiQLBatchRequest,
+ return_consumed_capacity: ReturnConsumedCapacity = None,
+ **kwargs,
+ ) -> BatchExecuteStatementOutput:
+ result = self.forward_request(context)
+ return result
+
+ @staticmethod
+ def get_record_template(region_name: str, stream_view_type: str | None = None) -> StreamRecord:
+ record = {
+ "eventID": short_uid(),
+ "eventVersion": "1.1",
+ "dynamodb": {
+ # expects nearest second rounded down
+ "ApproximateCreationDateTime": int(time.time()),
+ "SizeBytes": -1,
+ },
+ "awsRegion": region_name,
+ "eventSource": "aws:dynamodb",
+ }
+ if stream_view_type:
+ record["dynamodb"]["StreamViewType"] = stream_view_type
+
+ return record
+
+ def check_provisioned_throughput(self, action):
+ """
+ Check rate limiting for an API operation and raise an error if provisioned throughput is exceeded.
+ """
+ if self.should_throttle(action):
+ message = (
+ "The level of configured provisioned throughput for the table was exceeded. "
+ + "Consider increasing your provisioning level with the UpdateTable API"
+ )
+ raise ProvisionedThroughputExceededException(message)
+
+ def action_should_throttle(self, action, actions):
+ throttled = [f"{ACTION_PREFIX}{a}" for a in actions]
+ return (action in throttled) or (action in actions)
+
+ def should_throttle(self, action):
+ if (
+ not config.DYNAMODB_READ_ERROR_PROBABILITY
+ and not config.DYNAMODB_ERROR_PROBABILITY
+ and not config.DYNAMODB_WRITE_ERROR_PROBABILITY
+ ):
+ # early exit so we don't need to call random()
+ return False
+
+ rand = random.random()
+ if rand < config.DYNAMODB_READ_ERROR_PROBABILITY and self.action_should_throttle(
+ action, READ_THROTTLED_ACTIONS
+ ):
+ return True
+ elif rand < config.DYNAMODB_WRITE_ERROR_PROBABILITY and self.action_should_throttle(
+ action, WRITE_THROTTLED_ACTIONS
+ ):
+ return True
+ elif rand < config.DYNAMODB_ERROR_PROBABILITY and self.action_should_throttle(
+ action, THROTTLED_ACTIONS
+ ):
+ return True
+ return False
+
+
+# ---
+# Misc. util functions
+# ---
+
+
+def get_global_secondary_index(account_id: str, region_name: str, table_name: str, index_name: str):
+ schema = SchemaExtractor.get_table_schema(table_name, account_id, region_name)
+ for index in schema["Table"].get("GlobalSecondaryIndexes", []):
+ if index["IndexName"] == index_name:
+ return index
+ raise ResourceNotFoundException("Index not found")
+
+
+def is_local_secondary_index(
+ account_id: str, region_name: str, table_name: str, index_name: str
+) -> bool:
+ schema = SchemaExtractor.get_table_schema(table_name, account_id, region_name)
+ for index in schema["Table"].get("LocalSecondaryIndexes", []):
+ if index["IndexName"] == index_name:
+ return True
+ return False
+
+
+def is_index_query_valid(account_id: str, region_name: str, query_data: dict) -> bool:
+ table_name = to_str(query_data["TableName"])
+ index_name = to_str(query_data["IndexName"])
+ if is_local_secondary_index(account_id, region_name, table_name, index_name):
+ return True
+ index_query_type = query_data.get("Select")
+ index = get_global_secondary_index(account_id, region_name, table_name, index_name)
+ index_projection_type = index.get("Projection").get("ProjectionType")
+ if index_query_type == "ALL_ATTRIBUTES" and index_projection_type != "ALL":
+ return False
+ return True
+
+
+def kinesis_stream_exists(stream_arn):
+ account_id = extract_account_id_from_arn(stream_arn)
+ region_name = extract_region_from_arn(stream_arn)
+
+ kinesis = connect_to(
+ aws_access_key_id=account_id,
+ aws_secret_access_key=INTERNAL_AWS_SECRET_ACCESS_KEY,
+ region_name=region_name,
+ ).kinesis
+ stream_name_from_arn = stream_arn.split("/", 1)[1]
+ # check if the stream exists in kinesis for the user
+ filtered = list(
+ filter(
+ lambda stream_name: stream_name == stream_name_from_arn,
+ kinesis.list_streams()["StreamNames"],
+ )
+ )
+ return bool(filtered)
diff --git a/localstack-core/localstack/services/dynamodbstreams/__init__.py b/localstack-core/localstack/services/dynamodbstreams/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/dynamodbstreams/dynamodbstreams_api.py b/localstack-core/localstack/services/dynamodbstreams/dynamodbstreams_api.py
new file mode 100644
index 0000000000000..84079dbbf3d6f
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodbstreams/dynamodbstreams_api.py
@@ -0,0 +1,213 @@
+import logging
+import threading
+from typing import TYPE_CHECKING, Dict
+
+from bson.json_util import dumps
+
+from localstack import config
+from localstack.aws.api.dynamodbstreams import StreamStatus, StreamViewType, TableName
+from localstack.aws.connect import connect_to
+from localstack.services.dynamodbstreams.models import DynamoDbStreamsStore, dynamodbstreams_stores
+from localstack.utils.aws import arns, resources
+from localstack.utils.common import now_utc
+from localstack.utils.threads import FuncThread
+
+if TYPE_CHECKING:
+ from mypy_boto3_kinesis import KinesisClient
+
+DDB_KINESIS_STREAM_NAME_PREFIX = "__ddb_stream_"
+
+LOG = logging.getLogger(__name__)
+
+_SEQUENCE_MTX = threading.RLock()
+_SEQUENCE_NUMBER_COUNTER = 1
+
+
+def get_dynamodbstreams_store(account_id: str, region: str) -> DynamoDbStreamsStore:
+ return dynamodbstreams_stores[account_id][region]
+
+
+def get_and_increment_sequence_number_counter() -> int:
+ global _SEQUENCE_NUMBER_COUNTER
+ with _SEQUENCE_MTX:
+ cnt = _SEQUENCE_NUMBER_COUNTER
+ _SEQUENCE_NUMBER_COUNTER += 1
+ return cnt
+
+
+def get_kinesis_client(account_id: str, region_name: str) -> "KinesisClient":
+ # specifically specify endpoint url here to ensure we always hit the local kinesis instance
+ return connect_to(
+ aws_access_key_id=account_id,
+ region_name=region_name,
+ endpoint_url=config.internal_service_url(),
+ ).kinesis
+
+
+def add_dynamodb_stream(
+ account_id: str,
+ region_name: str,
+ table_name: str,
+ latest_stream_label: str | None = None,
+ view_type: StreamViewType = StreamViewType.NEW_AND_OLD_IMAGES,
+ enabled: bool = True,
+) -> None:
+ if not enabled:
+ return
+
+ store = get_dynamodbstreams_store(account_id, region_name)
+ # create kinesis stream as a backend
+ stream_name = get_kinesis_stream_name(table_name)
+ resources.create_kinesis_stream(
+ get_kinesis_client(account_id, region_name),
+ stream_name=stream_name,
+ )
+ latest_stream_label = latest_stream_label or "latest"
+ stream = {
+ "StreamArn": arns.dynamodb_stream_arn(
+ table_name=table_name,
+ latest_stream_label=latest_stream_label,
+ account_id=account_id,
+ region_name=region_name,
+ ),
+ "TableName": table_name,
+ "StreamLabel": latest_stream_label,
+ "StreamStatus": StreamStatus.ENABLING,
+ "KeySchema": [],
+ "Shards": [],
+ "StreamViewType": view_type,
+ "shards_id_map": {},
+ }
+ store.ddb_streams[table_name] = stream
+
+
+def get_stream_for_table(account_id: str, region_name: str, table_arn: str) -> dict:
+ store = get_dynamodbstreams_store(account_id, region_name)
+ table_name = table_name_from_stream_arn(table_arn)
+ return store.ddb_streams.get(table_name)
+
+
+def _process_forwarded_records(
+ account_id: str, region_name: str, table_name: TableName, table_records: dict, kinesis
+) -> None:
+ records = table_records["records"]
+ stream_type = table_records["table_stream_type"]
+ # if the table does not have a DynamoDB Streams enabled, skip publishing anything
+ if not stream_type.stream_view_type:
+ return
+
+ # in this case, Kinesis forces the record to have both OldImage and NewImage, so we need to filter it
+ # as the settings are different for DDB Streams and Kinesis
+ if stream_type.is_kinesis and stream_type.stream_view_type != StreamViewType.NEW_AND_OLD_IMAGES:
+ kinesis_records = []
+
+ # StreamViewType determines what information is written to the stream for the table
+ # When an item in the table is inserted, updated or deleted
+ image_filter = set()
+ if stream_type.stream_view_type == StreamViewType.KEYS_ONLY:
+ image_filter = {"OldImage", "NewImage"}
+ elif stream_type.stream_view_type == StreamViewType.OLD_IMAGE:
+ image_filter = {"NewImage"}
+ elif stream_type.stream_view_type == StreamViewType.NEW_IMAGE:
+ image_filter = {"OldImage"}
+
+ for record in records:
+ record["dynamodb"] = {
+ k: v for k, v in record["dynamodb"].items() if k not in image_filter
+ }
+
+ if "SequenceNumber" not in record["dynamodb"]:
+ record["dynamodb"]["SequenceNumber"] = str(
+ get_and_increment_sequence_number_counter()
+ )
+
+ kinesis_records.append({"Data": dumps(record), "PartitionKey": "TODO"})
+
+ else:
+ kinesis_records = []
+ for record in records:
+ if "SequenceNumber" not in record["dynamodb"]:
+ # we can mutate the record for SequenceNumber, the Kinesis forwarding takes care of filtering it
+ record["dynamodb"]["SequenceNumber"] = str(
+ get_and_increment_sequence_number_counter()
+ )
+
+ # simply pass along the records, they already have the right format
+ kinesis_records.append({"Data": dumps(record), "PartitionKey": "TODO"})
+
+ stream_name = get_kinesis_stream_name(table_name)
+ kinesis.put_records(
+ StreamName=stream_name,
+ Records=kinesis_records,
+ )
+
+
+def forward_events(account_id: str, region_name: str, records_map: dict[TableName, dict]) -> None:
+ kinesis = get_kinesis_client(account_id, region_name)
+
+ for table_name, table_records in records_map.items():
+ _process_forwarded_records(account_id, region_name, table_name, table_records, kinesis)
+
+
+def delete_streams(account_id: str, region_name: str, table_arn: str) -> None:
+ store = get_dynamodbstreams_store(account_id, region_name)
+ table_name = table_name_from_table_arn(table_arn)
+ if store.ddb_streams.pop(table_name, None):
+ stream_name = get_kinesis_stream_name(table_name)
+ # stream_arn = stream["StreamArn"]
+
+ # we're basically asynchronously trying to delete the stream, or should we do this "synchronous" with the table
+ # deletion?
+ def _delete_stream(*args, **kwargs):
+ try:
+ kinesis_client = get_kinesis_client(account_id, region_name)
+ # needs to be active otherwise we can't delete it
+ kinesis_client.get_waiter("stream_exists").wait(StreamName=stream_name)
+ kinesis_client.delete_stream(StreamName=stream_name, EnforceConsumerDeletion=True)
+ kinesis_client.get_waiter("stream_not_exists").wait(StreamName=stream_name)
+ except Exception:
+ LOG.warning(
+ "Failed to delete underlying kinesis stream for dynamodb table table_arn=%s",
+ table_arn,
+ exc_info=LOG.isEnabledFor(logging.DEBUG),
+ )
+
+ FuncThread(_delete_stream).start() # fire & forget
+
+
+def get_kinesis_stream_name(table_name: str) -> str:
+ return DDB_KINESIS_STREAM_NAME_PREFIX + table_name
+
+
+def table_name_from_stream_arn(stream_arn: str) -> str:
+ return stream_arn.split(":table/", 1)[-1].split("/")[0]
+
+
+def table_name_from_table_arn(table_arn: str) -> str:
+ return table_name_from_stream_arn(table_arn)
+
+
+def stream_name_from_stream_arn(stream_arn: str) -> str:
+ table_name = table_name_from_stream_arn(stream_arn)
+ return get_kinesis_stream_name(table_name)
+
+
+def shard_id(kinesis_shard_id: str) -> str:
+ timestamp = str(int(now_utc()))
+ timestamp = f"{timestamp[:-5]}00000000".rjust(20, "0")
+ kinesis_shard_params = kinesis_shard_id.split("-")
+ return f"{kinesis_shard_params[0]}-{timestamp}-{kinesis_shard_params[-1][:32]}"
+
+
+def kinesis_shard_id(dynamodbstream_shard_id: str) -> str:
+ shard_params = dynamodbstream_shard_id.rsplit("-")
+ return f"{shard_params[0]}-{shard_params[-1]}"
+
+
+def get_shard_id(stream: Dict, kinesis_shard_id: str) -> str:
+ ddb_stream_shard_id = stream.get("shards_id_map", {}).get(kinesis_shard_id)
+ if not ddb_stream_shard_id:
+ ddb_stream_shard_id = shard_id(kinesis_shard_id)
+ stream["shards_id_map"][kinesis_shard_id] = ddb_stream_shard_id
+
+ return ddb_stream_shard_id
diff --git a/localstack-core/localstack/services/dynamodbstreams/models.py b/localstack-core/localstack/services/dynamodbstreams/models.py
new file mode 100644
index 0000000000000..a8a6672babf11
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodbstreams/models.py
@@ -0,0 +1,11 @@
+from typing import Dict
+
+from localstack.services.stores import AccountRegionBundle, BaseStore, LocalAttribute
+
+
+class DynamoDbStreamsStore(BaseStore):
+ # maps table names to DynamoDB stream descriptions
+ ddb_streams: Dict[str, dict] = LocalAttribute(default=dict)
+
+
+dynamodbstreams_stores = AccountRegionBundle("dynamodbstreams", DynamoDbStreamsStore)
diff --git a/localstack-core/localstack/services/dynamodbstreams/provider.py b/localstack-core/localstack/services/dynamodbstreams/provider.py
new file mode 100644
index 0000000000000..fc8d0050c4ea6
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodbstreams/provider.py
@@ -0,0 +1,155 @@
+import copy
+import logging
+
+from bson.json_util import loads
+
+from localstack.aws.api import RequestContext, handler
+from localstack.aws.api.dynamodbstreams import (
+ DescribeStreamOutput,
+ DynamodbstreamsApi,
+ ExpiredIteratorException,
+ GetRecordsInput,
+ GetRecordsOutput,
+ GetShardIteratorOutput,
+ ListStreamsOutput,
+ PositiveIntegerObject,
+ ResourceNotFoundException,
+ SequenceNumber,
+ ShardId,
+ ShardIteratorType,
+ Stream,
+ StreamArn,
+ StreamDescription,
+ StreamStatus,
+ TableName,
+)
+from localstack.aws.connect import connect_to
+from localstack.services.dynamodbstreams.dynamodbstreams_api import (
+ get_dynamodbstreams_store,
+ get_kinesis_client,
+ get_kinesis_stream_name,
+ get_shard_id,
+ kinesis_shard_id,
+ stream_name_from_stream_arn,
+ table_name_from_stream_arn,
+)
+from localstack.services.plugins import ServiceLifecycleHook
+from localstack.utils.collections import select_from_typed_dict
+
+LOG = logging.getLogger(__name__)
+
+STREAM_STATUS_MAP = {
+ "ACTIVE": StreamStatus.ENABLED,
+ "CREATING": StreamStatus.ENABLING,
+ "DELETING": StreamStatus.DISABLING,
+ "UPDATING": StreamStatus.ENABLING,
+}
+
+
+class DynamoDBStreamsProvider(DynamodbstreamsApi, ServiceLifecycleHook):
+ def describe_stream(
+ self,
+ context: RequestContext,
+ stream_arn: StreamArn,
+ limit: PositiveIntegerObject = None,
+ exclusive_start_shard_id: ShardId = None,
+ **kwargs,
+ ) -> DescribeStreamOutput:
+ store = get_dynamodbstreams_store(context.account_id, context.region)
+ kinesis = get_kinesis_client(account_id=context.account_id, region_name=context.region)
+ for stream in store.ddb_streams.values():
+ if stream["StreamArn"] == stream_arn:
+ # get stream details
+ dynamodb = connect_to(
+ aws_access_key_id=context.account_id, region_name=context.region
+ ).dynamodb
+ table_name = table_name_from_stream_arn(stream["StreamArn"])
+ stream_name = get_kinesis_stream_name(table_name)
+ stream_details = kinesis.describe_stream(StreamName=stream_name)
+ table_details = dynamodb.describe_table(TableName=table_name)
+ stream["KeySchema"] = table_details["Table"]["KeySchema"]
+ stream["StreamStatus"] = STREAM_STATUS_MAP.get(
+ stream_details["StreamDescription"]["StreamStatus"]
+ )
+
+ # Replace Kinesis ShardIDs with ones that mimic actual
+ # DynamoDBStream ShardIDs.
+ stream_shards = copy.deepcopy(stream_details["StreamDescription"]["Shards"])
+ start_index = 0
+ for index, shard in enumerate(stream_shards):
+ shard["ShardId"] = get_shard_id(stream, shard["ShardId"])
+ shard.pop("HashKeyRange", None)
+ # we want to ignore the shards before exclusive_start_shard_id parameters
+ # we store the index where we encounter then slice the shards
+ if exclusive_start_shard_id and exclusive_start_shard_id == shard["ShardId"]:
+ start_index = index
+
+ if exclusive_start_shard_id:
+ # slicing the resulting shards after the exclusive_start_shard_id parameters
+ stream_shards = stream_shards[start_index + 1 :]
+
+ stream["Shards"] = stream_shards
+ stream_description = select_from_typed_dict(StreamDescription, stream)
+ return DescribeStreamOutput(StreamDescription=stream_description)
+
+ raise ResourceNotFoundException(
+ f"Requested resource not found: Stream: {stream_arn} not found"
+ )
+
+ @handler("GetRecords", expand=False)
+ def get_records(self, context: RequestContext, payload: GetRecordsInput) -> GetRecordsOutput:
+ kinesis = get_kinesis_client(account_id=context.account_id, region_name=context.region)
+ prefix, _, payload["ShardIterator"] = payload["ShardIterator"].rpartition("|")
+ try:
+ kinesis_records = kinesis.get_records(**payload)
+ except kinesis.exceptions.ExpiredIteratorException:
+ LOG.debug("Shard iterator for underlying kinesis stream expired")
+ raise ExpiredIteratorException("Shard iterator has expired")
+ result = {
+ "Records": [],
+ "NextShardIterator": f"{prefix}|{kinesis_records.get('NextShardIterator')}",
+ }
+ for record in kinesis_records["Records"]:
+ record_data = loads(record["Data"])
+ record_data["dynamodb"]["SequenceNumber"] = record["SequenceNumber"]
+ result["Records"].append(record_data)
+ return GetRecordsOutput(**result)
+
+ def get_shard_iterator(
+ self,
+ context: RequestContext,
+ stream_arn: StreamArn,
+ shard_id: ShardId,
+ shard_iterator_type: ShardIteratorType,
+ sequence_number: SequenceNumber = None,
+ **kwargs,
+ ) -> GetShardIteratorOutput:
+ stream_name = stream_name_from_stream_arn(stream_arn)
+ stream_shard_id = kinesis_shard_id(shard_id)
+ kinesis = get_kinesis_client(account_id=context.account_id, region_name=context.region)
+
+ kwargs = {"StartingSequenceNumber": sequence_number} if sequence_number else {}
+ result = kinesis.get_shard_iterator(
+ StreamName=stream_name,
+ ShardId=stream_shard_id,
+ ShardIteratorType=shard_iterator_type,
+ **kwargs,
+ )
+ del result["ResponseMetadata"]
+ # TODO not quite clear what the |1| exactly denotes, because at AWS it's sometimes other numbers
+ result["ShardIterator"] = f"{stream_arn}|1|{result['ShardIterator']}"
+ return GetShardIteratorOutput(**result)
+
+ def list_streams(
+ self,
+ context: RequestContext,
+ table_name: TableName = None,
+ limit: PositiveIntegerObject = None,
+ exclusive_start_stream_arn: StreamArn = None,
+ **kwargs,
+ ) -> ListStreamsOutput:
+ store = get_dynamodbstreams_store(context.account_id, context.region)
+ result = [select_from_typed_dict(Stream, res) for res in store.ddb_streams.values()]
+ if table_name:
+ result = [res for res in result if res["TableName"] == table_name]
+ return ListStreamsOutput(Streams=result)
diff --git a/localstack-core/localstack/services/dynamodbstreams/v2/__init__.py b/localstack-core/localstack/services/dynamodbstreams/v2/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/dynamodbstreams/v2/provider.py b/localstack-core/localstack/services/dynamodbstreams/v2/provider.py
new file mode 100644
index 0000000000000..5f6a86150b315
--- /dev/null
+++ b/localstack-core/localstack/services/dynamodbstreams/v2/provider.py
@@ -0,0 +1,81 @@
+import logging
+
+from localstack.aws import handlers
+from localstack.aws.api import RequestContext, ServiceRequest, ServiceResponse, handler
+from localstack.aws.api.dynamodbstreams import (
+ DescribeStreamInput,
+ DescribeStreamOutput,
+ DynamodbstreamsApi,
+ GetRecordsInput,
+ GetRecordsOutput,
+ GetShardIteratorInput,
+ GetShardIteratorOutput,
+ ListStreamsInput,
+ ListStreamsOutput,
+)
+from localstack.services.dynamodb.server import DynamodbServer
+from localstack.services.dynamodb.utils import modify_ddblocal_arns
+from localstack.services.dynamodb.v2.provider import DynamoDBProvider
+from localstack.services.plugins import ServiceLifecycleHook
+from localstack.utils.aws.arns import parse_arn
+
+LOG = logging.getLogger(__name__)
+
+
+class DynamoDBStreamsProvider(DynamodbstreamsApi, ServiceLifecycleHook):
+ def __init__(self):
+ self.server = DynamodbServer.get()
+
+ def on_after_init(self):
+ # add response processor specific to ddblocal
+ handlers.modify_service_response.append(self.service, modify_ddblocal_arns)
+
+ def on_before_start(self):
+ self.server.start_dynamodb()
+
+ def forward_request(
+ self, context: RequestContext, service_request: ServiceRequest = None
+ ) -> ServiceResponse:
+ """
+ Forward a request to DynamoDB Local.
+ """
+ DynamoDBProvider.prepare_request_headers(
+ context.request.headers, account_id=context.account_id, region_name=context.region
+ )
+ return self.server.proxy(context, service_request)
+
+ def modify_stream_arn_for_ddb_local(self, stream_arn: str) -> str:
+ parsed_arn = parse_arn(stream_arn)
+
+ return f"arn:aws:dynamodb:ddblocal:000000000000:{parsed_arn['resource']}"
+
+ @handler("DescribeStream", expand=False)
+ def describe_stream(
+ self,
+ context: RequestContext,
+ payload: DescribeStreamInput,
+ ) -> DescribeStreamOutput:
+ request = payload.copy()
+ request["StreamArn"] = self.modify_stream_arn_for_ddb_local(request.get("StreamArn", ""))
+ return self.forward_request(context, request)
+
+ @handler("GetRecords", expand=False)
+ def get_records(self, context: RequestContext, payload: GetRecordsInput) -> GetRecordsOutput:
+ request = payload.copy()
+ request["ShardIterator"] = self.modify_stream_arn_for_ddb_local(
+ request.get("ShardIterator", "")
+ )
+ return self.forward_request(context, request)
+
+ @handler("GetShardIterator", expand=False)
+ def get_shard_iterator(
+ self, context: RequestContext, payload: GetShardIteratorInput
+ ) -> GetShardIteratorOutput:
+ request = payload.copy()
+ request["StreamArn"] = self.modify_stream_arn_for_ddb_local(request.get("StreamArn", ""))
+ return self.forward_request(context, request)
+
+ @handler("ListStreams", expand=False)
+ def list_streams(self, context: RequestContext, payload: ListStreamsInput) -> ListStreamsOutput:
+ # TODO: look into `ExclusiveStartStreamArn` param
+ return self.forward_request(context, payload)
diff --git a/localstack-core/localstack/services/ec2/__init__.py b/localstack-core/localstack/services/ec2/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/ec2/exceptions.py b/localstack-core/localstack/services/ec2/exceptions.py
new file mode 100644
index 0000000000000..cb968ba2e6e68
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/exceptions.py
@@ -0,0 +1,80 @@
+from localstack.aws.api import CommonServiceException
+
+
+class InternalError(CommonServiceException):
+ def __init__(self, message):
+ super().__init__(
+ code="InternalError",
+ message=message,
+ )
+
+
+class IncorrectInstanceStateError(CommonServiceException):
+ def __init__(self, instance_id):
+ super().__init__(
+ code="IncorrectInstanceState",
+ message=f"The instance '{instance_id}' is not in a state from which it can be started",
+ )
+
+
+class InvalidAMIIdError(CommonServiceException):
+ def __init__(self, ami_id):
+ super().__init__(
+ code="InvalidAMIID.NotFound", message=f"The image id '{ami_id}' does not exist"
+ )
+
+
+class InvalidInstanceIdError(CommonServiceException):
+ def __init__(self, instance_id):
+ super().__init__(
+ code="InvalidInstanceID.NotFound",
+ message=f"The instance ID '{instance_id}' does not exist",
+ )
+
+
+class MissingParameterError(CommonServiceException):
+ def __init__(self, parameter):
+ super().__init__(
+ code="MissingParameter",
+ message=f"The request must contain the parameter {parameter}",
+ )
+
+
+class InvalidLaunchTemplateNameError(CommonServiceException):
+ def __init__(self):
+ super().__init__(
+ code="InvalidLaunchTemplateName.MalformedException",
+ message="A launch template name must be between 3 and 128 characters, and may contain letters, numbers, and the following characters: - ( ) . / _.'",
+ )
+
+
+class InvalidLaunchTemplateIdError(CommonServiceException):
+ def __init__(self):
+ super().__init__(
+ code="InvalidLaunchTemplateId.VersionNotFound",
+ message="Could not find launch template version",
+ )
+
+
+class InvalidSubnetDuplicateCustomIdError(CommonServiceException):
+ def __init__(self, custom_id):
+ super().__init__(
+ code="InvalidSubnet.DuplicateCustomId",
+ message=f"Subnet with custom id '{custom_id}' already exists",
+ )
+
+
+class InvalidSecurityGroupDuplicateCustomIdError(CommonServiceException):
+ def __init__(self, custom_id):
+ super().__init__(
+ code="InvalidSecurityGroupId.DuplicateCustomId",
+ message=f"Security group with custom id '{custom_id}' already exists",
+ )
+
+
+class InvalidVpcDuplicateCustomIdError(CommonServiceException):
+ def __init__(self, custom_id):
+ super().__init__(
+ code="InvalidVpc.DuplicateCustomId",
+ message=f"VPC with custom id '{custom_id}' already exists",
+ )
diff --git a/localstack-core/localstack/services/ec2/models.py b/localstack-core/localstack/services/ec2/models.py
new file mode 100644
index 0000000000000..cf2bc854900da
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/models.py
@@ -0,0 +1,27 @@
+from moto.ec2 import ec2_backends
+from moto.ec2.models import EC2Backend
+from moto.ec2.models.subnets import Subnet
+
+
+def get_ec2_backend(account_id: str, region: str) -> EC2Backend:
+ return ec2_backends[account_id][region]
+
+
+#
+# Pickle patches
+#
+
+
+def set_state(self, state):
+ state["_subnet_ip_generator"] = state["cidr"].hosts()
+ self.__dict__.update(state)
+
+
+def get_state(self):
+ state = self.__dict__.copy()
+ state.pop("_subnet_ip_generator", None)
+ return state
+
+
+Subnet.__setstate__ = set_state
+Subnet.__getstate__ = get_state
diff --git a/localstack-core/localstack/services/ec2/patches.py b/localstack-core/localstack/services/ec2/patches.py
new file mode 100644
index 0000000000000..d9db4cad11e08
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/patches.py
@@ -0,0 +1,199 @@
+import logging
+from typing import Optional
+
+from moto.ec2 import models as ec2_models
+from moto.utilities.id_generator import TAG_KEY_CUSTOM_ID, Tags
+
+from localstack.services.ec2.exceptions import (
+ InvalidSecurityGroupDuplicateCustomIdError,
+ InvalidSubnetDuplicateCustomIdError,
+ InvalidVpcDuplicateCustomIdError,
+)
+from localstack.utils.id_generator import (
+ ExistingIds,
+ ResourceIdentifier,
+ localstack_id,
+)
+from localstack.utils.patch import patch
+
+LOG = logging.getLogger(__name__)
+
+
+@localstack_id
+def generate_vpc_id(
+ resource_identifier: ResourceIdentifier,
+ existing_ids: ExistingIds = None,
+ tags: Tags = None,
+) -> str:
+ # We return an empty string here to differentiate between when a custom ID was used, or when it was randomly generated by `moto`.
+ return ""
+
+
+@localstack_id
+def generate_subnet_id(
+ resource_identifier: ResourceIdentifier,
+ existing_ids: ExistingIds = None,
+ tags: Tags = None,
+) -> str:
+ # We return an empty string here to differentiate between when a custom ID was used, or when it was randomly generated by `moto`.
+ return ""
+
+
+class VpcIdentifier(ResourceIdentifier):
+ service = "ec2"
+ resource = "vpc"
+
+ def __init__(self, account_id: str, region: str, cidr_block: str):
+ super().__init__(account_id, region, name=cidr_block)
+
+ def generate(self, existing_ids: ExistingIds = None, tags: Tags = None) -> str:
+ return generate_vpc_id(
+ resource_identifier=self,
+ existing_ids=existing_ids,
+ tags=tags,
+ )
+
+
+class SubnetIdentifier(ResourceIdentifier):
+ service = "ec2"
+ resource = "subnet"
+
+ def __init__(self, account_id: str, region: str, vpc_id: str, cidr_block: str):
+ super().__init__(account_id, region, name=f"subnet-{vpc_id}-{cidr_block}")
+
+ def generate(self, existing_ids: ExistingIds = None, tags: Tags = None) -> str:
+ return generate_subnet_id(
+ resource_identifier=self,
+ existing_ids=existing_ids,
+ tags=tags,
+ )
+
+
+def apply_patches():
+ @patch(ec2_models.subnets.SubnetBackend.create_subnet)
+ def ec2_create_subnet(
+ fn: ec2_models.subnets.SubnetBackend.create_subnet,
+ self: ec2_models.subnets.SubnetBackend,
+ *args,
+ tags: Optional[dict[str, str]] = None,
+ **kwargs,
+ ):
+ vpc_id: str = args[0] if len(args) >= 1 else kwargs["vpc_id"]
+ cidr_block: str = args[1] if len(args) >= 1 else kwargs["cidr_block"]
+ resource_identifier = SubnetIdentifier(
+ self.account_id, self.region_name, vpc_id, cidr_block
+ )
+ # tags has the format: {"subnet": {"Key": ..., "Value": ...}}
+ if tags is not None:
+ tags = tags.get("subnet", tags)
+ custom_id = resource_identifier.generate(tags=tags)
+
+ if custom_id:
+ # Check if custom id is unique within a given VPC
+ for az_subnets in self.subnets.values():
+ for subnet in az_subnets.values():
+ if subnet.vpc_id == vpc_id and subnet.id == custom_id:
+ raise InvalidSubnetDuplicateCustomIdError(custom_id)
+
+ # Generate subnet with moto library
+ result: ec2_models.subnets.Subnet = fn(self, *args, tags=tags, **kwargs)
+ availability_zone = result.availability_zone
+
+ if custom_id:
+ # Remove the subnet from the default dict and add it back with the custom id
+ self.subnets[availability_zone].pop(result.id)
+ result.id = custom_id
+ self.subnets[availability_zone][custom_id] = result
+
+ # Return the subnet with the patched custom id
+ return result
+
+ @patch(ec2_models.security_groups.SecurityGroupBackend.create_security_group)
+ def ec2_create_security_group(
+ fn: ec2_models.security_groups.SecurityGroupBackend.create_security_group,
+ self: ec2_models.security_groups.SecurityGroupBackend,
+ *args,
+ tags: Optional[dict[str, str]] = None,
+ force: bool = False,
+ **kwargs,
+ ):
+ # Extract tags and custom ID
+ tags: dict[str, str] = tags or {}
+ custom_id = tags.get(TAG_KEY_CUSTOM_ID)
+
+ if not force and self.get_security_group_from_id(custom_id):
+ raise InvalidSecurityGroupDuplicateCustomIdError(custom_id)
+
+ # Generate security group with moto library
+ result: ec2_models.security_groups.SecurityGroup = fn(
+ self, *args, tags=tags, force=force, **kwargs
+ )
+
+ if custom_id:
+ # Remove the security group from the default dict and add it back with the custom id
+ self.groups[result.vpc_id].pop(result.group_id)
+ result.group_id = result.id = custom_id
+ self.groups[result.vpc_id][custom_id] = result
+
+ return result
+
+ @patch(ec2_models.vpcs.VPCBackend.create_vpc)
+ def ec2_create_vpc(
+ fn: ec2_models.vpcs.VPCBackend.create_vpc,
+ self: ec2_models.vpcs.VPCBackend,
+ cidr_block: str,
+ *args,
+ tags: Optional[list[dict[str, str]]] = None,
+ is_default: bool = False,
+ **kwargs,
+ ):
+ resource_identifier = VpcIdentifier(self.account_id, self.region_name, cidr_block)
+ custom_id = resource_identifier.generate(tags=tags)
+
+ # Check if custom id is unique
+ if custom_id and custom_id in self.vpcs:
+ raise InvalidVpcDuplicateCustomIdError(custom_id)
+
+ # Generate VPC with moto library
+ result: ec2_models.vpcs.VPC = fn(
+ self, cidr_block, *args, tags=tags, is_default=is_default, **kwargs
+ )
+ vpc_id = result.id
+
+ if custom_id:
+ # Remove security group associated with unique non-custom VPC ID
+ default = self.get_security_group_from_name("default", vpc_id=vpc_id)
+ if not default:
+ self.delete_security_group(
+ name="default",
+ vpc_id=vpc_id,
+ )
+
+ # Delete route table if only main route table remains.
+ for route_table in self.describe_route_tables(filters={"vpc-id": vpc_id}):
+ self.delete_route_table(route_table.id) # type: ignore[attr-defined]
+
+ # Remove the VPC from the default dict and add it back with the custom id
+ self.vpcs.pop(vpc_id)
+ result.id = custom_id
+ self.vpcs[custom_id] = result
+
+ # Create default network ACL, route table, and security group for custom ID VPC
+ self.create_route_table(
+ vpc_id=custom_id,
+ main=True,
+ )
+ self.create_network_acl(
+ vpc_id=custom_id,
+ default=True,
+ )
+ # Associate default security group with custom ID VPC
+ if not default:
+ self.create_security_group(
+ name="default",
+ description="default VPC security group",
+ vpc_id=custom_id,
+ is_default=is_default,
+ )
+
+ return result
diff --git a/localstack-core/localstack/services/ec2/provider.py b/localstack-core/localstack/services/ec2/provider.py
new file mode 100644
index 0000000000000..59a560cd7295e
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/provider.py
@@ -0,0 +1,603 @@
+import copy
+import json
+import logging
+import re
+from abc import ABC
+from datetime import datetime, timezone
+
+from botocore.parsers import ResponseParserError
+from moto.core.utils import camelcase_to_underscores, underscores_to_camelcase
+from moto.ec2.exceptions import InvalidVpcEndPointIdError
+from moto.ec2.models import (
+ EC2Backend,
+ FlowLogsBackend,
+ SubnetBackend,
+ TransitGatewayAttachmentBackend,
+ VPCBackend,
+ ec2_backends,
+)
+from moto.ec2.models.launch_templates import LaunchTemplate as MotoLaunchTemplate
+from moto.ec2.models.subnets import Subnet
+
+from localstack.aws.api import CommonServiceException, RequestContext, handler
+from localstack.aws.api.ec2 import (
+ AvailabilityZone,
+ Boolean,
+ CreateFlowLogsRequest,
+ CreateFlowLogsResult,
+ CreateLaunchTemplateRequest,
+ CreateLaunchTemplateResult,
+ CreateSubnetRequest,
+ CreateSubnetResult,
+ CreateTransitGatewayRequest,
+ CreateTransitGatewayResult,
+ CurrencyCodeValues,
+ DescribeAvailabilityZonesRequest,
+ DescribeAvailabilityZonesResult,
+ DescribeReservedInstancesOfferingsRequest,
+ DescribeReservedInstancesOfferingsResult,
+ DescribeReservedInstancesRequest,
+ DescribeReservedInstancesResult,
+ DescribeSubnetsRequest,
+ DescribeSubnetsResult,
+ DescribeTransitGatewaysRequest,
+ DescribeTransitGatewaysResult,
+ DescribeVpcEndpointServicesRequest,
+ DescribeVpcEndpointServicesResult,
+ DescribeVpcEndpointsRequest,
+ DescribeVpcEndpointsResult,
+ DnsOptions,
+ DnsOptionsSpecification,
+ DnsRecordIpType,
+ Ec2Api,
+ InstanceType,
+ IpAddressType,
+ LaunchTemplate,
+ ModifyLaunchTemplateRequest,
+ ModifyLaunchTemplateResult,
+ ModifySubnetAttributeRequest,
+ ModifyVpcEndpointResult,
+ OfferingClassType,
+ OfferingTypeValues,
+ PricingDetail,
+ PurchaseReservedInstancesOfferingRequest,
+ PurchaseReservedInstancesOfferingResult,
+ RecurringCharge,
+ RecurringChargeFrequency,
+ ReservedInstances,
+ ReservedInstancesOffering,
+ ReservedInstanceState,
+ RevokeSecurityGroupEgressRequest,
+ RevokeSecurityGroupEgressResult,
+ RIProductDescription,
+ String,
+ SubnetConfigurationsList,
+ Tenancy,
+ UnsuccessfulItem,
+ UnsuccessfulItemError,
+ VpcEndpointId,
+ VpcEndpointRouteTableIdList,
+ VpcEndpointSecurityGroupIdList,
+ VpcEndpointSubnetIdList,
+ scope,
+)
+from localstack.aws.connect import connect_to
+from localstack.services.ec2.exceptions import (
+ InvalidLaunchTemplateIdError,
+ InvalidLaunchTemplateNameError,
+ MissingParameterError,
+)
+from localstack.services.ec2.models import get_ec2_backend
+from localstack.services.ec2.patches import apply_patches
+from localstack.services.moto import call_moto, call_moto_with_request
+from localstack.services.plugins import ServiceLifecycleHook
+from localstack.utils.patch import patch
+from localstack.utils.strings import first_char_to_upper, long_uid, short_uid
+
+LOG = logging.getLogger(__name__)
+
+# additional subnet attributes not yet supported upstream
+ADDITIONAL_SUBNET_ATTRS = ("private_dns_name_options_on_launch", "enable_dns64")
+
+
+class Ec2Provider(Ec2Api, ABC, ServiceLifecycleHook):
+ def on_after_init(self):
+ apply_patches()
+
+ @handler("DescribeAvailabilityZones", expand=False)
+ def describe_availability_zones(
+ self,
+ context: RequestContext,
+ describe_availability_zones_request: DescribeAvailabilityZonesRequest,
+ ) -> DescribeAvailabilityZonesResult:
+ backend = get_ec2_backend(context.account_id, context.region)
+ zone_names = describe_availability_zones_request.get("ZoneNames")
+ zone_ids = describe_availability_zones_request.get("ZoneIds")
+ if zone_names or zone_ids:
+ filters = {
+ "zone-name": zone_names,
+ "zone-id": zone_ids,
+ }
+ filtered_zones = backend.describe_availability_zones(filters)
+ availability_zones = [
+ AvailabilityZone(
+ State="available",
+ Messages=[],
+ RegionName=zone.region_name,
+ ZoneName=zone.name,
+ ZoneId=zone.zone_id,
+ ZoneType=zone.zone_type,
+ )
+ for zone in filtered_zones
+ ]
+ return DescribeAvailabilityZonesResult(AvailabilityZones=availability_zones)
+ return call_moto(context)
+
+ @handler("DescribeReservedInstancesOfferings", expand=False)
+ def describe_reserved_instances_offerings(
+ self,
+ context: RequestContext,
+ describe_reserved_instances_offerings_request: DescribeReservedInstancesOfferingsRequest,
+ ) -> DescribeReservedInstancesOfferingsResult:
+ return DescribeReservedInstancesOfferingsResult(
+ ReservedInstancesOfferings=[
+ ReservedInstancesOffering(
+ AvailabilityZone="eu-central-1a",
+ Duration=2628000,
+ FixedPrice=0.0,
+ InstanceType=InstanceType.t2_small,
+ ProductDescription=RIProductDescription.Linux_UNIX,
+ ReservedInstancesOfferingId=long_uid(),
+ UsagePrice=0.0,
+ CurrencyCode=CurrencyCodeValues.USD,
+ InstanceTenancy=Tenancy.default,
+ Marketplace=True,
+ PricingDetails=[PricingDetail(Price=0.0, Count=3)],
+ RecurringCharges=[
+ RecurringCharge(Amount=0.25, Frequency=RecurringChargeFrequency.Hourly)
+ ],
+ Scope=scope.Availability_Zone,
+ )
+ ]
+ )
+
+ @handler("DescribeReservedInstances", expand=False)
+ def describe_reserved_instances(
+ self,
+ context: RequestContext,
+ describe_reserved_instances_request: DescribeReservedInstancesRequest,
+ ) -> DescribeReservedInstancesResult:
+ return DescribeReservedInstancesResult(
+ ReservedInstances=[
+ ReservedInstances(
+ AvailabilityZone="eu-central-1a",
+ Duration=2628000,
+ End=datetime(2016, 6, 30, tzinfo=timezone.utc),
+ FixedPrice=0.0,
+ InstanceCount=2,
+ InstanceType=InstanceType.t2_small,
+ ProductDescription=RIProductDescription.Linux_UNIX,
+ ReservedInstancesId=long_uid(),
+ Start=datetime(2016, 1, 1, tzinfo=timezone.utc),
+ State=ReservedInstanceState.active,
+ UsagePrice=0.05,
+ CurrencyCode=CurrencyCodeValues.USD,
+ InstanceTenancy=Tenancy.default,
+ OfferingClass=OfferingClassType.standard,
+ OfferingType=OfferingTypeValues.Partial_Upfront,
+ RecurringCharges=[
+ RecurringCharge(Amount=0.05, Frequency=RecurringChargeFrequency.Hourly)
+ ],
+ Scope=scope.Availability_Zone,
+ )
+ ]
+ )
+
+ @handler("PurchaseReservedInstancesOffering", expand=False)
+ def purchase_reserved_instances_offering(
+ self,
+ context: RequestContext,
+ purchase_reserved_instances_offerings_request: PurchaseReservedInstancesOfferingRequest,
+ ) -> PurchaseReservedInstancesOfferingResult:
+ return PurchaseReservedInstancesOfferingResult(
+ ReservedInstancesId=long_uid(),
+ )
+
+ @handler("ModifyVpcEndpoint")
+ def modify_vpc_endpoint(
+ self,
+ context: RequestContext,
+ vpc_endpoint_id: VpcEndpointId,
+ dry_run: Boolean = None,
+ reset_policy: Boolean = None,
+ policy_document: String = None,
+ add_route_table_ids: VpcEndpointRouteTableIdList = None,
+ remove_route_table_ids: VpcEndpointRouteTableIdList = None,
+ add_subnet_ids: VpcEndpointSubnetIdList = None,
+ remove_subnet_ids: VpcEndpointSubnetIdList = None,
+ add_security_group_ids: VpcEndpointSecurityGroupIdList = None,
+ remove_security_group_ids: VpcEndpointSecurityGroupIdList = None,
+ ip_address_type: IpAddressType = None,
+ dns_options: DnsOptionsSpecification = None,
+ private_dns_enabled: Boolean = None,
+ subnet_configurations: SubnetConfigurationsList = None,
+ **kwargs,
+ ) -> ModifyVpcEndpointResult:
+ backend = get_ec2_backend(context.account_id, context.region)
+
+ vpc_endpoint = backend.vpc_end_points.get(vpc_endpoint_id)
+ if not vpc_endpoint:
+ raise InvalidVpcEndPointIdError(vpc_endpoint_id)
+
+ if policy_document is not None:
+ vpc_endpoint.policy_document = policy_document
+
+ if add_route_table_ids is not None:
+ vpc_endpoint.route_table_ids.extend(add_route_table_ids)
+
+ if remove_route_table_ids is not None:
+ vpc_endpoint.route_table_ids = [
+ id_ for id_ in vpc_endpoint.route_table_ids if id_ not in remove_route_table_ids
+ ]
+
+ if add_subnet_ids is not None:
+ vpc_endpoint.subnet_ids.extend(add_subnet_ids)
+
+ if remove_subnet_ids is not None:
+ vpc_endpoint.subnet_ids = [
+ id_ for id_ in vpc_endpoint.subnet_ids if id_ not in remove_subnet_ids
+ ]
+
+ if private_dns_enabled is not None:
+ vpc_endpoint.private_dns_enabled = private_dns_enabled
+
+ return ModifyVpcEndpointResult(Return=True)
+
+ @handler("ModifySubnetAttribute", expand=False)
+ def modify_subnet_attribute(
+ self, context: RequestContext, request: ModifySubnetAttributeRequest
+ ) -> None:
+ try:
+ return call_moto(context)
+ except Exception as e:
+ if not isinstance(e, ResponseParserError) and "InvalidParameterValue" not in str(e):
+ raise
+
+ backend = get_ec2_backend(context.account_id, context.region)
+
+ # fix setting subnet attributes currently not supported upstream
+ subnet_id = request["SubnetId"]
+ host_type = request.get("PrivateDnsHostnameTypeOnLaunch")
+ a_record_on_launch = request.get("EnableResourceNameDnsARecordOnLaunch")
+ aaaa_record_on_launch = request.get("EnableResourceNameDnsAAAARecordOnLaunch")
+ enable_dns64 = request.get("EnableDns64")
+
+ if host_type:
+ attr_name = camelcase_to_underscores("PrivateDnsNameOptionsOnLaunch")
+ value = {"HostnameType": host_type}
+ backend.modify_subnet_attribute(subnet_id, attr_name, value)
+ ## explicitly checking None value as this could contain a False value
+ if aaaa_record_on_launch is not None:
+ attr_name = camelcase_to_underscores("PrivateDnsNameOptionsOnLaunch")
+ value = {"EnableResourceNameDnsAAAARecord": aaaa_record_on_launch["Value"]}
+ backend.modify_subnet_attribute(subnet_id, attr_name, value)
+ if a_record_on_launch is not None:
+ attr_name = camelcase_to_underscores("PrivateDnsNameOptionsOnLaunch")
+ value = {"EnableResourceNameDnsARecord": a_record_on_launch["Value"]}
+ backend.modify_subnet_attribute(subnet_id, attr_name, value)
+ if enable_dns64 is not None:
+ attr_name = camelcase_to_underscores("EnableDns64")
+ backend.modify_subnet_attribute(subnet_id, attr_name, enable_dns64["Value"])
+
+ @handler("CreateSubnet", expand=False)
+ def create_subnet(
+ self, context: RequestContext, request: CreateSubnetRequest
+ ) -> CreateSubnetResult:
+ response = call_moto(context)
+ backend = get_ec2_backend(context.account_id, context.region)
+ subnet_id = response["Subnet"]["SubnetId"]
+ host_type = request.get("PrivateDnsHostnameTypeOnLaunch", "ip-name")
+ attr_name = camelcase_to_underscores("PrivateDnsNameOptionsOnLaunch")
+ value = {"HostnameType": host_type}
+ backend.modify_subnet_attribute(subnet_id, attr_name, value)
+ return response
+
+ @handler("RevokeSecurityGroupEgress", expand=False)
+ def revoke_security_group_egress(
+ self,
+ context: RequestContext,
+ revoke_security_group_egress_request: RevokeSecurityGroupEgressRequest,
+ ) -> RevokeSecurityGroupEgressResult:
+ try:
+ return call_moto(context)
+ except Exception as e:
+ if "specified rule does not exist" in str(e):
+ backend = get_ec2_backend(context.account_id, context.region)
+ group_id = revoke_security_group_egress_request["GroupId"]
+ group = backend.get_security_group_by_name_or_id(group_id)
+ if group and not group.egress_rules:
+ return RevokeSecurityGroupEgressResult(Return=True)
+ raise
+
+ @handler("DescribeSubnets", expand=False)
+ def describe_subnets(
+ self,
+ context: RequestContext,
+ request: DescribeSubnetsRequest,
+ ) -> DescribeSubnetsResult:
+ result = call_moto(context)
+ backend = get_ec2_backend(context.account_id, context.region)
+ # add additional/missing attributes in subnet responses
+ for subnet in result.get("Subnets", []):
+ subnet_obj = backend.subnets[subnet["AvailabilityZone"]].get(subnet["SubnetId"])
+ for attr in ADDITIONAL_SUBNET_ATTRS:
+ if hasattr(subnet_obj, attr):
+ attr_name = first_char_to_upper(underscores_to_camelcase(attr))
+ if attr_name not in subnet:
+ subnet[attr_name] = getattr(subnet_obj, attr)
+ return result
+
+ @handler("CreateTransitGateway", expand=False)
+ def create_transit_gateway(
+ self,
+ context: RequestContext,
+ request: CreateTransitGatewayRequest,
+ ) -> CreateTransitGatewayResult:
+ result = call_moto(context)
+ backend = get_ec2_backend(context.account_id, context.region)
+ transit_gateway_id = result["TransitGateway"]["TransitGatewayId"]
+ transit_gateway = backend.transit_gateways.get(transit_gateway_id)
+ result.get("TransitGateway").get("Options").update(transit_gateway.options)
+ return result
+
+ @handler("DescribeTransitGateways", expand=False)
+ def describe_transit_gateways(
+ self,
+ context: RequestContext,
+ request: DescribeTransitGatewaysRequest,
+ ) -> DescribeTransitGatewaysResult:
+ result = call_moto(context)
+ backend = get_ec2_backend(context.account_id, context.region)
+ for transit_gateway in result.get("TransitGateways", []):
+ transit_gateway_id = transit_gateway["TransitGatewayId"]
+ tgw = backend.transit_gateways.get(transit_gateway_id)
+ transit_gateway["Options"].update(tgw.options)
+ return result
+
+ @handler("CreateLaunchTemplate", expand=False)
+ def create_launch_template(
+ self,
+ context: RequestContext,
+ request: CreateLaunchTemplateRequest,
+ ) -> CreateLaunchTemplateResult:
+ # parameter validation
+ if not request["LaunchTemplateData"]:
+ raise MissingParameterError(parameter="LaunchTemplateData")
+
+ name = request["LaunchTemplateName"]
+ if len(name) < 3 or len(name) > 128 or not re.fullmatch(r"[a-zA-Z0-9.\-_()/]*", name):
+ raise InvalidLaunchTemplateNameError()
+
+ return call_moto(context)
+
+ @handler("ModifyLaunchTemplate", expand=False)
+ def modify_launch_template(
+ self,
+ context: RequestContext,
+ request: ModifyLaunchTemplateRequest,
+ ) -> ModifyLaunchTemplateResult:
+ backend = get_ec2_backend(context.account_id, context.region)
+ template_id = (
+ request["LaunchTemplateId"]
+ or backend.launch_template_name_to_ids[request["LaunchTemplateName"]]
+ )
+ template: MotoLaunchTemplate = backend.launch_templates[template_id]
+
+ # check if defaultVersion exists
+ if request["DefaultVersion"]:
+ try:
+ template.versions[int(request["DefaultVersion"]) - 1]
+ except IndexError:
+ raise InvalidLaunchTemplateIdError()
+
+ template.default_version_number = int(request["DefaultVersion"])
+
+ return ModifyLaunchTemplateResult(
+ LaunchTemplate=LaunchTemplate(
+ LaunchTemplateId=template.id,
+ LaunchTemplateName=template.name,
+ CreateTime=template.create_time,
+ DefaultVersionNumber=template.default_version_number,
+ LatestVersionNumber=template.latest_version_number,
+ Tags=template.tags,
+ )
+ )
+
+ @handler("DescribeVpcEndpointServices", expand=False)
+ def describe_vpc_endpoint_services(
+ self,
+ context: RequestContext,
+ request: DescribeVpcEndpointServicesRequest,
+ ) -> DescribeVpcEndpointServicesResult:
+ ep_services = VPCBackend._collect_default_endpoint_services(
+ account_id=context.account_id, region=context.region
+ )
+
+ moto_backend = get_moto_backend(context)
+ service_names = [s["ServiceName"] for s in ep_services]
+ execute_api_name = f"com.amazonaws.{context.region}.execute-api"
+
+ if execute_api_name not in service_names:
+ # ensure that the service entry for execute-api exists
+ zones = moto_backend.describe_availability_zones()
+ zones = [zone.name for zone in zones]
+ private_dns_name = f"*.execute-api.{context.region}.amazonaws.com"
+ service = {
+ "ServiceName": execute_api_name,
+ "ServiceId": f"vpce-svc-{short_uid()}",
+ "ServiceType": [{"ServiceType": "Interface"}],
+ "AvailabilityZones": zones,
+ "Owner": "amazon",
+ "BaseEndpointDnsNames": [f"execute-api.{context.region}.vpce.amazonaws.com"],
+ "PrivateDnsName": private_dns_name,
+ "PrivateDnsNames": [{"PrivateDnsName": private_dns_name}],
+ "VpcEndpointPolicySupported": True,
+ "AcceptanceRequired": False,
+ "ManagesVpcEndpoints": False,
+ "PrivateDnsNameVerificationState": "verified",
+ "SupportedIpAddressTypes": ["ipv4"],
+ }
+ ep_services.append(service)
+
+ return call_moto(context)
+
+ @handler("DescribeVpcEndpoints", expand=False)
+ def describe_vpc_endpoints(
+ self,
+ context: RequestContext,
+ request: DescribeVpcEndpointsRequest,
+ ) -> DescribeVpcEndpointsResult:
+ result: DescribeVpcEndpointsResult = call_moto(context)
+
+ for endpoint in result.get("VpcEndpoints"):
+ endpoint.setdefault("DnsOptions", DnsOptions(DnsRecordIpType=DnsRecordIpType.ipv4))
+ endpoint.setdefault("IpAddressType", IpAddressType.ipv4)
+ endpoint.setdefault("RequesterManaged", False)
+ endpoint.setdefault("RouteTableIds", [])
+ # AWS parity: Version should not be contained in the policy response
+ policy = endpoint.get("PolicyDocument")
+ if policy and '"Version":' in policy:
+ policy = json.loads(policy)
+ policy.pop("Version", None)
+ endpoint["PolicyDocument"] = json.dumps(policy)
+
+ return result
+
+ @handler("CreateFlowLogs", expand=False)
+ def create_flow_logs(
+ self,
+ context: RequestContext,
+ request: CreateFlowLogsRequest,
+ **kwargs,
+ ) -> CreateFlowLogsResult:
+ if request.get("LogDestination") and request.get("LogGroupName"):
+ raise CommonServiceException(
+ code="InvalidParameter",
+ message="Please only provide LogGroupName or only provide LogDestination.",
+ )
+ if request.get("LogDestinationType") == "s3":
+ if request.get("LogGroupName"):
+ raise CommonServiceException(
+ code="InvalidParameter",
+ message="LogDestination type must be cloud-watch-logs if LogGroupName is provided.",
+ )
+ elif not (bucket_arn := request.get("LogDestination")):
+ raise CommonServiceException(
+ code="InvalidParameter",
+ message="LogDestination can't be empty if LogGroupName is not provided.",
+ )
+
+ # Moto will check in memory whether the bucket exists in Moto itself
+ # we modify the request to not send a destination, so that the validation does not happen
+ # we can add the validation ourselves
+ service_request = copy.deepcopy(request)
+ service_request["LogDestinationType"] = "__placeholder__"
+ bucket_name = bucket_arn.split(":", 5)[5].split("/")[0]
+ # TODO: validate how IAM is enforced? probably with DeliverLogsPermissionArn
+ s3_client = connect_to().s3
+ try:
+ s3_client.head_bucket(Bucket=bucket_name)
+ except Exception as e:
+ LOG.debug(
+ "An exception occurred when trying to create FlowLogs with S3 destination: %s",
+ e,
+ )
+ return CreateFlowLogsResult(
+ FlowLogIds=[],
+ Unsuccessful=[
+ UnsuccessfulItem(
+ Error=UnsuccessfulItemError(
+ Code="400",
+ Message=f"LogDestination: {bucket_name} does not exist",
+ ),
+ ResourceId=resource_id,
+ )
+ for resource_id in request.get("ResourceIds", [])
+ ],
+ )
+
+ response: CreateFlowLogsResult = call_moto_with_request(context, service_request)
+ moto_backend = get_moto_backend(context)
+ for flow_log_id in response["FlowLogIds"]:
+ if flow_log := moto_backend.flow_logs.get(flow_log_id):
+ # just to be sure to not override another value, we only replace if it's the placeholder
+ flow_log.log_destination_type = flow_log.log_destination_type.replace(
+ "__placeholder__", "s3"
+ )
+ else:
+ response = call_moto(context)
+
+ return response
+
+
+@patch(SubnetBackend.modify_subnet_attribute)
+def modify_subnet_attribute(fn, self, subnet_id, attr_name, attr_value):
+ subnet = self.get_subnet(subnet_id)
+ if attr_name in ADDITIONAL_SUBNET_ATTRS:
+ # private dns name options on launch contains dict with keys EnableResourceNameDnsARecord and EnableResourceNameDnsAAAARecord, HostnameType
+ if attr_name == "private_dns_name_options_on_launch":
+ if hasattr(subnet, attr_name):
+ getattr(subnet, attr_name).update(attr_value)
+ return
+ else:
+ setattr(subnet, attr_name, attr_value)
+ return
+ setattr(subnet, attr_name, attr_value)
+ return
+ return fn(self, subnet_id, attr_name, attr_value)
+
+
+def get_moto_backend(context: RequestContext) -> EC2Backend:
+ """Get the moto EC2 backend for the given request context"""
+ return ec2_backends[context.account_id][context.region]
+
+
+@patch(Subnet.get_filter_value)
+def get_filter_value(fn, self, filter_name):
+ if filter_name in (
+ "ipv6CidrBlockAssociationSet.associationId",
+ "ipv6-cidr-block-association.association-id",
+ ):
+ return self.ipv6_cidr_block_associations
+ return fn(self, filter_name)
+
+
+@patch(TransitGatewayAttachmentBackend.delete_transit_gateway_vpc_attachment)
+def delete_transit_gateway_vpc_attachment(fn, self, transit_gateway_attachment_id, **kwargs):
+ transit_gateway_attachment = self.transit_gateway_attachments.get(transit_gateway_attachment_id)
+ transit_gateway_attachment.state = "deleted"
+ return transit_gateway_attachment
+
+
+@patch(FlowLogsBackend._validate_request)
+def _validate_request(
+ fn,
+ self,
+ log_group_name: str,
+ log_destination: str,
+ log_destination_type: str,
+ max_aggregation_interval: str,
+ deliver_logs_permission_arn: str,
+) -> None:
+ if not log_destination_type and log_destination:
+ # this is to fix the S3 destination issue, the validation will occur in the provider
+ return
+
+ fn(
+ self,
+ log_group_name,
+ log_destination,
+ log_destination_type,
+ max_aggregation_interval,
+ deliver_logs_permission_arn,
+ )
diff --git a/localstack-core/localstack/services/ec2/resource_providers/__init__.py b/localstack-core/localstack/services/ec2/resource_providers/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_dhcpoptions.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_dhcpoptions.py
new file mode 100644
index 0000000000000..03665a7c45fb6
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_dhcpoptions.py
@@ -0,0 +1,149 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EC2DHCPOptionsProperties(TypedDict):
+ DhcpOptionsId: Optional[str]
+ DomainName: Optional[str]
+ DomainNameServers: Optional[list[str]]
+ NetbiosNameServers: Optional[list[str]]
+ NetbiosNodeType: Optional[int]
+ NtpServers: Optional[list[str]]
+ Tags: Optional[list[Tag]]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EC2DHCPOptionsProvider(ResourceProvider[EC2DHCPOptionsProperties]):
+ TYPE = "AWS::EC2::DHCPOptions" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2DHCPOptionsProperties],
+ ) -> ProgressEvent[EC2DHCPOptionsProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/DhcpOptionsId
+
+
+
+ Create-only properties:
+ - /properties/NetbiosNameServers
+ - /properties/NetbiosNodeType
+ - /properties/NtpServers
+ - /properties/DomainName
+ - /properties/DomainNameServers
+
+ Read-only properties:
+ - /properties/DhcpOptionsId
+
+ IAM permissions required:
+ - ec2:CreateDhcpOptions
+ - ec2:DescribeDhcpOptions
+ - ec2:CreateTags
+
+ """
+ model = request.desired_state
+
+ dhcp_configurations = []
+ if model.get("DomainName"):
+ dhcp_configurations.append({"Key": "domain-name", "Values": [model["DomainName"]]})
+ if model.get("DomainNameServers"):
+ dhcp_configurations.append(
+ {"Key": "domain-name-servers", "Values": model["DomainNameServers"]}
+ )
+ if model.get("NetbiosNameServers"):
+ dhcp_configurations.append(
+ {"Key": "netbios-name-servers", "Values": model["NetbiosNameServers"]}
+ )
+ if model.get("NetbiosNodeType"):
+ dhcp_configurations.append(
+ {"Key": "netbios-node-type", "Values": [str(model["NetbiosNodeType"])]}
+ )
+ if model.get("NtpServers"):
+ dhcp_configurations.append({"Key": "ntp-servers", "Values": model["NtpServers"]})
+
+ create_params = {
+ "DhcpConfigurations": dhcp_configurations,
+ }
+ if model.get("Tags"):
+ tags = [{"Key": str(tag["Key"]), "Value": str(tag["Value"])} for tag in model["Tags"]]
+ else:
+ tags = []
+
+ default_tags = [
+ {"Key": "aws:cloudformation:logical-id", "Value": request.logical_resource_id},
+ {"Key": "aws:cloudformation:stack-id", "Value": request.stack_id},
+ {"Key": "aws:cloudformation:stack-name", "Value": request.stack_name},
+ ]
+
+ create_params["TagSpecifications"] = [
+ {"ResourceType": "dhcp-options", "Tags": (tags + default_tags)}
+ ]
+
+ result = request.aws_client_factory.ec2.create_dhcp_options(**create_params)
+ model["DhcpOptionsId"] = result["DhcpOptions"]["DhcpOptionsId"]
+
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ def read(
+ self,
+ request: ResourceRequest[EC2DHCPOptionsProperties],
+ ) -> ProgressEvent[EC2DHCPOptionsProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - ec2:DescribeDhcpOptions
+ - ec2:DescribeTags
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2DHCPOptionsProperties],
+ ) -> ProgressEvent[EC2DHCPOptionsProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - ec2:DeleteDhcpOptions
+ - ec2:DeleteTags
+ """
+ model = request.desired_state
+ request.aws_client_factory.ec2.delete_dhcp_options(DhcpOptionsId=model["DhcpOptionsId"])
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model={})
+
+ def update(
+ self,
+ request: ResourceRequest[EC2DHCPOptionsProperties],
+ ) -> ProgressEvent[EC2DHCPOptionsProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - ec2:CreateTags
+ - ec2:DescribeDhcpOptions
+ - ec2:DeleteTags
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_dhcpoptions.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_dhcpoptions.schema.json
new file mode 100644
index 0000000000000..93e8fd3d62171
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_dhcpoptions.schema.json
@@ -0,0 +1,120 @@
+{
+ "typeName": "AWS::EC2::DHCPOptions",
+ "description": "Resource Type definition for AWS::EC2::DHCPOptions",
+ "additionalProperties": false,
+ "properties": {
+ "DhcpOptionsId": {
+ "type": "string"
+ },
+ "DomainName": {
+ "type": "string",
+ "description": "This value is used to complete unqualified DNS hostnames."
+ },
+ "DomainNameServers": {
+ "type": "array",
+ "description": "The IPv4 addresses of up to four domain name servers, or AmazonProvidedDNS.",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "NetbiosNameServers": {
+ "type": "array",
+ "description": "The IPv4 addresses of up to four NetBIOS name servers.",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "NetbiosNodeType": {
+ "type": "integer",
+ "description": "The NetBIOS node type (1, 2, 4, or 8)."
+ },
+ "NtpServers": {
+ "type": "array",
+ "description": "The IPv4 addresses of up to four Network Time Protocol (NTP) servers.",
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Tags": {
+ "type": "array",
+ "description": "Any tags assigned to the DHCP options set.",
+ "uniqueItems": false,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ }
+ },
+ "definitions": {
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Key": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 128
+ },
+ "Value": {
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 256
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ }
+ },
+ "taggable": true,
+ "createOnlyProperties": [
+ "/properties/NetbiosNameServers",
+ "/properties/NetbiosNodeType",
+ "/properties/NtpServers",
+ "/properties/DomainName",
+ "/properties/DomainNameServers"
+ ],
+ "readOnlyProperties": [
+ "/properties/DhcpOptionsId"
+ ],
+ "primaryIdentifier": [
+ "/properties/DhcpOptionsId"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "ec2:CreateDhcpOptions",
+ "ec2:DescribeDhcpOptions",
+ "ec2:CreateTags"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "ec2:DescribeDhcpOptions",
+ "ec2:DescribeTags"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "ec2:CreateTags",
+ "ec2:DescribeDhcpOptions",
+ "ec2:DeleteTags"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "ec2:DeleteDhcpOptions",
+ "ec2:DeleteTags"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "ec2:DescribeDhcpOptions"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_dhcpoptions_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_dhcpoptions_plugin.py
new file mode 100644
index 0000000000000..c3ac8bb5a5827
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_dhcpoptions_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2DHCPOptionsProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::DHCPOptions"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_dhcpoptions import (
+ EC2DHCPOptionsProvider,
+ )
+
+ self.factory = EC2DHCPOptionsProvider
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_instance.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_instance.py
new file mode 100644
index 0000000000000..8c33cde7b2ab8
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_instance.py
@@ -0,0 +1,342 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import base64
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+from localstack.utils.strings import to_str
+
+
+class EC2InstanceProperties(TypedDict):
+ AdditionalInfo: Optional[str]
+ Affinity: Optional[str]
+ AvailabilityZone: Optional[str]
+ BlockDeviceMappings: Optional[list[BlockDeviceMapping]]
+ CpuOptions: Optional[CpuOptions]
+ CreditSpecification: Optional[CreditSpecification]
+ DisableApiTermination: Optional[bool]
+ EbsOptimized: Optional[bool]
+ ElasticGpuSpecifications: Optional[list[ElasticGpuSpecification]]
+ ElasticInferenceAccelerators: Optional[list[ElasticInferenceAccelerator]]
+ EnclaveOptions: Optional[EnclaveOptions]
+ HibernationOptions: Optional[HibernationOptions]
+ HostId: Optional[str]
+ HostResourceGroupArn: Optional[str]
+ IamInstanceProfile: Optional[str]
+ Id: Optional[str]
+ ImageId: Optional[str]
+ InstanceInitiatedShutdownBehavior: Optional[str]
+ InstanceType: Optional[str]
+ Ipv6AddressCount: Optional[int]
+ Ipv6Addresses: Optional[list[InstanceIpv6Address]]
+ KernelId: Optional[str]
+ KeyName: Optional[str]
+ LaunchTemplate: Optional[LaunchTemplateSpecification]
+ LicenseSpecifications: Optional[list[LicenseSpecification]]
+ Monitoring: Optional[bool]
+ NetworkInterfaces: Optional[list[NetworkInterface]]
+ PlacementGroupName: Optional[str]
+ PrivateDnsName: Optional[str]
+ PrivateDnsNameOptions: Optional[PrivateDnsNameOptions]
+ PrivateIp: Optional[str]
+ PrivateIpAddress: Optional[str]
+ PropagateTagsToVolumeOnCreation: Optional[bool]
+ PublicDnsName: Optional[str]
+ PublicIp: Optional[str]
+ RamdiskId: Optional[str]
+ SecurityGroupIds: Optional[list[str]]
+ SecurityGroups: Optional[list[str]]
+ SourceDestCheck: Optional[bool]
+ SsmAssociations: Optional[list[SsmAssociation]]
+ SubnetId: Optional[str]
+ Tags: Optional[list[Tag]]
+ Tenancy: Optional[str]
+ UserData: Optional[str]
+ Volumes: Optional[list[Volume]]
+
+
+class Ebs(TypedDict):
+ DeleteOnTermination: Optional[bool]
+ Encrypted: Optional[bool]
+ Iops: Optional[int]
+ KmsKeyId: Optional[str]
+ SnapshotId: Optional[str]
+ VolumeSize: Optional[int]
+ VolumeType: Optional[str]
+
+
+class BlockDeviceMapping(TypedDict):
+ DeviceName: Optional[str]
+ Ebs: Optional[Ebs]
+ NoDevice: Optional[dict]
+ VirtualName: Optional[str]
+
+
+class InstanceIpv6Address(TypedDict):
+ Ipv6Address: Optional[str]
+
+
+class ElasticGpuSpecification(TypedDict):
+ Type: Optional[str]
+
+
+class ElasticInferenceAccelerator(TypedDict):
+ Type: Optional[str]
+ Count: Optional[int]
+
+
+class Volume(TypedDict):
+ Device: Optional[str]
+ VolumeId: Optional[str]
+
+
+class LaunchTemplateSpecification(TypedDict):
+ Version: Optional[str]
+ LaunchTemplateId: Optional[str]
+ LaunchTemplateName: Optional[str]
+
+
+class EnclaveOptions(TypedDict):
+ Enabled: Optional[bool]
+
+
+class PrivateIpAddressSpecification(TypedDict):
+ Primary: Optional[bool]
+ PrivateIpAddress: Optional[str]
+
+
+class NetworkInterface(TypedDict):
+ DeviceIndex: Optional[str]
+ AssociateCarrierIpAddress: Optional[bool]
+ AssociatePublicIpAddress: Optional[bool]
+ DeleteOnTermination: Optional[bool]
+ Description: Optional[str]
+ GroupSet: Optional[list[str]]
+ Ipv6AddressCount: Optional[int]
+ Ipv6Addresses: Optional[list[InstanceIpv6Address]]
+ NetworkInterfaceId: Optional[str]
+ PrivateIpAddress: Optional[str]
+ PrivateIpAddresses: Optional[list[PrivateIpAddressSpecification]]
+ SecondaryPrivateIpAddressCount: Optional[int]
+ SubnetId: Optional[str]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+class HibernationOptions(TypedDict):
+ Configured: Optional[bool]
+
+
+class LicenseSpecification(TypedDict):
+ LicenseConfigurationArn: Optional[str]
+
+
+class CpuOptions(TypedDict):
+ CoreCount: Optional[int]
+ ThreadsPerCore: Optional[int]
+
+
+class PrivateDnsNameOptions(TypedDict):
+ EnableResourceNameDnsAAAARecord: Optional[bool]
+ EnableResourceNameDnsARecord: Optional[bool]
+ HostnameType: Optional[str]
+
+
+class AssociationParameter(TypedDict):
+ Key: Optional[str]
+ Value: Optional[list[str]]
+
+
+class SsmAssociation(TypedDict):
+ DocumentName: Optional[str]
+ AssociationParameters: Optional[list[AssociationParameter]]
+
+
+class CreditSpecification(TypedDict):
+ CPUCredits: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EC2InstanceProvider(ResourceProvider[EC2InstanceProperties]):
+ TYPE = "AWS::EC2::Instance" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2InstanceProperties],
+ ) -> ProgressEvent[EC2InstanceProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+
+
+ Create-only properties:
+ - /properties/ElasticGpuSpecifications
+ - /properties/Ipv6Addresses
+ - /properties/PlacementGroupName
+ - /properties/HostResourceGroupArn
+ - /properties/ImageId
+ - /properties/CpuOptions
+ - /properties/PrivateIpAddress
+ - /properties/ElasticInferenceAccelerators
+ - /properties/EnclaveOptions
+ - /properties/HibernationOptions
+ - /properties/KeyName
+ - /properties/LicenseSpecifications
+ - /properties/NetworkInterfaces
+ - /properties/AvailabilityZone
+ - /properties/SubnetId
+ - /properties/LaunchTemplate
+ - /properties/SecurityGroups
+ - /properties/Ipv6AddressCount
+
+ Read-only properties:
+ - /properties/PublicIp
+ - /properties/Id
+ - /properties/PublicDnsName
+ - /properties/PrivateDnsName
+ - /properties/PrivateIp
+
+
+
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+ # TODO: validations
+
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ # this is the first time this callback is invoked
+ # TODO: idempotency
+ params = util.select_attributes(
+ model,
+ ["InstanceType", "SecurityGroups", "KeyName", "ImageId", "MaxCount", "MinCount"],
+ )
+
+ # This Parameters are not defined in the schema but are required by the API
+ params["MaxCount"] = 1
+ params["MinCount"] = 1
+
+ if model.get("UserData"):
+ params["UserData"] = to_str(base64.b64decode(model["UserData"]))
+
+ response = ec2.run_instances(**params)
+ model["Id"] = response["Instances"][0]["InstanceId"]
+ request.custom_context[REPEATED_INVOCATION] = True
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ response = ec2.describe_instances(InstanceIds=[model["Id"]])
+ instance = response["Reservations"][0]["Instances"][0]
+ if instance["State"]["Name"] != "running":
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ model["PrivateIp"] = instance["PrivateIpAddress"]
+ model["PrivateDnsName"] = instance["PrivateDnsName"]
+ model["AvailabilityZone"] = instance["Placement"]["AvailabilityZone"]
+
+ # PublicIp is not guaranteed to be returned by the request:
+ # https://docs.aws.amazon.com/cdk/api/v2/docs/aws-cdk-lib.aws_ec2.Instance.html#instancepublicip
+ # it says it is supposed to return an empty string, but trying to add an output with the value will result in
+ # an error: `Attribute 'PublicIp' does not exist`
+ if public_ip := instance.get("PublicIpAddress"):
+ model["PublicIp"] = public_ip
+
+ if public_dns_name := instance.get("PublicDnsName"):
+ model["PublicDnsName"] = public_dns_name
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EC2InstanceProperties],
+ ) -> ProgressEvent[EC2InstanceProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2InstanceProperties],
+ ) -> ProgressEvent[EC2InstanceProperties]:
+ """
+ Delete a resource
+
+
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+ ec2.terminate_instances(InstanceIds=[model["Id"]])
+ # TODO add checking of ec2 instance state
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[EC2InstanceProperties],
+ ) -> ProgressEvent[EC2InstanceProperties]:
+ """
+ Update a resource
+
+
+ """
+ desired_state = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+
+ groups = desired_state.get("SecurityGroups", desired_state.get("SecurityGroupIds"))
+
+ kwargs = {}
+ if groups:
+ kwargs["Groups"] = groups
+ ec2.modify_instance_attribute(
+ InstanceId=desired_state["Id"],
+ InstanceType={"Value": desired_state["InstanceType"]},
+ **kwargs,
+ )
+
+ response = ec2.describe_instances(InstanceIds=[desired_state["Id"]])
+ instance = response["Reservations"][0]["Instances"][0]
+ if instance["State"]["Name"] != "running":
+ return ProgressEvent(
+ status=OperationStatus.PENDING,
+ resource_model=desired_state,
+ custom_context=request.custom_context,
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=desired_state,
+ custom_context=request.custom_context,
+ )
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_instance.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_instance.schema.json
new file mode 100644
index 0000000000000..85ff4e3fd9d10
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_instance.schema.json
@@ -0,0 +1,540 @@
+{
+ "typeName": "AWS::EC2::Instance",
+ "description": "Resource Type definition for AWS::EC2::Instance",
+ "additionalProperties": false,
+ "properties": {
+ "Tenancy": {
+ "type": "string"
+ },
+ "SecurityGroups": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "PrivateDnsName": {
+ "type": "string"
+ },
+ "PrivateIpAddress": {
+ "type": "string"
+ },
+ "UserData": {
+ "type": "string"
+ },
+ "BlockDeviceMappings": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/BlockDeviceMapping"
+ }
+ },
+ "IamInstanceProfile": {
+ "type": "string"
+ },
+ "Ipv6Addresses": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/InstanceIpv6Address"
+ }
+ },
+ "KernelId": {
+ "type": "string"
+ },
+ "SubnetId": {
+ "type": "string"
+ },
+ "EbsOptimized": {
+ "type": "boolean"
+ },
+ "PropagateTagsToVolumeOnCreation": {
+ "type": "boolean"
+ },
+ "ElasticGpuSpecifications": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/ElasticGpuSpecification"
+ }
+ },
+ "ElasticInferenceAccelerators": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/ElasticInferenceAccelerator"
+ }
+ },
+ "Volumes": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Volume"
+ }
+ },
+ "PrivateIp": {
+ "type": "string"
+ },
+ "Ipv6AddressCount": {
+ "type": "integer"
+ },
+ "LaunchTemplate": {
+ "$ref": "#/definitions/LaunchTemplateSpecification"
+ },
+ "EnclaveOptions": {
+ "$ref": "#/definitions/EnclaveOptions"
+ },
+ "NetworkInterfaces": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/NetworkInterface"
+ }
+ },
+ "ImageId": {
+ "type": "string"
+ },
+ "InstanceType": {
+ "type": "string"
+ },
+ "Monitoring": {
+ "type": "boolean"
+ },
+ "Tags": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "AdditionalInfo": {
+ "type": "string"
+ },
+ "HibernationOptions": {
+ "$ref": "#/definitions/HibernationOptions"
+ },
+ "LicenseSpecifications": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/LicenseSpecification"
+ }
+ },
+ "PublicIp": {
+ "type": "string"
+ },
+ "InstanceInitiatedShutdownBehavior": {
+ "type": "string"
+ },
+ "CpuOptions": {
+ "$ref": "#/definitions/CpuOptions"
+ },
+ "AvailabilityZone": {
+ "type": "string"
+ },
+ "PrivateDnsNameOptions": {
+ "$ref": "#/definitions/PrivateDnsNameOptions"
+ },
+ "HostId": {
+ "type": "string"
+ },
+ "HostResourceGroupArn": {
+ "type": "string"
+ },
+ "PublicDnsName": {
+ "type": "string"
+ },
+ "SecurityGroupIds": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "DisableApiTermination": {
+ "type": "boolean"
+ },
+ "KeyName": {
+ "type": "string"
+ },
+ "RamdiskId": {
+ "type": "string"
+ },
+ "SourceDestCheck": {
+ "type": "boolean"
+ },
+ "PlacementGroupName": {
+ "type": "string"
+ },
+ "SsmAssociations": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/SsmAssociation"
+ }
+ },
+ "Affinity": {
+ "type": "string"
+ },
+ "Id": {
+ "type": "string"
+ },
+ "CreditSpecification": {
+ "$ref": "#/definitions/CreditSpecification"
+ }
+ },
+ "definitions": {
+ "LaunchTemplateSpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "LaunchTemplateName": {
+ "type": "string"
+ },
+ "LaunchTemplateId": {
+ "type": "string"
+ },
+ "Version": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Version"
+ ]
+ },
+ "HibernationOptions": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Configured": {
+ "type": "boolean"
+ }
+ }
+ },
+ "LicenseSpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "LicenseConfigurationArn": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "LicenseConfigurationArn"
+ ]
+ },
+ "CpuOptions": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "ThreadsPerCore": {
+ "type": "integer"
+ },
+ "CoreCount": {
+ "type": "integer"
+ }
+ }
+ },
+ "NoDevice": {
+ "type": "object",
+ "additionalProperties": false
+ },
+ "InstanceIpv6Address": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Ipv6Address": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Ipv6Address"
+ ]
+ },
+ "NetworkInterface": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Description": {
+ "type": "string"
+ },
+ "PrivateIpAddress": {
+ "type": "string"
+ },
+ "PrivateIpAddresses": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/PrivateIpAddressSpecification"
+ }
+ },
+ "SecondaryPrivateIpAddressCount": {
+ "type": "integer"
+ },
+ "DeviceIndex": {
+ "type": "string"
+ },
+ "GroupSet": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Ipv6Addresses": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/InstanceIpv6Address"
+ }
+ },
+ "SubnetId": {
+ "type": "string"
+ },
+ "AssociatePublicIpAddress": {
+ "type": "boolean"
+ },
+ "NetworkInterfaceId": {
+ "type": "string"
+ },
+ "AssociateCarrierIpAddress": {
+ "type": "boolean"
+ },
+ "Ipv6AddressCount": {
+ "type": "integer"
+ },
+ "DeleteOnTermination": {
+ "type": "boolean"
+ }
+ },
+ "required": [
+ "DeviceIndex"
+ ]
+ },
+ "PrivateDnsNameOptions": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "HostnameType": {
+ "type": "string"
+ },
+ "EnableResourceNameDnsAAAARecord": {
+ "type": "boolean"
+ },
+ "EnableResourceNameDnsARecord": {
+ "type": "boolean"
+ }
+ }
+ },
+ "ElasticGpuSpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Type": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Type"
+ ]
+ },
+ "ElasticInferenceAccelerator": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Type": {
+ "type": "string"
+ },
+ "Count": {
+ "type": "integer"
+ }
+ },
+ "required": [
+ "Type"
+ ]
+ },
+ "SsmAssociation": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "AssociationParameters": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/AssociationParameter"
+ }
+ },
+ "DocumentName": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "DocumentName"
+ ]
+ },
+ "AssociationParameter": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Value": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Key": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ },
+ "PrivateIpAddressSpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "PrivateIpAddress": {
+ "type": "string"
+ },
+ "Primary": {
+ "type": "boolean"
+ }
+ },
+ "required": [
+ "PrivateIpAddress",
+ "Primary"
+ ]
+ },
+ "Volume": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "VolumeId": {
+ "type": "string"
+ },
+ "Device": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "VolumeId",
+ "Device"
+ ]
+ },
+ "EnclaveOptions": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Enabled": {
+ "type": "boolean"
+ }
+ }
+ },
+ "Ebs": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "SnapshotId": {
+ "type": "string"
+ },
+ "VolumeType": {
+ "type": "string"
+ },
+ "KmsKeyId": {
+ "type": "string"
+ },
+ "Encrypted": {
+ "type": "boolean"
+ },
+ "Iops": {
+ "type": "integer"
+ },
+ "VolumeSize": {
+ "type": "integer"
+ },
+ "DeleteOnTermination": {
+ "type": "boolean"
+ }
+ }
+ },
+ "BlockDeviceMapping": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "NoDevice": {
+ "$ref": "#/definitions/NoDevice"
+ },
+ "VirtualName": {
+ "type": "string"
+ },
+ "Ebs": {
+ "$ref": "#/definitions/Ebs"
+ },
+ "DeviceName": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "DeviceName"
+ ]
+ },
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Value": {
+ "type": "string"
+ },
+ "Key": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ },
+ "CreditSpecification": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "CPUCredits": {
+ "type": "string"
+ }
+ }
+ }
+ },
+ "createOnlyProperties": [
+ "/properties/ElasticGpuSpecifications",
+ "/properties/Ipv6Addresses",
+ "/properties/PlacementGroupName",
+ "/properties/HostResourceGroupArn",
+ "/properties/ImageId",
+ "/properties/CpuOptions",
+ "/properties/PrivateIpAddress",
+ "/properties/ElasticInferenceAccelerators",
+ "/properties/EnclaveOptions",
+ "/properties/HibernationOptions",
+ "/properties/KeyName",
+ "/properties/LicenseSpecifications",
+ "/properties/NetworkInterfaces",
+ "/properties/AvailabilityZone",
+ "/properties/SubnetId",
+ "/properties/LaunchTemplate",
+ "/properties/SecurityGroups",
+ "/properties/Ipv6AddressCount"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/PublicIp",
+ "/properties/Id",
+ "/properties/PublicDnsName",
+ "/properties/PrivateDnsName",
+ "/properties/PrivateIp"
+ ]
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_instance_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_instance_plugin.py
new file mode 100644
index 0000000000000..60f400297a47f
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_instance_plugin.py
@@ -0,0 +1,18 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2InstanceProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::Instance"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_instance import EC2InstanceProvider
+
+ self.factory = EC2InstanceProvider
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_internetgateway.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_internetgateway.py
new file mode 100644
index 0000000000000..1ad0d6981b9c0
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_internetgateway.py
@@ -0,0 +1,116 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EC2InternetGatewayProperties(TypedDict):
+ InternetGatewayId: Optional[str]
+ Tags: Optional[list[Tag]]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EC2InternetGatewayProvider(ResourceProvider[EC2InternetGatewayProperties]):
+ TYPE = "AWS::EC2::InternetGateway" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2InternetGatewayProperties],
+ ) -> ProgressEvent[EC2InternetGatewayProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/InternetGatewayId
+
+ Read-only properties:
+ - /properties/InternetGatewayId
+
+ IAM permissions required:
+ - ec2:CreateInternetGateway
+ - ec2:CreateTags
+
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+
+ tags = [{"ResourceType": "'internet-gateway'", "Tags": model.get("Tags", [])}]
+
+ response = ec2.create_internet_gateway(TagSpecifications=tags)
+ model["InternetGatewayId"] = response["InternetGateway"]["InternetGatewayId"]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EC2InternetGatewayProperties],
+ ) -> ProgressEvent[EC2InternetGatewayProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - ec2:DescribeInternetGateways
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2InternetGatewayProperties],
+ ) -> ProgressEvent[EC2InternetGatewayProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - ec2:DeleteInternetGateway
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+
+ # detach it first before deleting it
+ response = ec2.describe_internet_gateways(InternetGatewayIds=[model["InternetGatewayId"]])
+
+ for gateway in response.get("InternetGateways", []):
+ for attachment in gateway.get("Attachments", []):
+ ec2.detach_internet_gateway(
+ InternetGatewayId=model["InternetGatewayId"], VpcId=attachment["VpcId"]
+ )
+ ec2.delete_internet_gateway(InternetGatewayId=model["InternetGatewayId"])
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[EC2InternetGatewayProperties],
+ ) -> ProgressEvent[EC2InternetGatewayProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - ec2:DeleteTags
+ - ec2:CreateTags
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_internetgateway.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_internetgateway.schema.json
new file mode 100644
index 0000000000000..62fd843a46c3f
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_internetgateway.schema.json
@@ -0,0 +1,78 @@
+{
+ "typeName": "AWS::EC2::InternetGateway",
+ "description": "Resource Type definition for AWS::EC2::InternetGateway",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-rpdk.git",
+ "additionalProperties": false,
+ "definitions": {
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Key": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 128
+ },
+ "Value": {
+ "type": "string",
+ "maxLength": 256
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ }
+ },
+ "properties": {
+ "InternetGatewayId": {
+ "description": "ID of internet gateway.",
+ "type": "string"
+ },
+ "Tags": {
+ "description": "Any tags to assign to the internet gateway.",
+ "type": "array",
+ "uniqueItems": false,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ }
+ },
+ "taggable": true,
+ "readOnlyProperties": [
+ "/properties/InternetGatewayId"
+ ],
+ "primaryIdentifier": [
+ "/properties/InternetGatewayId"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "ec2:CreateInternetGateway",
+ "ec2:CreateTags"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "ec2:DescribeInternetGateways"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "ec2:DeleteInternetGateway"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "ec2:DeleteTags",
+ "ec2:CreateTags"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "ec2:DescribeInternetGateways"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_internetgateway_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_internetgateway_plugin.py
new file mode 100644
index 0000000000000..51c889fae01a0
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_internetgateway_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2InternetGatewayProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::InternetGateway"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_internetgateway import (
+ EC2InternetGatewayProvider,
+ )
+
+ self.factory = EC2InternetGatewayProvider
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_keypair.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_keypair.py
new file mode 100644
index 0000000000000..8c03d6bc738b5
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_keypair.py
@@ -0,0 +1,148 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EC2KeyPairProperties(TypedDict):
+ KeyName: Optional[str]
+ KeyFingerprint: Optional[str]
+ KeyFormat: Optional[str]
+ KeyPairId: Optional[str]
+ KeyType: Optional[str]
+ PublicKeyMaterial: Optional[str]
+ Tags: Optional[list[Tag]]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EC2KeyPairProvider(ResourceProvider[EC2KeyPairProperties]):
+ TYPE = "AWS::EC2::KeyPair" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2KeyPairProperties],
+ ) -> ProgressEvent[EC2KeyPairProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/KeyName
+
+ Required properties:
+ - KeyName
+
+ Create-only properties:
+ - /properties/KeyName
+ - /properties/KeyType
+ - /properties/KeyFormat
+ - /properties/PublicKeyMaterial
+ - /properties/Tags
+
+ Read-only properties:
+ - /properties/KeyPairId
+ - /properties/KeyFingerprint
+
+ IAM permissions required:
+ - ec2:CreateKeyPair
+ - ec2:ImportKeyPair
+ - ec2:CreateTags
+ - ssm:PutParameter
+
+ """
+ model = request.desired_state
+
+ if "KeyName" not in model:
+ raise ValueError("Property 'KeyName' is required")
+
+ if public_key_material := model.get("PublicKeyMaterial"):
+ response = request.aws_client_factory.ec2.import_key_pair(
+ KeyName=model["KeyName"],
+ PublicKeyMaterial=public_key_material,
+ )
+ else:
+ create_params = util.select_attributes(
+ model, ["KeyName", "KeyType", "KeyFormat", "Tags"]
+ )
+ response = request.aws_client_factory.ec2.create_key_pair(**create_params)
+
+ model["KeyPairId"] = response["KeyPairId"]
+ model["KeyFingerprint"] = response["KeyFingerprint"]
+
+ request.aws_client_factory.ssm.put_parameter(
+ Name=f"/ec2/keypair/{model['KeyPairId']}",
+ Value=model["KeyName"],
+ Type="String",
+ Overwrite=True,
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EC2KeyPairProperties],
+ ) -> ProgressEvent[EC2KeyPairProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - ec2:DescribeKeyPairs
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2KeyPairProperties],
+ ) -> ProgressEvent[EC2KeyPairProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - ec2:DeleteKeyPair
+ - ssm:DeleteParameter
+ - ec2:DescribeKeyPairs
+ """
+
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+ ec2.delete_key_pair(KeyName=model["KeyName"])
+
+ request.aws_client_factory.ssm.delete_parameter(
+ Name=f"/ec2/keypair/{model['KeyPairId']}",
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[EC2KeyPairProperties],
+ ) -> ProgressEvent[EC2KeyPairProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_keypair.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_keypair.schema.json
new file mode 100644
index 0000000000000..d5b65ffc19a74
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_keypair.schema.json
@@ -0,0 +1,133 @@
+{
+ "typeName": "AWS::EC2::KeyPair",
+ "description": "The AWS::EC2::KeyPair creates an SSH key pair",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-rpdk.git",
+ "definitions": {
+ "Tag": {
+ "description": "A key-value pair to associate with a resource.",
+ "type": "object",
+ "properties": {
+ "Key": {
+ "type": "string",
+ "description": "The key name of the tag. You can specify a value that is 1 to 128 Unicode characters in length and cannot be prefixed with aws:. You can use any of the following characters: the set of Unicode letters, digits, whitespace, _, ., /, =, +, and -.",
+ "minLength": 1,
+ "maxLength": 128
+ },
+ "Value": {
+ "type": "string",
+ "description": "The value for the tag. You can specify a value that is 0 to 256 Unicode characters in length and cannot be prefixed with aws:. You can use any of the following characters: the set of Unicode letters, digits, whitespace, _, ., /, =, +, and -.",
+ "minLength": 0,
+ "maxLength": 256
+ }
+ },
+ "required": [
+ "Key",
+ "Value"
+ ],
+ "additionalProperties": false
+ }
+ },
+ "properties": {
+ "KeyName": {
+ "description": "The name of the SSH key pair",
+ "type": "string"
+ },
+ "KeyType": {
+ "description": "The crypto-system used to generate a key pair.",
+ "type": "string",
+ "default": "rsa",
+ "enum": [
+ "rsa",
+ "ed25519"
+ ]
+ },
+ "KeyFormat": {
+ "description": "The format of the private key",
+ "type": "string",
+ "default": "pem",
+ "enum": [
+ "pem",
+ "ppk"
+ ]
+ },
+ "PublicKeyMaterial": {
+ "description": "Plain text public key to import",
+ "type": "string"
+ },
+ "KeyFingerprint": {
+ "description": "A short sequence of bytes used for public key verification",
+ "type": "string"
+ },
+ "KeyPairId": {
+ "description": "An AWS generated ID for the key pair",
+ "type": "string"
+ },
+ "Tags": {
+ "description": "An array of key-value pairs to apply to this resource.",
+ "type": "array",
+ "uniqueItems": true,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "KeyName"
+ ],
+ "primaryIdentifier": [
+ "/properties/KeyName"
+ ],
+ "additionalIdentifiers": [
+ [
+ "/properties/KeyPairId"
+ ]
+ ],
+ "createOnlyProperties": [
+ "/properties/KeyName",
+ "/properties/KeyType",
+ "/properties/KeyFormat",
+ "/properties/PublicKeyMaterial",
+ "/properties/Tags"
+ ],
+ "writeOnlyProperties": [
+ "/properties/KeyFormat"
+ ],
+ "readOnlyProperties": [
+ "/properties/KeyPairId",
+ "/properties/KeyFingerprint"
+ ],
+ "tagging": {
+ "taggable": true,
+ "tagUpdatable": false,
+ "cloudFormationSystemTags": false
+ },
+ "handlers": {
+ "create": {
+ "permissions": [
+ "ec2:CreateKeyPair",
+ "ec2:ImportKeyPair",
+ "ec2:CreateTags",
+ "ssm:PutParameter"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "ec2:DescribeKeyPairs"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "ec2:DescribeKeyPairs"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "ec2:DeleteKeyPair",
+ "ssm:DeleteParameter",
+ "ec2:DescribeKeyPairs"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_keypair_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_keypair_plugin.py
new file mode 100644
index 0000000000000..5bb9524b1f667
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_keypair_plugin.py
@@ -0,0 +1,18 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2KeyPairProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::KeyPair"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_keypair import EC2KeyPairProvider
+
+ self.factory = EC2KeyPairProvider
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_natgateway.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_natgateway.py
new file mode 100644
index 0000000000000..de03079d89699
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_natgateway.py
@@ -0,0 +1,183 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EC2NatGatewayProperties(TypedDict):
+ SubnetId: Optional[str]
+ AllocationId: Optional[str]
+ ConnectivityType: Optional[str]
+ MaxDrainDurationSeconds: Optional[int]
+ NatGatewayId: Optional[str]
+ PrivateIpAddress: Optional[str]
+ SecondaryAllocationIds: Optional[list[str]]
+ SecondaryPrivateIpAddressCount: Optional[int]
+ SecondaryPrivateIpAddresses: Optional[list[str]]
+ Tags: Optional[list[Tag]]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EC2NatGatewayProvider(ResourceProvider[EC2NatGatewayProperties]):
+ TYPE = "AWS::EC2::NatGateway" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2NatGatewayProperties],
+ ) -> ProgressEvent[EC2NatGatewayProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/NatGatewayId
+
+ Required properties:
+ - SubnetId
+
+ Create-only properties:
+ - /properties/SubnetId
+ - /properties/ConnectivityType
+ - /properties/AllocationId
+ - /properties/PrivateIpAddress
+
+ Read-only properties:
+ - /properties/NatGatewayId
+
+ IAM permissions required:
+ - ec2:CreateNatGateway
+ - ec2:DescribeNatGateways
+ - ec2:CreateTags
+
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+
+ # TODO: validations
+ # TODO add tests for this resource at the moment, it's not covered
+
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ # this is the first time this callback is invoked
+ # TODO: defaults
+ # TODO: idempotency
+ params = util.select_attributes(
+ model,
+ ["SubnetId", "AllocationId"],
+ )
+
+ if model.get("Tags"):
+ tags = [{"ResourceType": "natgateway", "Tags": model.get("Tags")}]
+ params["TagSpecifications"] = tags
+
+ response = ec2.create_nat_gateway(
+ SubnetId=model["SubnetId"], AllocationId=model["AllocationId"]
+ )
+ model["NatGatewayId"] = response["NatGateway"]["NatGatewayId"]
+ request.custom_context[REPEATED_INVOCATION] = True
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+ response = ec2.describe_nat_gateways(NatGatewayIds=[model["NatGatewayId"]])
+ if response["NatGateways"][0]["State"] == "pending":
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+ # TODO add handling for failed events
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EC2NatGatewayProperties],
+ ) -> ProgressEvent[EC2NatGatewayProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - ec2:DescribeNatGateways
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2NatGatewayProperties],
+ ) -> ProgressEvent[EC2NatGatewayProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - ec2:DeleteNatGateway
+ - ec2:DescribeNatGateways
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ request.custom_context[REPEATED_INVOCATION] = True
+ ec2.delete_nat_gateway(NatGatewayId=model["NatGatewayId"])
+
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ is_deleting = False
+ try:
+ response = ec2.describe_nat_gateways(NatGatewayIds=[model["NatGatewayId"]])
+ is_deleting = response["NatGateways"][0]["State"] == "deleting"
+ except ec2.exceptions.ClientError:
+ pass
+
+ if is_deleting:
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[EC2NatGatewayProperties],
+ ) -> ProgressEvent[EC2NatGatewayProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - ec2:DescribeNatGateways
+ - ec2:CreateTags
+ - ec2:DeleteTags
+ - ec2:AssociateNatGatewayAddress
+ - ec2:DisassociateNatGatewayAddress
+ - ec2:AssignPrivateNatGatewayAddress
+ - ec2:UnassignPrivateNatGatewayAddress
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_natgateway.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_natgateway.schema.json
new file mode 100644
index 0000000000000..99f268a2dfc29
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_natgateway.schema.json
@@ -0,0 +1,131 @@
+{
+ "typeName": "AWS::EC2::NatGateway",
+ "description": "Resource Type definition for AWS::EC2::NatGateway",
+ "additionalProperties": false,
+ "properties": {
+ "SubnetId": {
+ "type": "string"
+ },
+ "NatGatewayId": {
+ "type": "string"
+ },
+ "ConnectivityType": {
+ "type": "string"
+ },
+ "PrivateIpAddress": {
+ "type": "string"
+ },
+ "Tags": {
+ "type": "array",
+ "uniqueItems": false,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "AllocationId": {
+ "type": "string"
+ },
+ "SecondaryAllocationIds": {
+ "type": "array",
+ "uniqueItems": true,
+ "insertionOrder": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "SecondaryPrivateIpAddresses": {
+ "type": "array",
+ "uniqueItems": true,
+ "insertionOrder": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "SecondaryPrivateIpAddressCount": {
+ "type": "integer",
+ "minimum": 1
+ },
+ "MaxDrainDurationSeconds": {
+ "type": "integer"
+ }
+ },
+ "definitions": {
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Value": {
+ "type": "string"
+ },
+ "Key": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ }
+ },
+ "required": [
+ "SubnetId"
+ ],
+ "createOnlyProperties": [
+ "/properties/SubnetId",
+ "/properties/ConnectivityType",
+ "/properties/AllocationId",
+ "/properties/PrivateIpAddress"
+ ],
+ "primaryIdentifier": [
+ "/properties/NatGatewayId"
+ ],
+ "readOnlyProperties": [
+ "/properties/NatGatewayId"
+ ],
+ "writeOnlyProperties": [
+ "/properties/MaxDrainDurationSeconds"
+ ],
+ "tagging": {
+ "taggable": true,
+ "tagOnCreate": true,
+ "tagUpdatable": true,
+ "cloudFormationSystemTags": true
+ },
+ "handlers": {
+ "create": {
+ "permissions": [
+ "ec2:CreateNatGateway",
+ "ec2:DescribeNatGateways",
+ "ec2:CreateTags"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "ec2:DeleteNatGateway",
+ "ec2:DescribeNatGateways"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "ec2:DescribeNatGateways"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "ec2:DescribeNatGateways"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "ec2:DescribeNatGateways",
+ "ec2:CreateTags",
+ "ec2:DeleteTags",
+ "ec2:AssociateNatGatewayAddress",
+ "ec2:DisassociateNatGatewayAddress",
+ "ec2:AssignPrivateNatGatewayAddress",
+ "ec2:UnassignPrivateNatGatewayAddress"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_natgateway_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_natgateway_plugin.py
new file mode 100644
index 0000000000000..e8036702f5e79
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_natgateway_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2NatGatewayProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::NatGateway"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_natgateway import (
+ EC2NatGatewayProvider,
+ )
+
+ self.factory = EC2NatGatewayProvider
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_networkacl.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_networkacl.py
new file mode 100644
index 0000000000000..47d36951d0068
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_networkacl.py
@@ -0,0 +1,116 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EC2NetworkAclProperties(TypedDict):
+ VpcId: Optional[str]
+ Id: Optional[str]
+ Tags: Optional[list[Tag]]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EC2NetworkAclProvider(ResourceProvider[EC2NetworkAclProperties]):
+ TYPE = "AWS::EC2::NetworkAcl" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2NetworkAclProperties],
+ ) -> ProgressEvent[EC2NetworkAclProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - VpcId
+
+ Create-only properties:
+ - /properties/VpcId
+
+ Read-only properties:
+ - /properties/Id
+
+ IAM permissions required:
+ - ec2:CreateNetworkAcl
+ - ec2:DescribeNetworkAcls
+
+ """
+ model = request.desired_state
+
+ create_params = {
+ "VpcId": model["VpcId"],
+ }
+
+ if model.get("Tags"):
+ create_params["TagSpecifications"] = [
+ {
+ "ResourceType": "network-acl",
+ "Tags": [{"Key": tag["Key"], "Value": tag["Value"]} for tag in model["Tags"]],
+ }
+ ]
+
+ response = request.aws_client_factory.ec2.create_network_acl(**create_params)
+ model["Id"] = response["NetworkAcl"]["NetworkAclId"]
+
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ def read(
+ self,
+ request: ResourceRequest[EC2NetworkAclProperties],
+ ) -> ProgressEvent[EC2NetworkAclProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - ec2:DescribeNetworkAcls
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2NetworkAclProperties],
+ ) -> ProgressEvent[EC2NetworkAclProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - ec2:DeleteNetworkAcl
+ - ec2:DescribeNetworkAcls
+ """
+ model = request.desired_state
+ request.aws_client_factory.ec2.delete_network_acl(NetworkAclId=model["Id"])
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model={})
+
+ def update(
+ self,
+ request: ResourceRequest[EC2NetworkAclProperties],
+ ) -> ProgressEvent[EC2NetworkAclProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - ec2:DescribeNetworkAcls
+ - ec2:DeleteTags
+ - ec2:CreateTags
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_networkacl.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_networkacl.schema.json
new file mode 100644
index 0000000000000..52bdc7cdca1ca
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_networkacl.schema.json
@@ -0,0 +1,92 @@
+{
+ "typeName": "AWS::EC2::NetworkAcl",
+ "description": "Resource Type definition for AWS::EC2::NetworkAcl",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-ec2.git",
+ "additionalProperties": false,
+ "definitions": {
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Key": {
+ "type": "string"
+ },
+ "Value": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ }
+ },
+ "properties": {
+ "Id": {
+ "type": "string"
+ },
+ "Tags": {
+ "description": "The tags to assign to the network ACL.",
+ "type": "array",
+ "uniqueItems": false,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "VpcId": {
+ "description": "The ID of the VPC.",
+ "type": "string"
+ }
+ },
+ "required": [
+ "VpcId"
+ ],
+ "createOnlyProperties": [
+ "/properties/VpcId"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "tagging": {
+ "taggable": true,
+ "tagOnCreate": true,
+ "tagUpdatable": true,
+ "cloudFormationSystemTags": true,
+ "tagProperty": "/properties/Tags"
+ },
+ "handlers": {
+ "create": {
+ "permissions": [
+ "ec2:CreateNetworkAcl",
+ "ec2:DescribeNetworkAcls"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "ec2:DescribeNetworkAcls"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "ec2:DescribeNetworkAcls",
+ "ec2:DeleteTags",
+ "ec2:CreateTags"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "ec2:DeleteNetworkAcl",
+ "ec2:DescribeNetworkAcls"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "ec2:DescribeNetworkAcls"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_networkacl_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_networkacl_plugin.py
new file mode 100644
index 0000000000000..0f24a9cd40adc
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_networkacl_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2NetworkAclProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::NetworkAcl"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_networkacl import (
+ EC2NetworkAclProvider,
+ )
+
+ self.factory = EC2NetworkAclProvider
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_prefixlist.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_prefixlist.py
new file mode 100644
index 0000000000000..8308fb5bfa990
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_prefixlist.py
@@ -0,0 +1,167 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EC2PrefixListProperties(TypedDict):
+ AddressFamily: Optional[str]
+ MaxEntries: Optional[int]
+ PrefixListName: Optional[str]
+ Arn: Optional[str]
+ Entries: Optional[list[Entry]]
+ OwnerId: Optional[str]
+ PrefixListId: Optional[str]
+ Tags: Optional[list[Tag]]
+ Version: Optional[int]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+class Entry(TypedDict):
+ Cidr: Optional[str]
+ Description: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EC2PrefixListProvider(ResourceProvider[EC2PrefixListProperties]):
+ TYPE = "AWS::EC2::PrefixList" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2PrefixListProperties],
+ ) -> ProgressEvent[EC2PrefixListProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/PrefixListId
+
+ Required properties:
+ - PrefixListName
+ - MaxEntries
+ - AddressFamily
+
+
+
+ Read-only properties:
+ - /properties/PrefixListId
+ - /properties/OwnerId
+ - /properties/Version
+ - /properties/Arn
+
+ IAM permissions required:
+ - EC2:CreateManagedPrefixList
+ - EC2:DescribeManagedPrefixLists
+ - EC2:CreateTags
+
+ """
+ model = request.desired_state
+
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ create_params = util.select_attributes(
+ model, ["PrefixListName", "Entries", "MaxEntries", "AddressFamily", "Tags"]
+ )
+
+ if "Tags" in create_params:
+ create_params["TagSpecifications"] = [
+ {"ResourceType": "prefix-list", "Tags": create_params.pop("Tags")}
+ ]
+
+ response = request.aws_client_factory.ec2.create_managed_prefix_list(**create_params)
+ model["Arn"] = response["PrefixList"]["PrefixListId"]
+ model["OwnerId"] = response["PrefixList"]["OwnerId"]
+ model["PrefixListId"] = response["PrefixList"]["PrefixListId"]
+ model["Version"] = response["PrefixList"]["Version"]
+ request.custom_context[REPEATED_INVOCATION] = True
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ response = request.aws_client_factory.ec2.describe_managed_prefix_lists(
+ PrefixListIds=[model["PrefixListId"]]
+ )
+ if not response["PrefixLists"]:
+ return ProgressEvent(
+ status=OperationStatus.FAILED,
+ resource_model=model,
+ custom_context=request.custom_context,
+ message="Resource not found after creation",
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EC2PrefixListProperties],
+ ) -> ProgressEvent[EC2PrefixListProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - EC2:GetManagedPrefixListEntries
+ - EC2:DescribeManagedPrefixLists
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2PrefixListProperties],
+ ) -> ProgressEvent[EC2PrefixListProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - EC2:DeleteManagedPrefixList
+ - EC2:DescribeManagedPrefixLists
+ """
+
+ model = request.previous_state
+ response = request.aws_client_factory.ec2.describe_managed_prefix_lists(
+ PrefixListIds=[model["PrefixListId"]]
+ )
+
+ if not response["PrefixLists"]:
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ request.aws_client_factory.ec2.delete_managed_prefix_list(
+ PrefixListId=request.previous_state["PrefixListId"]
+ )
+ return ProgressEvent(status=OperationStatus.IN_PROGRESS, resource_model=model)
+
+ def update(
+ self,
+ request: ResourceRequest[EC2PrefixListProperties],
+ ) -> ProgressEvent[EC2PrefixListProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - EC2:DescribeManagedPrefixLists
+ - EC2:GetManagedPrefixListEntries
+ - EC2:ModifyManagedPrefixList
+ - EC2:CreateTags
+ - EC2:DeleteTags
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_prefixlist.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_prefixlist.schema.json
new file mode 100644
index 0000000000000..cb27aefee2bd3
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_prefixlist.schema.json
@@ -0,0 +1,152 @@
+{
+ "typeName": "AWS::EC2::PrefixList",
+ "description": "Resource schema of AWS::EC2::PrefixList Type",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-rpdk.git",
+ "definitions": {
+ "Tag": {
+ "type": "object",
+ "properties": {
+ "Key": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 128
+ },
+ "Value": {
+ "type": "string",
+ "maxLength": 256
+ }
+ },
+ "required": [
+ "Key"
+ ],
+ "additionalProperties": false
+ },
+ "Entry": {
+ "type": "object",
+ "properties": {
+ "Cidr": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 46
+ },
+ "Description": {
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 255
+ }
+ },
+ "required": [
+ "Cidr"
+ ],
+ "additionalProperties": false
+ }
+ },
+ "properties": {
+ "PrefixListName": {
+ "description": "Name of Prefix List.",
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 255
+ },
+ "PrefixListId": {
+ "description": "Id of Prefix List.",
+ "type": "string"
+ },
+ "OwnerId": {
+ "description": "Owner Id of Prefix List.",
+ "type": "string"
+ },
+ "AddressFamily": {
+ "description": "Ip Version of Prefix List.",
+ "type": "string",
+ "enum": [
+ "IPv4",
+ "IPv6"
+ ]
+ },
+ "MaxEntries": {
+ "description": "Max Entries of Prefix List.",
+ "type": "integer",
+ "minimum": 1
+ },
+ "Version": {
+ "description": "Version of Prefix List.",
+ "type": "integer"
+ },
+ "Tags": {
+ "description": "Tags for Prefix List",
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "Entries": {
+ "description": "Entries of Prefix List.",
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/Entry"
+ }
+ },
+ "Arn": {
+ "description": "The Amazon Resource Name (ARN) of the Prefix List.",
+ "type": "string"
+ }
+ },
+ "required": [
+ "PrefixListName",
+ "MaxEntries",
+ "AddressFamily"
+ ],
+ "readOnlyProperties": [
+ "/properties/PrefixListId",
+ "/properties/OwnerId",
+ "/properties/Version",
+ "/properties/Arn"
+ ],
+ "primaryIdentifier": [
+ "/properties/PrefixListId"
+ ],
+ "tagging": {
+ "taggable": true,
+ "tagOnCreate": true,
+ "tagUpdatable": true,
+ "cloudFormationSystemTags": true
+ },
+ "handlers": {
+ "create": {
+ "permissions": [
+ "EC2:CreateManagedPrefixList",
+ "EC2:DescribeManagedPrefixLists",
+ "EC2:CreateTags"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "EC2:GetManagedPrefixListEntries",
+ "EC2:DescribeManagedPrefixLists"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "EC2:DescribeManagedPrefixLists",
+ "EC2:GetManagedPrefixListEntries",
+ "EC2:ModifyManagedPrefixList",
+ "EC2:CreateTags",
+ "EC2:DeleteTags"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "EC2:DeleteManagedPrefixList",
+ "EC2:DescribeManagedPrefixLists"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "EC2:DescribeManagedPrefixLists",
+ "EC2:GetManagedPrefixListEntries"
+ ]
+ }
+ },
+ "additionalProperties": false
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_prefixlist_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_prefixlist_plugin.py
new file mode 100644
index 0000000000000..5d8b993d28409
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_prefixlist_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2PrefixListProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::PrefixList"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_prefixlist import (
+ EC2PrefixListProvider,
+ )
+
+ self.factory = EC2PrefixListProvider
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_route.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_route.py
new file mode 100644
index 0000000000000..c779541d04229
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_route.py
@@ -0,0 +1,137 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+from moto.ec2.utils import generate_route_id
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EC2RouteProperties(TypedDict):
+ RouteTableId: Optional[str]
+ CarrierGatewayId: Optional[str]
+ DestinationCidrBlock: Optional[str]
+ DestinationIpv6CidrBlock: Optional[str]
+ EgressOnlyInternetGatewayId: Optional[str]
+ GatewayId: Optional[str]
+ Id: Optional[str]
+ InstanceId: Optional[str]
+ LocalGatewayId: Optional[str]
+ NatGatewayId: Optional[str]
+ NetworkInterfaceId: Optional[str]
+ TransitGatewayId: Optional[str]
+ VpcEndpointId: Optional[str]
+ VpcPeeringConnectionId: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EC2RouteProvider(ResourceProvider[EC2RouteProperties]):
+ TYPE = "AWS::EC2::Route" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2RouteProperties],
+ ) -> ProgressEvent[EC2RouteProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - RouteTableId
+
+ Create-only properties:
+ - /properties/RouteTableId
+ - /properties/DestinationCidrBlock
+
+ Read-only properties:
+ - /properties/Id
+
+
+
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+
+ cidr_block = model.get("DestinationCidrBlock")
+ ipv6_cidr_block = model.get("DestinationIpv6CidrBlock", "")
+
+ ec2.create_route(
+ DestinationCidrBlock=cidr_block,
+ DestinationIpv6CidrBlock=ipv6_cidr_block,
+ RouteTableId=model["RouteTableId"],
+ )
+ model["Id"] = generate_route_id(
+ model["RouteTableId"],
+ cidr_block,
+ ipv6_cidr_block,
+ )
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EC2RouteProperties],
+ ) -> ProgressEvent[EC2RouteProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2RouteProperties],
+ ) -> ProgressEvent[EC2RouteProperties]:
+ """
+ Delete a resource
+
+
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+
+ cidr_block = model.get("DestinationCidrBlock")
+ ipv6_cidr_block = model.get("DestinationIpv6CidrBlock", "")
+
+ try:
+ ec2.delete_route(
+ DestinationCidrBlock=cidr_block,
+ DestinationIpv6CidrBlock=ipv6_cidr_block,
+ RouteTableId=model["RouteTableId"],
+ )
+ except ec2.exceptions.ClientError:
+ pass
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[EC2RouteProperties],
+ ) -> ProgressEvent[EC2RouteProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_route.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_route.schema.json
new file mode 100644
index 0000000000000..151c2d115972e
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_route.schema.json
@@ -0,0 +1,62 @@
+{
+ "typeName": "AWS::EC2::Route",
+ "description": "Resource Type definition for AWS::EC2::Route",
+ "additionalProperties": false,
+ "properties": {
+ "DestinationIpv6CidrBlock": {
+ "type": "string"
+ },
+ "RouteTableId": {
+ "type": "string"
+ },
+ "InstanceId": {
+ "type": "string"
+ },
+ "LocalGatewayId": {
+ "type": "string"
+ },
+ "CarrierGatewayId": {
+ "type": "string"
+ },
+ "DestinationCidrBlock": {
+ "type": "string"
+ },
+ "GatewayId": {
+ "type": "string"
+ },
+ "NetworkInterfaceId": {
+ "type": "string"
+ },
+ "VpcEndpointId": {
+ "type": "string"
+ },
+ "TransitGatewayId": {
+ "type": "string"
+ },
+ "VpcPeeringConnectionId": {
+ "type": "string"
+ },
+ "EgressOnlyInternetGatewayId": {
+ "type": "string"
+ },
+ "Id": {
+ "type": "string"
+ },
+ "NatGatewayId": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "RouteTableId"
+ ],
+ "createOnlyProperties": [
+ "/properties/RouteTableId",
+ "/properties/DestinationCidrBlock"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ]
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_route_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_route_plugin.py
new file mode 100644
index 0000000000000..abd759b08aaca
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_route_plugin.py
@@ -0,0 +1,18 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2RouteProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::Route"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_route import EC2RouteProvider
+
+ self.factory = EC2RouteProvider
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_routetable.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_routetable.py
new file mode 100644
index 0000000000000..618c3fad99c08
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_routetable.py
@@ -0,0 +1,123 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EC2RouteTableProperties(TypedDict):
+ VpcId: Optional[str]
+ RouteTableId: Optional[str]
+ Tags: Optional[list[Tag]]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EC2RouteTableProvider(ResourceProvider[EC2RouteTableProperties]):
+ TYPE = "AWS::EC2::RouteTable" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2RouteTableProperties],
+ ) -> ProgressEvent[EC2RouteTableProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/RouteTableId
+
+ Required properties:
+ - VpcId
+
+ Create-only properties:
+ - /properties/VpcId
+
+ Read-only properties:
+ - /properties/RouteTableId
+
+ IAM permissions required:
+ - ec2:CreateRouteTable
+ - ec2:CreateTags
+ - ec2:DescribeRouteTables
+
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+ # TODO: validations
+ params = util.select_attributes(model, ["VpcId", "Tags"])
+
+ tags = [{"ResourceType": "route-table", "Tags": params.get("Tags", [])}]
+
+ response = ec2.create_route_table(VpcId=params["VpcId"], TagSpecifications=tags)
+ model["RouteTableId"] = response["RouteTable"]["RouteTableId"]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EC2RouteTableProperties],
+ ) -> ProgressEvent[EC2RouteTableProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - ec2:DescribeRouteTables
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2RouteTableProperties],
+ ) -> ProgressEvent[EC2RouteTableProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - ec2:DescribeRouteTables
+ - ec2:DeleteRouteTable
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+ try:
+ ec2.delete_route_table(RouteTableId=model["RouteTableId"])
+ except ec2.exceptions.ClientError:
+ pass
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[EC2RouteTableProperties],
+ ) -> ProgressEvent[EC2RouteTableProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - ec2:CreateTags
+ - ec2:DeleteTags
+ - ec2:DescribeRouteTables
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_routetable.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_routetable.schema.json
new file mode 100644
index 0000000000000..491be25027a62
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_routetable.schema.json
@@ -0,0 +1,94 @@
+{
+ "typeName": "AWS::EC2::RouteTable",
+ "description": "Resource Type definition for AWS::EC2::RouteTable",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-ec2",
+ "definitions": {
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Key": {
+ "type": "string"
+ },
+ "Value": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ }
+ },
+ "properties": {
+ "RouteTableId": {
+ "description": "The route table ID.",
+ "type": "string"
+ },
+ "Tags": {
+ "description": "Any tags assigned to the route table.",
+ "type": "array",
+ "uniqueItems": false,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "VpcId": {
+ "description": "The ID of the VPC.",
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "VpcId"
+ ],
+ "createOnlyProperties": [
+ "/properties/VpcId"
+ ],
+ "readOnlyProperties": [
+ "/properties/RouteTableId"
+ ],
+ "primaryIdentifier": [
+ "/properties/RouteTableId"
+ ],
+ "tagging": {
+ "taggable": true,
+ "tagOnCreate": true,
+ "tagUpdatable": true,
+ "cloudFormationSystemTags": true,
+ "tagProperty": "/properties/Tags"
+ },
+ "handlers": {
+ "create": {
+ "permissions": [
+ "ec2:CreateRouteTable",
+ "ec2:CreateTags",
+ "ec2:DescribeRouteTables"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "ec2:DescribeRouteTables"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "ec2:CreateTags",
+ "ec2:DeleteTags",
+ "ec2:DescribeRouteTables"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "ec2:DescribeRouteTables",
+ "ec2:DeleteRouteTable"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "ec2:DescribeRouteTables"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_routetable_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_routetable_plugin.py
new file mode 100644
index 0000000000000..07396c832bf66
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_routetable_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2RouteTableProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::RouteTable"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_routetable import (
+ EC2RouteTableProvider,
+ )
+
+ self.factory = EC2RouteTableProvider
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_securitygroup.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_securitygroup.py
new file mode 100644
index 0000000000000..8e9b54a35bb0f
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_securitygroup.py
@@ -0,0 +1,228 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ Properties,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EC2SecurityGroupProperties(TypedDict):
+ GroupDescription: Optional[str]
+ GroupId: Optional[str]
+ GroupName: Optional[str]
+ Id: Optional[str]
+ SecurityGroupEgress: Optional[list[Egress]]
+ SecurityGroupIngress: Optional[list[Ingress]]
+ Tags: Optional[list[Tag]]
+ VpcId: Optional[str]
+
+
+class Ingress(TypedDict):
+ IpProtocol: Optional[str]
+ CidrIp: Optional[str]
+ CidrIpv6: Optional[str]
+ Description: Optional[str]
+ FromPort: Optional[int]
+ SourcePrefixListId: Optional[str]
+ SourceSecurityGroupId: Optional[str]
+ SourceSecurityGroupName: Optional[str]
+ SourceSecurityGroupOwnerId: Optional[str]
+ ToPort: Optional[int]
+
+
+class Egress(TypedDict):
+ IpProtocol: Optional[str]
+ CidrIp: Optional[str]
+ CidrIpv6: Optional[str]
+ Description: Optional[str]
+ DestinationPrefixListId: Optional[str]
+ DestinationSecurityGroupId: Optional[str]
+ FromPort: Optional[int]
+ ToPort: Optional[int]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+def model_from_description(sg_description: dict) -> dict:
+ model = {
+ "Id": sg_description.get("GroupId"),
+ "GroupId": sg_description.get("GroupId"),
+ "GroupName": sg_description.get("GroupName"),
+ "GroupDescription": sg_description.get("Description"),
+ "SecurityGroupEgress": [],
+ "SecurityGroupIngress": [],
+ }
+
+ for i, egress in enumerate(sg_description.get("IpPermissionsEgress", [])):
+ for ip_range in egress.get("IpRanges", []):
+ model["SecurityGroupEgress"].append(
+ {
+ "CidrIp": ip_range.get("CidrIp"),
+ "FromPort": egress.get("FromPort", -1),
+ "IpProtocol": egress.get("IpProtocol", "-1"),
+ "ToPort": egress.get("ToPort", -1),
+ }
+ )
+
+ for i, ingress in enumerate(sg_description.get("IpPermissions", [])):
+ for ip_range in ingress.get("IpRanges", []):
+ model["SecurityGroupIngress"].append(
+ {
+ "CidrIp": ip_range.get("CidrIp"),
+ "FromPort": ingress.get("FromPort", -1),
+ "IpProtocol": ingress.get("IpProtocol", "-1"),
+ "ToPort": ingress.get("ToPort", -1),
+ }
+ )
+
+ model["VpcId"] = sg_description.get("VpcId")
+ return model
+
+
+class EC2SecurityGroupProvider(ResourceProvider[EC2SecurityGroupProperties]):
+ TYPE = "AWS::EC2::SecurityGroup" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2SecurityGroupProperties],
+ ) -> ProgressEvent[EC2SecurityGroupProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - GroupDescription
+
+ Create-only properties:
+ - /properties/GroupDescription
+ - /properties/GroupName
+ - /properties/VpcId
+
+ Read-only properties:
+ - /properties/Id
+ - /properties/GroupId
+
+
+
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+
+ params = {}
+
+ if not model.get("GroupName"):
+ params["GroupName"] = util.generate_default_name(
+ request.stack_name, request.logical_resource_id
+ )
+ else:
+ params["GroupName"] = model["GroupName"]
+
+ if vpc_id := model.get("VpcId"):
+ params["VpcId"] = vpc_id
+
+ params["Description"] = model.get("GroupDescription", "")
+
+ tags = [
+ {"Key": "aws:cloudformation:logical-id", "Value": request.logical_resource_id},
+ {"Key": "aws:cloudformation:stack-id", "Value": request.stack_id},
+ {"Key": "aws:cloudformation:stack-name", "Value": request.stack_name},
+ ]
+
+ if model_tags := model.get("Tags"):
+ tags += model_tags
+
+ params["TagSpecifications"] = [{"ResourceType": "security-group", "Tags": tags}]
+
+ response = ec2.create_security_group(**params)
+ model["GroupId"] = response["GroupId"]
+
+ # When you pass the logical ID of this resource to the intrinsic Ref function,
+ # Ref returns the ID of the security group if you specified the VpcId property.
+ # Otherwise, it returns the name of the security group.
+ # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-securitygroup.html#aws-resource-ec2-securitygroup-return-values-ref
+ if "VpcId" in model:
+ model["Id"] = response["GroupId"]
+ else:
+ model["Id"] = params["GroupName"]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EC2SecurityGroupProperties],
+ ) -> ProgressEvent[EC2SecurityGroupProperties]:
+ """
+ Fetch resource information
+ """
+
+ model = request.desired_state
+
+ security_group = request.aws_client_factory.ec2.describe_security_groups(
+ GroupIds=[model["Id"]]
+ )["SecurityGroups"][0]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model_from_description(security_group),
+ )
+
+ def list(self, request: ResourceRequest[Properties]) -> ProgressEvent[Properties]:
+ security_groups = request.aws_client_factory.ec2.describe_security_groups()[
+ "SecurityGroups"
+ ]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_models=[{"Id": description["GroupId"]} for description in security_groups],
+ )
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2SecurityGroupProperties],
+ ) -> ProgressEvent[EC2SecurityGroupProperties]:
+ """
+ Delete a resource
+
+
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+
+ ec2.delete_security_group(GroupId=model["GroupId"])
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[EC2SecurityGroupProperties],
+ ) -> ProgressEvent[EC2SecurityGroupProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_securitygroup.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_securitygroup.schema.json
new file mode 100644
index 0000000000000..5ccdf924ac598
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_securitygroup.schema.json
@@ -0,0 +1,148 @@
+{
+ "typeName": "AWS::EC2::SecurityGroup",
+ "description": "Resource Type definition for AWS::EC2::SecurityGroup",
+ "additionalProperties": false,
+ "properties": {
+ "GroupDescription": {
+ "type": "string"
+ },
+ "GroupName": {
+ "type": "string"
+ },
+ "VpcId": {
+ "type": "string"
+ },
+ "Id": {
+ "type": "string"
+ },
+ "SecurityGroupIngress": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Ingress"
+ }
+ },
+ "SecurityGroupEgress": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Egress"
+ }
+ },
+ "Tags": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "GroupId": {
+ "type": "string"
+ }
+ },
+ "definitions": {
+ "Ingress": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "CidrIp": {
+ "type": "string"
+ },
+ "CidrIpv6": {
+ "type": "string"
+ },
+ "Description": {
+ "type": "string"
+ },
+ "FromPort": {
+ "type": "integer"
+ },
+ "SourceSecurityGroupName": {
+ "type": "string"
+ },
+ "ToPort": {
+ "type": "integer"
+ },
+ "SourceSecurityGroupOwnerId": {
+ "type": "string"
+ },
+ "IpProtocol": {
+ "type": "string"
+ },
+ "SourceSecurityGroupId": {
+ "type": "string"
+ },
+ "SourcePrefixListId": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "IpProtocol"
+ ]
+ },
+ "Egress": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "CidrIp": {
+ "type": "string"
+ },
+ "CidrIpv6": {
+ "type": "string"
+ },
+ "Description": {
+ "type": "string"
+ },
+ "FromPort": {
+ "type": "integer"
+ },
+ "ToPort": {
+ "type": "integer"
+ },
+ "IpProtocol": {
+ "type": "string"
+ },
+ "DestinationSecurityGroupId": {
+ "type": "string"
+ },
+ "DestinationPrefixListId": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "IpProtocol"
+ ]
+ },
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Value": {
+ "type": "string"
+ },
+ "Key": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ }
+ },
+ "required": [
+ "GroupDescription"
+ ],
+ "createOnlyProperties": [
+ "/properties/GroupDescription",
+ "/properties/GroupName",
+ "/properties/VpcId"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id",
+ "/properties/GroupId"
+ ]
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_securitygroup_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_securitygroup_plugin.py
new file mode 100644
index 0000000000000..176bddb74e703
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_securitygroup_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2SecurityGroupProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::SecurityGroup"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_securitygroup import (
+ EC2SecurityGroupProvider,
+ )
+
+ self.factory = EC2SecurityGroupProvider
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnet.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnet.py
new file mode 100644
index 0000000000000..e7c82a0d3669c
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnet.py
@@ -0,0 +1,248 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+from localstack.utils.strings import str_to_bool
+
+
+class EC2SubnetProperties(TypedDict):
+ VpcId: Optional[str]
+ AssignIpv6AddressOnCreation: Optional[bool]
+ AvailabilityZone: Optional[str]
+ AvailabilityZoneId: Optional[str]
+ CidrBlock: Optional[str]
+ EnableDns64: Optional[bool]
+ Ipv6CidrBlock: Optional[str]
+ Ipv6CidrBlocks: Optional[list[str]]
+ Ipv6Native: Optional[bool]
+ MapPublicIpOnLaunch: Optional[bool]
+ NetworkAclAssociationId: Optional[str]
+ OutpostArn: Optional[str]
+ PrivateDnsNameOptionsOnLaunch: Optional[dict]
+ SubnetId: Optional[str]
+ Tags: Optional[list[Tag]]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+def generate_subnet_read_payload(
+ ec2_client, schema, subnet_ids: Optional[list[str]] = None
+) -> list[EC2SubnetProperties]:
+ kwargs = {}
+ if subnet_ids:
+ kwargs["SubnetIds"] = subnet_ids
+ subnets = ec2_client.describe_subnets(**kwargs)["Subnets"]
+
+ models = []
+ for subnet in subnets:
+ subnet_id = subnet["SubnetId"]
+
+ model = EC2SubnetProperties(**util.select_attributes(subnet, schema))
+
+ if "Tags" not in model:
+ model["Tags"] = []
+
+ if "EnableDns64" not in model:
+ model["EnableDns64"] = False
+
+ private_dns_name_options = model.setdefault("PrivateDnsNameOptionsOnLaunch", {})
+
+ if "HostnameType" not in private_dns_name_options:
+ private_dns_name_options["HostnameType"] = "ip-name"
+
+ optional_bool_attrs = ["EnableResourceNameDnsAAAARecord", "EnableResourceNameDnsARecord"]
+ for attr in optional_bool_attrs:
+ if attr not in private_dns_name_options:
+ private_dns_name_options[attr] = False
+
+ network_acl_associations = ec2_client.describe_network_acls(
+ Filters=[{"Name": "association.subnet-id", "Values": [subnet_id]}]
+ )
+ model["NetworkAclAssociationId"] = network_acl_associations["NetworkAcls"][0][
+ "NetworkAclId"
+ ]
+ models.append(model)
+
+ return models
+
+
+class EC2SubnetProvider(ResourceProvider[EC2SubnetProperties]):
+ TYPE = "AWS::EC2::Subnet" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2SubnetProperties],
+ ) -> ProgressEvent[EC2SubnetProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/SubnetId
+
+ Required properties:
+ - VpcId
+
+ Create-only properties:
+ - /properties/VpcId
+ - /properties/AvailabilityZone
+ - /properties/AvailabilityZoneId
+ - /properties/CidrBlock
+ - /properties/OutpostArn
+ - /properties/Ipv6Native
+
+ Read-only properties:
+ - /properties/NetworkAclAssociationId
+ - /properties/SubnetId
+ - /properties/Ipv6CidrBlocks
+
+ IAM permissions required:
+ - ec2:DescribeSubnets
+ - ec2:CreateSubnet
+ - ec2:CreateTags
+ - ec2:ModifySubnetAttribute
+
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+
+ params = util.select_attributes(
+ model,
+ [
+ "AvailabilityZone",
+ "AvailabilityZoneId",
+ "CidrBlock",
+ "Ipv6CidrBlock",
+ "Ipv6Native",
+ "OutpostArn",
+ "VpcId",
+ ],
+ )
+ if model.get("Tags"):
+ tags = [{"ResourceType": "subnet", "Tags": model.get("Tags")}]
+ params["TagSpecifications"] = tags
+
+ response = ec2.create_subnet(**params)
+ model["SubnetId"] = response["Subnet"]["SubnetId"]
+ bool_attrs = [
+ "AssignIpv6AddressOnCreation",
+ "EnableDns64",
+ "MapPublicIpOnLaunch",
+ ]
+ custom_attrs = bool_attrs + ["PrivateDnsNameOptionsOnLaunch"]
+ if not any(attr in model for attr in custom_attrs):
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ # update boolean attributes
+ for attr in bool_attrs:
+ if attr in model:
+ kwargs = {attr: {"Value": str_to_bool(model[attr])}}
+ ec2.modify_subnet_attribute(SubnetId=model["SubnetId"], **kwargs)
+
+ # determine DNS hostname type on launch
+ dns_options = model.get("PrivateDnsNameOptionsOnLaunch")
+ if dns_options:
+ if isinstance(dns_options, str):
+ dns_options = json.loads(dns_options)
+ if dns_options.get("HostnameType"):
+ ec2.modify_subnet_attribute(
+ SubnetId=model["SubnetId"],
+ PrivateDnsHostnameTypeOnLaunch=dns_options.get("HostnameType"),
+ )
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EC2SubnetProperties],
+ ) -> ProgressEvent[EC2SubnetProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - ec2:DescribeSubnets
+ - ec2:DescribeNetworkAcls
+ """
+ models = generate_subnet_read_payload(
+ ec2_client=request.aws_client_factory.ec2,
+ schema=self.SCHEMA["properties"],
+ subnet_ids=[request.desired_state["SubnetId"]],
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=models[0],
+ custom_context=request.custom_context,
+ )
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2SubnetProperties],
+ ) -> ProgressEvent[EC2SubnetProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - ec2:DescribeSubnets
+ - ec2:DeleteSubnet
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+
+ ec2.delete_subnet(SubnetId=model["SubnetId"])
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ def update(
+ self,
+ request: ResourceRequest[EC2SubnetProperties],
+ ) -> ProgressEvent[EC2SubnetProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - ec2:DescribeSubnets
+ - ec2:ModifySubnetAttribute
+ - ec2:CreateTags
+ - ec2:DeleteTags
+ - ec2:AssociateSubnetCidrBlock
+ - ec2:DisassociateSubnetCidrBlock
+ """
+ raise NotImplementedError
+
+ def list(
+ self, request: ResourceRequest[EC2SubnetProperties]
+ ) -> ProgressEvent[EC2SubnetProperties]:
+ """
+ List resources
+
+ IAM permissions required:
+ - ec2:DescribeSubnets
+ - ec2:DescribeNetworkAcls
+ """
+ models = generate_subnet_read_payload(
+ request.aws_client_factory.ec2, self.SCHEMA["properties"]
+ )
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_models=models)
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnet.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnet.schema.json
new file mode 100644
index 0000000000000..806f82f3ed8c7
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnet.schema.json
@@ -0,0 +1,157 @@
+{
+ "typeName": "AWS::EC2::Subnet",
+ "description": "Resource Type definition for AWS::EC2::Subnet",
+ "additionalProperties": false,
+ "properties": {
+ "AssignIpv6AddressOnCreation": {
+ "type": "boolean"
+ },
+ "VpcId": {
+ "type": "string"
+ },
+ "MapPublicIpOnLaunch": {
+ "type": "boolean"
+ },
+ "NetworkAclAssociationId": {
+ "type": "string"
+ },
+ "AvailabilityZone": {
+ "type": "string"
+ },
+ "AvailabilityZoneId": {
+ "type": "string"
+ },
+ "CidrBlock": {
+ "type": "string"
+ },
+ "SubnetId": {
+ "type": "string"
+ },
+ "Ipv6CidrBlocks": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Ipv6CidrBlock": {
+ "type": "string"
+ },
+ "OutpostArn": {
+ "type": "string"
+ },
+ "Ipv6Native": {
+ "type": "boolean"
+ },
+ "EnableDns64": {
+ "type": "boolean"
+ },
+ "PrivateDnsNameOptionsOnLaunch": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "HostnameType": {
+ "type": "string"
+ },
+ "EnableResourceNameDnsARecord": {
+ "type": "boolean"
+ },
+ "EnableResourceNameDnsAAAARecord": {
+ "type": "boolean"
+ }
+ }
+ },
+ "Tags": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ }
+ },
+ "tagging": {
+ "taggable": true,
+ "tagOnCreate": true,
+ "tagUpdatable": true,
+ "cloudFormationSystemTags": true,
+ "tagProperty": "/properties/Tags"
+ },
+ "definitions": {
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Value": {
+ "type": "string"
+ },
+ "Key": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ }
+ },
+ "required": [
+ "VpcId"
+ ],
+ "createOnlyProperties": [
+ "/properties/VpcId",
+ "/properties/AvailabilityZone",
+ "/properties/AvailabilityZoneId",
+ "/properties/CidrBlock",
+ "/properties/OutpostArn",
+ "/properties/Ipv6Native"
+ ],
+ "conditionalCreateOnlyProperties": [
+ "/properties/Ipv6CidrBlock"
+ ],
+ "primaryIdentifier": [
+ "/properties/SubnetId"
+ ],
+ "readOnlyProperties": [
+ "/properties/NetworkAclAssociationId",
+ "/properties/SubnetId",
+ "/properties/Ipv6CidrBlocks"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "ec2:DescribeSubnets",
+ "ec2:CreateSubnet",
+ "ec2:CreateTags",
+ "ec2:ModifySubnetAttribute"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "ec2:DescribeSubnets",
+ "ec2:DescribeNetworkAcls"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "ec2:DescribeSubnets",
+ "ec2:ModifySubnetAttribute",
+ "ec2:CreateTags",
+ "ec2:DeleteTags",
+ "ec2:AssociateSubnetCidrBlock",
+ "ec2:DisassociateSubnetCidrBlock"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "ec2:DescribeSubnets",
+ "ec2:DeleteSubnet"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "ec2:DescribeSubnets",
+ "ec2:DescribeNetworkAcls"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnet_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnet_plugin.py
new file mode 100644
index 0000000000000..65349afd2f656
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnet_plugin.py
@@ -0,0 +1,18 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2SubnetProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::Subnet"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_subnet import EC2SubnetProvider
+
+ self.factory = EC2SubnetProvider
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnetroutetableassociation.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnetroutetableassociation.py
new file mode 100644
index 0000000000000..d07bbdcb6665e
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnetroutetableassociation.py
@@ -0,0 +1,142 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EC2SubnetRouteTableAssociationProperties(TypedDict):
+ RouteTableId: Optional[str]
+ SubnetId: Optional[str]
+ Id: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EC2SubnetRouteTableAssociationProvider(
+ ResourceProvider[EC2SubnetRouteTableAssociationProperties]
+):
+ TYPE = "AWS::EC2::SubnetRouteTableAssociation" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2SubnetRouteTableAssociationProperties],
+ ) -> ProgressEvent[EC2SubnetRouteTableAssociationProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - RouteTableId
+ - SubnetId
+
+ Create-only properties:
+ - /properties/SubnetId
+ - /properties/RouteTableId
+
+ Read-only properties:
+ - /properties/Id
+
+ IAM permissions required:
+ - ec2:AssociateRouteTable
+ - ec2:ReplaceRouteTableAssociation
+ - ec2:DescribeSubnets
+ - ec2:DescribeRouteTables
+
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+
+ # TODO: validations
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ # this is the first time this callback is invoked
+ # TODO: defaults
+ # TODO: idempotency
+ model["Id"] = ec2.associate_route_table(
+ RouteTableId=model["RouteTableId"], SubnetId=model["SubnetId"]
+ )["AssociationId"]
+ request.custom_context[REPEATED_INVOCATION] = True
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ # we need to check association status
+ route_table = ec2.describe_route_tables(RouteTableIds=[model["RouteTableId"]])[
+ "RouteTables"
+ ][0]
+ for association in route_table["Associations"]:
+ if association["RouteTableAssociationId"] == model["Id"]:
+ # if it is showing up here, it's associated
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EC2SubnetRouteTableAssociationProperties],
+ ) -> ProgressEvent[EC2SubnetRouteTableAssociationProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - ec2:DescribeRouteTables
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2SubnetRouteTableAssociationProperties],
+ ) -> ProgressEvent[EC2SubnetRouteTableAssociationProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - ec2:DisassociateRouteTable
+ - ec2:DescribeSubnets
+ - ec2:DescribeRouteTables
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+ # TODO add async
+ try:
+ ec2.disassociate_route_table(AssociationId=model["Id"])
+ except ec2.exceptions.ClientError:
+ pass
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[EC2SubnetRouteTableAssociationProperties],
+ ) -> ProgressEvent[EC2SubnetRouteTableAssociationProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnetroutetableassociation.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnetroutetableassociation.schema.json
new file mode 100644
index 0000000000000..d0dab1cba2a02
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnetroutetableassociation.schema.json
@@ -0,0 +1,64 @@
+{
+ "typeName": "AWS::EC2::SubnetRouteTableAssociation",
+ "description": "Resource Type definition for AWS::EC2::SubnetRouteTableAssociation",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-ec2.git",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "type": "string"
+ },
+ "RouteTableId": {
+ "type": "string"
+ },
+ "SubnetId": {
+ "type": "string"
+ }
+ },
+ "tagging": {
+ "taggable": false,
+ "tagOnCreate": false,
+ "tagUpdatable": false,
+ "cloudFormationSystemTags": false
+ },
+ "required": [
+ "RouteTableId",
+ "SubnetId"
+ ],
+ "createOnlyProperties": [
+ "/properties/SubnetId",
+ "/properties/RouteTableId"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "ec2:AssociateRouteTable",
+ "ec2:ReplaceRouteTableAssociation",
+ "ec2:DescribeSubnets",
+ "ec2:DescribeRouteTables"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "ec2:DescribeRouteTables"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "ec2:DisassociateRouteTable",
+ "ec2:DescribeSubnets",
+ "ec2:DescribeRouteTables"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "ec2:DescribeRouteTables"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnetroutetableassociation_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnetroutetableassociation_plugin.py
new file mode 100644
index 0000000000000..6841f27741847
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_subnetroutetableassociation_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2SubnetRouteTableAssociationProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::SubnetRouteTableAssociation"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_subnetroutetableassociation import (
+ EC2SubnetRouteTableAssociationProvider,
+ )
+
+ self.factory = EC2SubnetRouteTableAssociationProvider
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgateway.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgateway.py
new file mode 100644
index 0000000000000..4a4b5825966cc
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgateway.py
@@ -0,0 +1,144 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EC2TransitGatewayProperties(TypedDict):
+ AmazonSideAsn: Optional[int]
+ AssociationDefaultRouteTableId: Optional[str]
+ AutoAcceptSharedAttachments: Optional[str]
+ DefaultRouteTableAssociation: Optional[str]
+ DefaultRouteTablePropagation: Optional[str]
+ Description: Optional[str]
+ DnsSupport: Optional[str]
+ Id: Optional[str]
+ MulticastSupport: Optional[str]
+ PropagationDefaultRouteTableId: Optional[str]
+ Tags: Optional[list[Tag]]
+ TransitGatewayCidrBlocks: Optional[list[str]]
+ VpnEcmpSupport: Optional[str]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EC2TransitGatewayProvider(ResourceProvider[EC2TransitGatewayProperties]):
+ TYPE = "AWS::EC2::TransitGateway" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2TransitGatewayProperties],
+ ) -> ProgressEvent[EC2TransitGatewayProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+
+
+ Create-only properties:
+ - /properties/AmazonSideAsn
+ - /properties/MulticastSupport
+
+ Read-only properties:
+ - /properties/Id
+
+ IAM permissions required:
+ - ec2:CreateTransitGateway
+ - ec2:CreateTags
+
+ """
+ model = request.desired_state
+ create_params = {
+ "Options": util.select_attributes(
+ model,
+ [
+ "AmazonSideAsn",
+ "AssociationDefaultRouteTableId",
+ "AutoAcceptSharedAttachments",
+ "DefaultRouteTableAssociation",
+ "DefaultRouteTablePropagation",
+ "DnsSupport",
+ "MulticastSupport",
+ "PropagationDefaultRouteTableId",
+ "TransitGatewayCidrBlocks",
+ "VpnEcmpSupport",
+ ],
+ )
+ }
+
+ if model.get("Description"):
+ create_params["Description"] = model["Description"]
+
+ if model.get("Tags", []):
+ create_params["TagSpecifications"] = [
+ {"ResourceType": "transit-gateway", "Tags": model["Tags"]}
+ ]
+
+ response = request.aws_client_factory.ec2.create_transit_gateway(**create_params)
+ model["Id"] = response["TransitGateway"]["TransitGatewayId"]
+
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ def read(
+ self,
+ request: ResourceRequest[EC2TransitGatewayProperties],
+ ) -> ProgressEvent[EC2TransitGatewayProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - ec2:DescribeTransitGateways
+ - ec2:DescribeTags
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2TransitGatewayProperties],
+ ) -> ProgressEvent[EC2TransitGatewayProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - ec2:DeleteTransitGateway
+ - ec2:DeleteTags
+ """
+ model = request.desired_state
+ request.aws_client_factory.ec2.delete_transit_gateway(TransitGatewayId=model["Id"])
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model={},
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[EC2TransitGatewayProperties],
+ ) -> ProgressEvent[EC2TransitGatewayProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - ec2:ModifyTransitGateway
+ - ec2:DeleteTags
+ - ec2:CreateTags
+ - ec2:ModifyTransitGatewayOptions
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgateway.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgateway.schema.json
new file mode 100644
index 0000000000000..afa8ae6ecd09f
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgateway.schema.json
@@ -0,0 +1,118 @@
+{
+ "typeName": "AWS::EC2::TransitGateway",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-transitgateway",
+ "description": "Resource Type definition for AWS::EC2::TransitGateway",
+ "additionalProperties": false,
+ "properties": {
+ "DefaultRouteTablePropagation": {
+ "type": "string"
+ },
+ "Description": {
+ "type": "string"
+ },
+ "AutoAcceptSharedAttachments": {
+ "type": "string"
+ },
+ "DefaultRouteTableAssociation": {
+ "type": "string"
+ },
+ "Id": {
+ "type": "string"
+ },
+ "VpnEcmpSupport": {
+ "type": "string"
+ },
+ "DnsSupport": {
+ "type": "string"
+ },
+ "MulticastSupport": {
+ "type": "string"
+ },
+ "AmazonSideAsn": {
+ "type": "integer",
+ "format": "int64"
+ },
+ "TransitGatewayCidrBlocks": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "Tags": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "AssociationDefaultRouteTableId": {
+ "type": "string"
+ },
+ "PropagationDefaultRouteTableId": {
+ "type": "string"
+ }
+ },
+ "definitions": {
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Value": {
+ "type": "string"
+ },
+ "Key": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ }
+ },
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ],
+ "createOnlyProperties": [
+ "/properties/AmazonSideAsn",
+ "/properties/MulticastSupport"
+ ],
+ "taggable": true,
+ "handlers": {
+ "create": {
+ "permissions": [
+ "ec2:CreateTransitGateway",
+ "ec2:CreateTags"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "ec2:DescribeTransitGateways",
+ "ec2:DescribeTags"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "ec2:DeleteTransitGateway",
+ "ec2:DeleteTags"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "ec2:ModifyTransitGateway",
+ "ec2:DeleteTags",
+ "ec2:CreateTags",
+ "ec2:ModifyTransitGatewayOptions"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "ec2:DescribeTransitGateways",
+ "ec2:DescribeTags"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgateway_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgateway_plugin.py
new file mode 100644
index 0000000000000..eac947d512bd5
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgateway_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2TransitGatewayProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::TransitGateway"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_transitgateway import (
+ EC2TransitGatewayProvider,
+ )
+
+ self.factory = EC2TransitGatewayProvider
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgatewayattachment.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgatewayattachment.py
new file mode 100644
index 0000000000000..59aac3a6a15d4
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgatewayattachment.py
@@ -0,0 +1,131 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EC2TransitGatewayAttachmentProperties(TypedDict):
+ SubnetIds: Optional[list[str]]
+ TransitGatewayId: Optional[str]
+ VpcId: Optional[str]
+ Id: Optional[str]
+ Options: Optional[dict]
+ Tags: Optional[list[Tag]]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EC2TransitGatewayAttachmentProvider(ResourceProvider[EC2TransitGatewayAttachmentProperties]):
+ TYPE = "AWS::EC2::TransitGatewayAttachment" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2TransitGatewayAttachmentProperties],
+ ) -> ProgressEvent[EC2TransitGatewayAttachmentProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - VpcId
+ - SubnetIds
+ - TransitGatewayId
+
+ Create-only properties:
+ - /properties/TransitGatewayId
+ - /properties/VpcId
+
+ Read-only properties:
+ - /properties/Id
+
+ IAM permissions required:
+ - ec2:CreateTransitGatewayVpcAttachment
+ - ec2:CreateTags
+
+ """
+ model = request.desired_state
+ create_params = util.select_attributes(
+ model, ["SubnetIds", "TransitGatewayId", "VpcId", "Options"]
+ )
+
+ if model.get("Tags", []):
+ create_params["TagSpecifications"] = [
+ {"ResourceType": "transit-gateway-attachment", "Tags": model["Tags"]}
+ ]
+
+ result = request.aws_client_factory.ec2.create_transit_gateway_vpc_attachment(
+ **create_params
+ )
+ model["Id"] = result["TransitGatewayVpcAttachment"]["TransitGatewayAttachmentId"]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EC2TransitGatewayAttachmentProperties],
+ ) -> ProgressEvent[EC2TransitGatewayAttachmentProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - ec2:DescribeTransitGatewayAttachments
+ - ec2:DescribeTransitGatewayVpcAttachments
+ - ec2:DescribeTags
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2TransitGatewayAttachmentProperties],
+ ) -> ProgressEvent[EC2TransitGatewayAttachmentProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - ec2:DeleteTransitGatewayVpcAttachment
+ - ec2:DeleteTags
+ """
+ model = request.desired_state
+ request.aws_client_factory.ec2.delete_transit_gateway_vpc_attachment(
+ TransitGatewayAttachmentId=model["Id"]
+ )
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model={},
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[EC2TransitGatewayAttachmentProperties],
+ ) -> ProgressEvent[EC2TransitGatewayAttachmentProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - ec2:ModifyTransitGatewayVpcAttachment
+ - ec2:DescribeTransitGatewayVpcAttachments
+ - ec2:DeleteTags
+ - ec2:CreateTags
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgatewayattachment.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgatewayattachment.schema.json
new file mode 100644
index 0000000000000..075af98c71c9a
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgatewayattachment.schema.json
@@ -0,0 +1,128 @@
+{
+ "typeName": "AWS::EC2::TransitGatewayAttachment",
+ "description": "Resource Type definition for AWS::EC2::TransitGatewayAttachment",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-transitgateway",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "type": "string"
+ },
+ "TransitGatewayId": {
+ "type": "string"
+ },
+ "VpcId": {
+ "type": "string"
+ },
+ "SubnetIds": {
+ "type": "array",
+ "insertionOrder": false,
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Tags": {
+ "type": "array",
+ "insertionOrder": false,
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "Options": {
+ "description": "The options for the transit gateway vpc attachment.",
+ "type": "object",
+ "properties": {
+ "DnsSupport": {
+ "description": "Indicates whether to enable DNS Support for Vpc Attachment. Valid Values: enable | disable",
+ "type": "string"
+ },
+ "Ipv6Support": {
+ "description": "Indicates whether to enable Ipv6 Support for Vpc Attachment. Valid Values: enable | disable",
+ "type": "string"
+ },
+ "ApplianceModeSupport": {
+ "description": "Indicates whether to enable Ipv6 Support for Vpc Attachment. Valid Values: enable | disable",
+ "type": "string"
+ }
+ },
+ "additionalProperties": false
+ }
+ },
+ "definitions": {
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Key": {
+ "type": "string"
+ },
+ "Value": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ }
+ },
+ "required": [
+ "VpcId",
+ "SubnetIds",
+ "TransitGatewayId"
+ ],
+ "tagging": {
+ "taggable": true,
+ "tagOnCreate": true,
+ "tagUpdatable": true,
+ "cloudFormationSystemTags": false,
+ "tagProperty": "/properties/Tags"
+ },
+ "createOnlyProperties": [
+ "/properties/TransitGatewayId",
+ "/properties/VpcId"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "ec2:CreateTransitGatewayVpcAttachment",
+ "ec2:CreateTags"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "ec2:DescribeTransitGatewayAttachments",
+ "ec2:DescribeTransitGatewayVpcAttachments",
+ "ec2:DescribeTags"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "ec2:DeleteTransitGatewayVpcAttachment",
+ "ec2:DeleteTags"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "ec2:DescribeTransitGatewayAttachments",
+ "ec2:DescribeTransitGatewayVpcAttachments",
+ "ec2:DescribeTags"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "ec2:ModifyTransitGatewayVpcAttachment",
+ "ec2:DescribeTransitGatewayVpcAttachments",
+ "ec2:DeleteTags",
+ "ec2:CreateTags"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgatewayattachment_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgatewayattachment_plugin.py
new file mode 100644
index 0000000000000..7b34a535f56e6
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_transitgatewayattachment_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2TransitGatewayAttachmentProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::TransitGatewayAttachment"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_transitgatewayattachment import (
+ EC2TransitGatewayAttachmentProvider,
+ )
+
+ self.factory = EC2TransitGatewayAttachmentProvider
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpc.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpc.py
new file mode 100644
index 0000000000000..3244a72b8b863
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpc.py
@@ -0,0 +1,242 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import logging
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+LOG = logging.getLogger(__name__)
+
+
+class EC2VPCProperties(TypedDict):
+ CidrBlock: Optional[str]
+ CidrBlockAssociations: Optional[list[str]]
+ DefaultNetworkAcl: Optional[str]
+ DefaultSecurityGroup: Optional[str]
+ EnableDnsHostnames: Optional[bool]
+ EnableDnsSupport: Optional[bool]
+ InstanceTenancy: Optional[str]
+ Ipv4IpamPoolId: Optional[str]
+ Ipv4NetmaskLength: Optional[int]
+ Ipv6CidrBlocks: Optional[list[str]]
+ Tags: Optional[list[Tag]]
+ VpcId: Optional[str]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+def _get_default_security_group_for_vpc(ec2_client, vpc_id: str) -> str:
+ sgs = ec2_client.describe_security_groups(
+ Filters=[
+ {"Name": "group-name", "Values": ["default"]},
+ {"Name": "vpc-id", "Values": [vpc_id]},
+ ]
+ )["SecurityGroups"]
+ if len(sgs) != 1:
+ raise Exception(f"There should only be one default group for this VPC ({vpc_id=})")
+ return sgs[0]["GroupId"]
+
+
+def _get_default_acl_for_vpc(ec2_client, vpc_id: str) -> str:
+ acls = ec2_client.describe_network_acls(
+ Filters=[
+ {"Name": "default", "Values": ["true"]},
+ {"Name": "vpc-id", "Values": [vpc_id]},
+ ]
+ )["NetworkAcls"]
+ if len(acls) != 1:
+ raise Exception(f"There should only be one default network ACL for this VPC ({vpc_id=})")
+ return acls[0]["NetworkAclId"]
+
+
+def generate_vpc_read_payload(ec2_client, vpc_id: str) -> EC2VPCProperties:
+ vpc = ec2_client.describe_vpcs(VpcIds=[vpc_id])["Vpcs"][0]
+
+ model = EC2VPCProperties(
+ **util.select_attributes(vpc, EC2VPCProvider.SCHEMA["properties"].keys())
+ )
+ model["CidrBlockAssociations"] = [
+ cba["AssociationId"] for cba in vpc["CidrBlockAssociationSet"]
+ ]
+ model["Ipv6CidrBlocks"] = [
+ ipv6_ass["Ipv6CidrBlock"] for ipv6_ass in vpc.get("Ipv6CidrBlockAssociationSet", [])
+ ]
+ model["DefaultNetworkAcl"] = _get_default_acl_for_vpc(ec2_client, model["VpcId"])
+ model["DefaultSecurityGroup"] = _get_default_security_group_for_vpc(ec2_client, model["VpcId"])
+ model["EnableDnsHostnames"] = ec2_client.describe_vpc_attribute(
+ Attribute="enableDnsHostnames", VpcId=vpc_id
+ )["EnableDnsHostnames"]["Value"]
+ model["EnableDnsSupport"] = ec2_client.describe_vpc_attribute(
+ Attribute="enableDnsSupport", VpcId=vpc_id
+ )["EnableDnsSupport"]["Value"]
+
+ return model
+
+
+class EC2VPCProvider(ResourceProvider[EC2VPCProperties]):
+ TYPE = "AWS::EC2::VPC" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2VPCProperties],
+ ) -> ProgressEvent[EC2VPCProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/VpcId
+
+ Create-only properties:
+ - /properties/CidrBlock
+ - /properties/Ipv4IpamPoolId
+ - /properties/Ipv4NetmaskLength
+
+ Read-only properties:
+ - /properties/CidrBlockAssociations
+ - /properties/DefaultNetworkAcl
+ - /properties/DefaultSecurityGroup
+ - /properties/Ipv6CidrBlocks
+ - /properties/VpcId
+
+ IAM permissions required:
+ - ec2:CreateVpc
+ - ec2:DescribeVpcs
+ - ec2:ModifyVpcAttribute
+
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+ # TODO: validations
+
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ # this is the first time this callback is invoked
+ # TODO: defaults
+ # TODO: idempotency
+ params = util.select_attributes(
+ model,
+ ["CidrBlock", "InstanceTenancy"],
+ )
+ if model.get("Tags"):
+ tags = [{"ResourceType": "vpc", "Tags": model.get("Tags")}]
+ params["TagSpecifications"] = tags
+
+ response = ec2.create_vpc(**params)
+
+ request.custom_context[REPEATED_INVOCATION] = True
+ model = generate_vpc_read_payload(ec2, response["Vpc"]["VpcId"])
+
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ response = ec2.describe_vpcs(VpcIds=[model["VpcId"]])["Vpcs"][0]
+ if response["State"] == "pending":
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EC2VPCProperties],
+ ) -> ProgressEvent[EC2VPCProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - ec2:DescribeVpcs
+ - ec2:DescribeSecurityGroups
+ - ec2:DescribeNetworkAcls
+ - ec2:DescribeVpcAttribute
+ """
+ ec2 = request.aws_client_factory.ec2
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=generate_vpc_read_payload(ec2, request.desired_state["VpcId"]),
+ custom_context=request.custom_context,
+ )
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2VPCProperties],
+ ) -> ProgressEvent[EC2VPCProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - ec2:DeleteVpc
+ - ec2:DescribeVpcs
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+
+ # remove routes and route tables first
+ resp = ec2.describe_route_tables(
+ Filters=[
+ {"Name": "vpc-id", "Values": [model["VpcId"]]},
+ {"Name": "association.main", "Values": ["false"]},
+ ]
+ )
+ for rt in resp["RouteTables"]:
+ for assoc in rt.get("Associations", []):
+ # skipping Main association (upstream moto includes default association that cannot be deleted)
+ if assoc.get("Main"):
+ continue
+ ec2.disassociate_route_table(AssociationId=assoc["RouteTableAssociationId"])
+ ec2.delete_route_table(RouteTableId=rt["RouteTableId"])
+
+ # TODO security groups, gateways and other attached resources need to be deleted as well
+ ec2.delete_vpc(VpcId=model["VpcId"])
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ def update(
+ self,
+ request: ResourceRequest[EC2VPCProperties],
+ ) -> ProgressEvent[EC2VPCProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - ec2:CreateTags
+ - ec2:ModifyVpcAttribute
+ - ec2:DeleteTags
+ - ec2:ModifyVpcTenancy
+ """
+ raise NotImplementedError
+
+ def list(
+ self,
+ request: ResourceRequest[EC2VPCProperties],
+ ) -> ProgressEvent[EC2VPCProperties]:
+ resources = request.aws_client_factory.ec2.describe_vpcs()
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_models=[
+ EC2VPCProperties(VpcId=resource["VpcId"]) for resource in resources["Vpcs"]
+ ],
+ )
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpc.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpc.schema.json
new file mode 100644
index 0000000000000..0f8838c52d008
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpc.schema.json
@@ -0,0 +1,155 @@
+{
+ "typeName": "AWS::EC2::VPC",
+ "description": "Resource Type definition for AWS::EC2::VPC",
+ "additionalProperties": false,
+ "properties": {
+ "VpcId": {
+ "type": "string",
+ "description": "The Id for the model."
+ },
+ "CidrBlock": {
+ "type": "string",
+ "description": "The primary IPv4 CIDR block for the VPC."
+ },
+ "CidrBlockAssociations": {
+ "type": "array",
+ "description": "A list of IPv4 CIDR block association IDs for the VPC.",
+ "uniqueItems": false,
+ "insertionOrder": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "DefaultNetworkAcl": {
+ "type": "string",
+ "insertionOrder": false,
+ "description": "The default network ACL ID that is associated with the VPC."
+ },
+ "DefaultSecurityGroup": {
+ "type": "string",
+ "insertionOrder": false,
+ "description": "The default security group ID that is associated with the VPC."
+ },
+ "Ipv6CidrBlocks": {
+ "type": "array",
+ "description": "A list of IPv6 CIDR blocks that are associated with the VPC.",
+ "uniqueItems": false,
+ "insertionOrder": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "EnableDnsHostnames": {
+ "type": "boolean",
+ "description": "Indicates whether the instances launched in the VPC get DNS hostnames. If enabled, instances in the VPC get DNS hostnames; otherwise, they do not. Disabled by default for nondefault VPCs."
+ },
+ "EnableDnsSupport": {
+ "type": "boolean",
+ "description": "Indicates whether the DNS resolution is supported for the VPC. If enabled, queries to the Amazon provided DNS server at the 169.254.169.253 IP address, or the reserved IP address at the base of the VPC network range \"plus two\" succeed. If disabled, the Amazon provided DNS service in the VPC that resolves public DNS hostnames to IP addresses is not enabled. Enabled by default."
+ },
+ "InstanceTenancy": {
+ "type": "string",
+ "description": "The allowed tenancy of instances launched into the VPC.\n\n\"default\": An instance launched into the VPC runs on shared hardware by default, unless you explicitly specify a different tenancy during instance launch.\n\n\"dedicated\": An instance launched into the VPC is a Dedicated Instance by default, unless you explicitly specify a tenancy of host during instance launch. You cannot specify a tenancy of default during instance launch.\n\nUpdating InstanceTenancy requires no replacement only if you are updating its value from \"dedicated\" to \"default\". Updating InstanceTenancy from \"default\" to \"dedicated\" requires replacement."
+ },
+ "Ipv4IpamPoolId": {
+ "type": "string",
+ "description": "The ID of an IPv4 IPAM pool you want to use for allocating this VPC's CIDR"
+ },
+ "Ipv4NetmaskLength": {
+ "type": "integer",
+ "description": "The netmask length of the IPv4 CIDR you want to allocate to this VPC from an Amazon VPC IP Address Manager (IPAM) pool"
+ },
+ "Tags": {
+ "type": "array",
+ "description": "The tags for the VPC.",
+ "uniqueItems": false,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ }
+ },
+ "definitions": {
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Key": {
+ "type": "string"
+ },
+ "Value": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ }
+ },
+ "tagging": {
+ "taggable": true,
+ "tagOnCreate": true,
+ "tagUpdatable": true,
+ "cloudFormationSystemTags": true,
+ "tagProperty": "/properties/Tags"
+ },
+ "createOnlyProperties": [
+ "/properties/CidrBlock",
+ "/properties/Ipv4IpamPoolId",
+ "/properties/Ipv4NetmaskLength"
+ ],
+ "conditionalCreateOnlyProperties": [
+ "/properties/InstanceTenancy"
+ ],
+ "readOnlyProperties": [
+ "/properties/CidrBlockAssociations",
+ "/properties/DefaultNetworkAcl",
+ "/properties/DefaultSecurityGroup",
+ "/properties/Ipv6CidrBlocks",
+ "/properties/VpcId"
+ ],
+ "primaryIdentifier": [
+ "/properties/VpcId"
+ ],
+ "writeOnlyProperties": [
+ "/properties/Ipv4IpamPoolId",
+ "/properties/Ipv4NetmaskLength"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "ec2:CreateVpc",
+ "ec2:DescribeVpcs",
+ "ec2:ModifyVpcAttribute"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "ec2:DescribeVpcs",
+ "ec2:DescribeSecurityGroups",
+ "ec2:DescribeNetworkAcls",
+ "ec2:DescribeVpcAttribute"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "ec2:CreateTags",
+ "ec2:ModifyVpcAttribute",
+ "ec2:DeleteTags",
+ "ec2:ModifyVpcTenancy"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "ec2:DeleteVpc",
+ "ec2:DescribeVpcs"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "ec2:DescribeVpcs"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpc_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpc_plugin.py
new file mode 100644
index 0000000000000..3f4aea38386f0
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpc_plugin.py
@@ -0,0 +1,18 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2VPCProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::VPC"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_vpc import EC2VPCProvider
+
+ self.factory = EC2VPCProvider
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcendpoint.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcendpoint.py
new file mode 100644
index 0000000000000..420efcb8029ee
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcendpoint.py
@@ -0,0 +1,180 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EC2VPCEndpointProperties(TypedDict):
+ ServiceName: Optional[str]
+ VpcId: Optional[str]
+ CreationTimestamp: Optional[str]
+ DnsEntries: Optional[list[str]]
+ Id: Optional[str]
+ NetworkInterfaceIds: Optional[list[str]]
+ PolicyDocument: Optional[str | dict]
+ PrivateDnsEnabled: Optional[bool]
+ RouteTableIds: Optional[list[str]]
+ SecurityGroupIds: Optional[list[str]]
+ SubnetIds: Optional[list[str]]
+ VpcEndpointType: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EC2VPCEndpointProvider(ResourceProvider[EC2VPCEndpointProperties]):
+ TYPE = "AWS::EC2::VPCEndpoint" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2VPCEndpointProperties],
+ ) -> ProgressEvent[EC2VPCEndpointProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - VpcId
+ - ServiceName
+
+ Create-only properties:
+ - /properties/ServiceName
+ - /properties/VpcEndpointType
+ - /properties/VpcId
+
+ Read-only properties:
+ - /properties/NetworkInterfaceIds
+ - /properties/CreationTimestamp
+ - /properties/DnsEntries
+ - /properties/Id
+
+ IAM permissions required:
+ - ec2:CreateVpcEndpoint
+ - ec2:DescribeVpcEndpoints
+
+ """
+ model = request.desired_state
+ create_params = util.select_attributes(
+ model,
+ [
+ "PolidyDocument",
+ "PrivateDnsEnabled",
+ "RouteTablesIds",
+ "SecurityGroupIds",
+ "ServiceName",
+ "SubnetIds",
+ "VpcEndpointType",
+ "VpcId",
+ ],
+ )
+
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ response = request.aws_client_factory.ec2.create_vpc_endpoint(**create_params)
+ model["Id"] = response["VpcEndpoint"]["VpcEndpointId"]
+ model["DnsEntries"] = response["VpcEndpoint"]["DnsEntries"]
+ model["CreationTimestamp"] = response["VpcEndpoint"]["CreationTimestamp"]
+ model["NetworkInterfaceIds"] = response["VpcEndpoint"]["NetworkInterfaceIds"]
+ request.custom_context[REPEATED_INVOCATION] = True
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ response = request.aws_client_factory.ec2.describe_vpc_endpoints(
+ VpcEndpointIds=[model["Id"]]
+ )
+ if not response["VpcEndpoints"]:
+ return ProgressEvent(
+ status=OperationStatus.FAILED,
+ resource_model=model,
+ custom_context=request.custom_context,
+ message="Resource not found after creation",
+ )
+
+ state = response["VpcEndpoints"][0][
+ "State"
+ ].lower() # API specifies capital but lowercase is returned
+ match state:
+ case "available":
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+ case "pending":
+ return ProgressEvent(status=OperationStatus.IN_PROGRESS, resource_model=model)
+ case "pendingacceptance":
+ return ProgressEvent(status=OperationStatus.IN_PROGRESS, resource_model=model)
+ case _:
+ return ProgressEvent(
+ status=OperationStatus.FAILED,
+ resource_model=model,
+ message=f"Invalid state '{state}' for resource",
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EC2VPCEndpointProperties],
+ ) -> ProgressEvent[EC2VPCEndpointProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - ec2:DescribeVpcEndpoints
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2VPCEndpointProperties],
+ ) -> ProgressEvent[EC2VPCEndpointProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - ec2:DeleteVpcEndpoints
+ - ec2:DescribeVpcEndpoints
+ """
+ model = request.previous_state
+ response = request.aws_client_factory.ec2.describe_vpc_endpoints(
+ VpcEndpointIds=[model["Id"]]
+ )
+
+ if not response["VpcEndpoints"]:
+ return ProgressEvent(
+ status=OperationStatus.FAILED,
+ resource_model=model,
+ message="Resource not found for deletion",
+ )
+
+ state = response["VpcEndpoints"][0]["State"].lower()
+ match state:
+ case "deleted":
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+ case "deleting":
+ return ProgressEvent(status=OperationStatus.IN_PROGRESS, resource_model=model)
+ case _:
+ request.aws_client_factory.ec2.delete_vpc_endpoints(VpcEndpointIds=[model["Id"]])
+ return ProgressEvent(status=OperationStatus.IN_PROGRESS, resource_model=model)
+
+ def update(
+ self,
+ request: ResourceRequest[EC2VPCEndpointProperties],
+ ) -> ProgressEvent[EC2VPCEndpointProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - ec2:ModifyVpcEndpoint
+ - ec2:DescribeVpcEndpoints
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcendpoint.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcendpoint.schema.json
new file mode 100644
index 0000000000000..c8dcc84644d4c
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcendpoint.schema.json
@@ -0,0 +1,140 @@
+{
+ "typeName": "AWS::EC2::VPCEndpoint",
+ "description": "Resource Type definition for AWS::EC2::VPCEndpoint",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "type": "string"
+ },
+ "CreationTimestamp": {
+ "type": "string"
+ },
+ "DnsEntries": {
+ "type": "array",
+ "uniqueItems": false,
+ "insertionOrder": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "NetworkInterfaceIds": {
+ "type": "array",
+ "uniqueItems": false,
+ "insertionOrder": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "PolicyDocument": {
+ "type": [
+ "string",
+ "object"
+ ],
+ "description": "A policy to attach to the endpoint that controls access to the service."
+ },
+ "PrivateDnsEnabled": {
+ "type": "boolean",
+ "description": "Indicate whether to associate a private hosted zone with the specified VPC."
+ },
+ "RouteTableIds": {
+ "type": "array",
+ "description": "One or more route table IDs.",
+ "uniqueItems": true,
+ "insertionOrder": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "SecurityGroupIds": {
+ "type": "array",
+ "description": "The ID of one or more security groups to associate with the endpoint network interface.",
+ "uniqueItems": true,
+ "insertionOrder": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "ServiceName": {
+ "type": "string",
+ "description": "The service name."
+ },
+ "SubnetIds": {
+ "type": "array",
+ "description": "The ID of one or more subnets in which to create an endpoint network interface.",
+ "uniqueItems": true,
+ "insertionOrder": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "VpcEndpointType": {
+ "type": "string",
+ "enum": [
+ "Interface",
+ "Gateway",
+ "GatewayLoadBalancer"
+ ]
+ },
+ "VpcId": {
+ "type": "string",
+ "description": "The ID of the VPC in which the endpoint will be used."
+ }
+ },
+ "required": [
+ "VpcId",
+ "ServiceName"
+ ],
+ "readOnlyProperties": [
+ "/properties/NetworkInterfaceIds",
+ "/properties/CreationTimestamp",
+ "/properties/DnsEntries",
+ "/properties/Id"
+ ],
+ "createOnlyProperties": [
+ "/properties/ServiceName",
+ "/properties/VpcEndpointType",
+ "/properties/VpcId"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "tagging": {
+ "taggable": false,
+ "tagOnCreate": false,
+ "tagUpdatable": false,
+ "cloudFormationSystemTags": false
+ },
+ "handlers": {
+ "create": {
+ "permissions": [
+ "ec2:CreateVpcEndpoint",
+ "ec2:DescribeVpcEndpoints"
+ ],
+ "timeoutInMinutes": 210
+ },
+ "read": {
+ "permissions": [
+ "ec2:DescribeVpcEndpoints"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "ec2:ModifyVpcEndpoint",
+ "ec2:DescribeVpcEndpoints"
+ ],
+ "timeoutInMinutes": 210
+ },
+ "delete": {
+ "permissions": [
+ "ec2:DeleteVpcEndpoints",
+ "ec2:DescribeVpcEndpoints"
+ ],
+ "timeoutInMinutes": 210
+ },
+ "list": {
+ "permissions": [
+ "ec2:DescribeVpcEndpoints"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcendpoint_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcendpoint_plugin.py
new file mode 100644
index 0000000000000..e0e1d228a95de
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcendpoint_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2VPCEndpointProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::VPCEndpoint"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_vpcendpoint import (
+ EC2VPCEndpointProvider,
+ )
+
+ self.factory = EC2VPCEndpointProvider
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcgatewayattachment.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcgatewayattachment.py
new file mode 100644
index 0000000000000..8f4656e317b7f
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcgatewayattachment.py
@@ -0,0 +1,116 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EC2VPCGatewayAttachmentProperties(TypedDict):
+ VpcId: Optional[str]
+ Id: Optional[str]
+ InternetGatewayId: Optional[str]
+ VpnGatewayId: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EC2VPCGatewayAttachmentProvider(ResourceProvider[EC2VPCGatewayAttachmentProperties]):
+ TYPE = "AWS::EC2::VPCGatewayAttachment" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EC2VPCGatewayAttachmentProperties],
+ ) -> ProgressEvent[EC2VPCGatewayAttachmentProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - VpcId
+
+
+
+ Read-only properties:
+ - /properties/Id
+
+
+
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+ # TODO: validations
+ if model.get("InternetGatewayId"):
+ ec2.attach_internet_gateway(
+ InternetGatewayId=model["InternetGatewayId"], VpcId=model["VpcId"]
+ )
+ else:
+ ec2.attach_vpn_gateway(VpnGatewayId=model["VpnGatewayId"], VpcId=model["VpcId"])
+
+ # TODO: idempotency
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EC2VPCGatewayAttachmentProperties],
+ ) -> ProgressEvent[EC2VPCGatewayAttachmentProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EC2VPCGatewayAttachmentProperties],
+ ) -> ProgressEvent[EC2VPCGatewayAttachmentProperties]:
+ """
+ Delete a resource
+
+
+ """
+ model = request.desired_state
+ ec2 = request.aws_client_factory.ec2
+ # TODO: validations
+ try:
+ if model.get("InternetGatewayId"):
+ ec2.detach_internet_gateway(
+ InternetGatewayId=model["InternetGatewayId"], VpcId=model["VpcId"]
+ )
+ else:
+ ec2.detach_vpn_gateway(VpnGatewayId=model["VpnGatewayId"], VpcId=model["VpcId"])
+ except ec2.exceptions.ClientError:
+ pass
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[EC2VPCGatewayAttachmentProperties],
+ ) -> ProgressEvent[EC2VPCGatewayAttachmentProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcgatewayattachment.schema.json b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcgatewayattachment.schema.json
new file mode 100644
index 0000000000000..856548db1f173
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcgatewayattachment.schema.json
@@ -0,0 +1,28 @@
+{
+ "typeName": "AWS::EC2::VPCGatewayAttachment",
+ "description": "Resource Type definition for AWS::EC2::VPCGatewayAttachment",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "type": "string"
+ },
+ "InternetGatewayId": {
+ "type": "string"
+ },
+ "VpcId": {
+ "type": "string"
+ },
+ "VpnGatewayId": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "VpcId"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ]
+}
diff --git a/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcgatewayattachment_plugin.py b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcgatewayattachment_plugin.py
new file mode 100644
index 0000000000000..f210fa0ff8c1d
--- /dev/null
+++ b/localstack-core/localstack/services/ec2/resource_providers/aws_ec2_vpcgatewayattachment_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EC2VPCGatewayAttachmentProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::EC2::VPCGatewayAttachment"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ec2.resource_providers.aws_ec2_vpcgatewayattachment import (
+ EC2VPCGatewayAttachmentProvider,
+ )
+
+ self.factory = EC2VPCGatewayAttachmentProvider
diff --git a/localstack-core/localstack/services/ecr/__init__.py b/localstack-core/localstack/services/ecr/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/ecr/resource_providers/__init__.py b/localstack-core/localstack/services/ecr/resource_providers/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/ecr/resource_providers/aws_ecr_repository.py b/localstack-core/localstack/services/ecr/resource_providers/aws_ecr_repository.py
new file mode 100644
index 0000000000000..a42735467d146
--- /dev/null
+++ b/localstack-core/localstack/services/ecr/resource_providers/aws_ecr_repository.py
@@ -0,0 +1,169 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import logging
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.constants import AWS_REGION_US_EAST_1, DEFAULT_AWS_ACCOUNT_ID
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+from localstack.utils.aws import arns
+
+LOG = logging.getLogger(__name__)
+
+# simple mock state
+default_repos_per_stack = {}
+
+
+class ECRRepositoryProperties(TypedDict):
+ Arn: Optional[str]
+ EncryptionConfiguration: Optional[EncryptionConfiguration]
+ ImageScanningConfiguration: Optional[ImageScanningConfiguration]
+ ImageTagMutability: Optional[str]
+ LifecyclePolicy: Optional[LifecyclePolicy]
+ RepositoryName: Optional[str]
+ RepositoryPolicyText: Optional[dict | str]
+ RepositoryUri: Optional[str]
+ Tags: Optional[list[Tag]]
+
+
+class LifecyclePolicy(TypedDict):
+ LifecyclePolicyText: Optional[str]
+ RegistryId: Optional[str]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+class ImageScanningConfiguration(TypedDict):
+ ScanOnPush: Optional[bool]
+
+
+class EncryptionConfiguration(TypedDict):
+ EncryptionType: Optional[str]
+ KmsKey: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class ECRRepositoryProvider(ResourceProvider[ECRRepositoryProperties]):
+ TYPE = "AWS::ECR::Repository" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[ECRRepositoryProperties],
+ ) -> ProgressEvent[ECRRepositoryProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/RepositoryName
+
+ Create-only properties:
+ - /properties/RepositoryName
+ - /properties/EncryptionConfiguration
+ - /properties/EncryptionConfiguration/EncryptionType
+ - /properties/EncryptionConfiguration/KmsKey
+
+ Read-only properties:
+ - /properties/Arn
+ - /properties/RepositoryUri
+
+ IAM permissions required:
+ - ecr:CreateRepository
+ - ecr:PutLifecyclePolicy
+ - ecr:SetRepositoryPolicy
+ - ecr:TagResource
+ - kms:DescribeKey
+ - kms:CreateGrant
+ - kms:RetireGrant
+
+ """
+ model = request.desired_state
+
+ default_repos_per_stack[request.stack_name] = model["RepositoryName"]
+ LOG.warning(
+ "Creating a Mock ECR Repository for CloudFormation. This is only intended to be used for allowing a successful CDK bootstrap and does not provision any underlying ECR repository."
+ )
+ model.update(
+ {
+ "Arn": arns.ecr_repository_arn(
+ model["RepositoryName"], DEFAULT_AWS_ACCOUNT_ID, AWS_REGION_US_EAST_1
+ ),
+ "RepositoryUri": "http://localhost:4566",
+ "ImageTagMutability": "MUTABLE",
+ "ImageScanningConfiguration": {"scanOnPush": True},
+ }
+ )
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[ECRRepositoryProperties],
+ ) -> ProgressEvent[ECRRepositoryProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - ecr:DescribeRepositories
+ - ecr:GetLifecyclePolicy
+ - ecr:GetRepositoryPolicy
+ - ecr:ListTagsForResource
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[ECRRepositoryProperties],
+ ) -> ProgressEvent[ECRRepositoryProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - ecr:DeleteRepository
+ - kms:RetireGrant
+ """
+ if default_repos_per_stack.get(request.stack_name):
+ del default_repos_per_stack[request.stack_name]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=request.desired_state,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[ECRRepositoryProperties],
+ ) -> ProgressEvent[ECRRepositoryProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - ecr:PutLifecyclePolicy
+ - ecr:SetRepositoryPolicy
+ - ecr:TagResource
+ - ecr:UntagResource
+ - ecr:DeleteLifecyclePolicy
+ - ecr:DeleteRepositoryPolicy
+ - ecr:PutImageScanningConfiguration
+ - ecr:PutImageTagMutability
+ - kms:DescribeKey
+ - kms:CreateGrant
+ - kms:RetireGrant
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/ecr/resource_providers/aws_ecr_repository.schema.json b/localstack-core/localstack/services/ecr/resource_providers/aws_ecr_repository.schema.json
new file mode 100644
index 0000000000000..ef4f7c01e3a74
--- /dev/null
+++ b/localstack-core/localstack/services/ecr/resource_providers/aws_ecr_repository.schema.json
@@ -0,0 +1,210 @@
+{
+ "typeName": "AWS::ECR::Repository",
+ "description": "The AWS::ECR::Repository resource specifies an Amazon Elastic Container Registry (Amazon ECR) repository, where users can push and pull Docker images. For more information, see https://docs.aws.amazon.com/AmazonECR/latest/userguide/Repositories.html",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-ecr.git",
+ "definitions": {
+ "LifecyclePolicy": {
+ "type": "object",
+ "description": "The LifecyclePolicy property type specifies a lifecycle policy. For information about lifecycle policy syntax, see https://docs.aws.amazon.com/AmazonECR/latest/userguide/LifecyclePolicies.html",
+ "properties": {
+ "LifecyclePolicyText": {
+ "$ref": "#/definitions/LifecyclePolicyText"
+ },
+ "RegistryId": {
+ "$ref": "#/definitions/RegistryId"
+ }
+ },
+ "additionalProperties": false
+ },
+ "LifecyclePolicyText": {
+ "type": "string",
+ "description": "The JSON repository policy text to apply to the repository.",
+ "minLength": 100,
+ "maxLength": 30720
+ },
+ "RegistryId": {
+ "type": "string",
+ "description": "The AWS account ID associated with the registry that contains the repository. If you do not specify a registry, the default registry is assumed. ",
+ "minLength": 12,
+ "maxLength": 12,
+ "pattern": "^[0-9]{12}$"
+ },
+ "Tag": {
+ "description": "A key-value pair to associate with a resource.",
+ "type": "object",
+ "properties": {
+ "Key": {
+ "type": "string",
+ "description": "The key name of the tag. You can specify a value that is 1 to 127 Unicode characters in length and cannot be prefixed with aws:. You can use any of the following characters: the set of Unicode letters, digits, whitespace, _, ., /, =, +, and -. ",
+ "minLength": 1,
+ "maxLength": 127
+ },
+ "Value": {
+ "type": "string",
+ "description": "The value for the tag. You can specify a value that is 1 to 255 Unicode characters in length and cannot be prefixed with aws:. You can use any of the following characters: the set of Unicode letters, digits, whitespace, _, ., /, =, +, and -. ",
+ "minLength": 1,
+ "maxLength": 255
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ],
+ "additionalProperties": false
+ },
+ "ImageScanningConfiguration": {
+ "type": "object",
+ "description": "The image scanning configuration for the repository. This setting determines whether images are scanned for known vulnerabilities after being pushed to the repository.",
+ "properties": {
+ "ScanOnPush": {
+ "$ref": "#/definitions/ScanOnPush"
+ }
+ },
+ "additionalProperties": false
+ },
+ "ScanOnPush": {
+ "type": "boolean",
+ "description": "The setting that determines whether images are scanned after being pushed to a repository."
+ },
+ "EncryptionConfiguration": {
+ "type": "object",
+ "description": "The encryption configuration for the repository. This determines how the contents of your repository are encrypted at rest.\n\nBy default, when no encryption configuration is set or the AES256 encryption type is used, Amazon ECR uses server-side encryption with Amazon S3-managed encryption keys which encrypts your data at rest using an AES-256 encryption algorithm. This does not require any action on your part.\n\nFor more information, see https://docs.aws.amazon.com/AmazonECR/latest/userguide/encryption-at-rest.html",
+ "properties": {
+ "EncryptionType": {
+ "$ref": "#/definitions/EncryptionType"
+ },
+ "KmsKey": {
+ "$ref": "#/definitions/KmsKey"
+ }
+ },
+ "required": [
+ "EncryptionType"
+ ],
+ "additionalProperties": false
+ },
+ "EncryptionType": {
+ "type": "string",
+ "description": "The encryption type to use.",
+ "enum": [
+ "AES256",
+ "KMS"
+ ]
+ },
+ "KmsKey": {
+ "type": "string",
+ "description": "If you use the KMS encryption type, specify the CMK to use for encryption. The alias, key ID, or full ARN of the CMK can be specified. The key must exist in the same Region as the repository. If no key is specified, the default AWS managed CMK for Amazon ECR will be used.",
+ "minLength": 1,
+ "maxLength": 2048
+ }
+ },
+ "properties": {
+ "LifecyclePolicy": {
+ "$ref": "#/definitions/LifecyclePolicy"
+ },
+ "RepositoryName": {
+ "type": "string",
+ "description": "The name to use for the repository. The repository name may be specified on its own (such as nginx-web-app) or it can be prepended with a namespace to group the repository into a category (such as project-a/nginx-web-app). If you don't specify a name, AWS CloudFormation generates a unique physical ID and uses that ID for the repository name. For more information, see https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-properties-name.html.",
+ "minLength": 2,
+ "maxLength": 256,
+ "pattern": "^(?=.{2,256}$)((?:[a-z0-9]+(?:[._-][a-z0-9]+)*/)*[a-z0-9]+(?:[._-][a-z0-9]+)*)$"
+ },
+ "RepositoryPolicyText": {
+ "type": [
+ "object",
+ "string"
+ ],
+ "description": "The JSON repository policy text to apply to the repository. For more information, see https://docs.aws.amazon.com/AmazonECR/latest/userguide/RepositoryPolicyExamples.html in the Amazon Elastic Container Registry User Guide. "
+ },
+ "Tags": {
+ "type": "array",
+ "maxItems": 50,
+ "uniqueItems": true,
+ "insertionOrder": false,
+ "description": "An array of key-value pairs to apply to this resource.",
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "Arn": {
+ "type": "string"
+ },
+ "RepositoryUri": {
+ "type": "string"
+ },
+ "ImageTagMutability": {
+ "type": "string",
+ "description": "The image tag mutability setting for the repository.",
+ "enum": [
+ "MUTABLE",
+ "IMMUTABLE"
+ ]
+ },
+ "ImageScanningConfiguration": {
+ "$ref": "#/definitions/ImageScanningConfiguration"
+ },
+ "EncryptionConfiguration": {
+ "$ref": "#/definitions/EncryptionConfiguration"
+ }
+ },
+ "createOnlyProperties": [
+ "/properties/RepositoryName",
+ "/properties/EncryptionConfiguration",
+ "/properties/EncryptionConfiguration/EncryptionType",
+ "/properties/EncryptionConfiguration/KmsKey"
+ ],
+ "readOnlyProperties": [
+ "/properties/Arn",
+ "/properties/RepositoryUri"
+ ],
+ "primaryIdentifier": [
+ "/properties/RepositoryName"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "ecr:CreateRepository",
+ "ecr:PutLifecyclePolicy",
+ "ecr:SetRepositoryPolicy",
+ "ecr:TagResource",
+ "kms:DescribeKey",
+ "kms:CreateGrant",
+ "kms:RetireGrant"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "ecr:DescribeRepositories",
+ "ecr:GetLifecyclePolicy",
+ "ecr:GetRepositoryPolicy",
+ "ecr:ListTagsForResource"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "ecr:PutLifecyclePolicy",
+ "ecr:SetRepositoryPolicy",
+ "ecr:TagResource",
+ "ecr:UntagResource",
+ "ecr:DeleteLifecyclePolicy",
+ "ecr:DeleteRepositoryPolicy",
+ "ecr:PutImageScanningConfiguration",
+ "ecr:PutImageTagMutability",
+ "kms:DescribeKey",
+ "kms:CreateGrant",
+ "kms:RetireGrant"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "ecr:DeleteRepository",
+ "kms:RetireGrant"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "ecr:DescribeRepositories"
+ ]
+ }
+ },
+ "additionalProperties": false
+}
diff --git a/localstack-core/localstack/services/ecr/resource_providers/aws_ecr_repository_plugin.py b/localstack-core/localstack/services/ecr/resource_providers/aws_ecr_repository_plugin.py
new file mode 100644
index 0000000000000..7d7ba440a668d
--- /dev/null
+++ b/localstack-core/localstack/services/ecr/resource_providers/aws_ecr_repository_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class ECRRepositoryProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::ECR::Repository"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.ecr.resource_providers.aws_ecr_repository import (
+ ECRRepositoryProvider,
+ )
+
+ self.factory = ECRRepositoryProvider
diff --git a/localstack-core/localstack/services/edge.py b/localstack-core/localstack/services/edge.py
new file mode 100644
index 0000000000000..5c3ede66b65e5
--- /dev/null
+++ b/localstack-core/localstack/services/edge.py
@@ -0,0 +1,213 @@
+import argparse
+import logging
+import shlex
+import subprocess
+import sys
+from typing import List, Optional, TypeVar
+
+from localstack import config, constants
+from localstack.config import HostAndPort
+from localstack.constants import (
+ LOCALSTACK_ROOT_FOLDER,
+)
+from localstack.http import Router
+from localstack.http.dispatcher import Handler, handler_dispatcher
+from localstack.http.router import GreedyPathConverter
+from localstack.utils.collections import split_list_by
+from localstack.utils.net import get_free_tcp_port
+from localstack.utils.run import is_root, run
+from localstack.utils.server.tcp_proxy import TCPProxy
+from localstack.utils.threads import start_thread
+
+T = TypeVar("T")
+
+LOG = logging.getLogger(__name__)
+
+
+ROUTER: Router[Handler] = Router(
+ dispatcher=handler_dispatcher(), converters={"greedy_path": GreedyPathConverter}
+)
+"""This special Router is part of the edge proxy. Use the router to inject custom handlers that are handled before
+the actual AWS service call is made."""
+
+
+def do_start_edge(
+ listen: HostAndPort | List[HostAndPort], use_ssl: bool, asynchronous: bool = False
+):
+ from localstack.aws.serving.edge import serve_gateway
+
+ return serve_gateway(listen, use_ssl, asynchronous)
+
+
+def can_use_sudo():
+ try:
+ run("sudo -n -v", print_error=False)
+ return True
+ except Exception:
+ return False
+
+
+def ensure_can_use_sudo():
+ if not is_root() and not can_use_sudo():
+ if not sys.stdin.isatty():
+ raise IOError("cannot get sudo password from non-tty input")
+ print("Please enter your sudo password (required to configure local network):")
+ run("sudo -v", stdin=True)
+
+
+def start_component(
+ component: str, listen_str: str | None = None, target_address: str | None = None
+):
+ if component == "edge":
+ return start_edge(listen_str=listen_str)
+ if component == "proxy":
+ if target_address is None:
+ raise ValueError("no target address specified")
+
+ return start_proxy(
+ listen_str=listen_str,
+ target_address=HostAndPort.parse(
+ target_address,
+ default_host=config.default_ip,
+ default_port=constants.DEFAULT_PORT_EDGE,
+ ),
+ )
+ raise Exception("Unexpected component name '%s' received during start up" % component)
+
+
+def start_proxy(
+ listen_str: str, target_address: HostAndPort, asynchronous: bool = False
+) -> TCPProxy:
+ """
+ Starts a TCP proxy to perform a low-level forwarding of incoming requests.
+
+ :param listen_str: address to listen on
+ :param target_address: target address to proxy requests to
+ :param asynchronous: False if the function should join the proxy thread and block until it terminates.
+ :return: created thread executing the proxy
+ """
+ listen_hosts = parse_gateway_listen(
+ listen_str,
+ default_host=constants.LOCALHOST_IP,
+ default_port=constants.DEFAULT_PORT_EDGE,
+ )
+ listen = listen_hosts[0]
+ return do_start_tcp_proxy(listen, target_address, asynchronous)
+
+
+def do_start_tcp_proxy(
+ listen: HostAndPort, target_address: HostAndPort, asynchronous: bool = False
+) -> TCPProxy:
+ src = str(listen)
+ dst = str(target_address)
+
+ LOG.debug("Starting Local TCP Proxy: %s -> %s", src, dst)
+ proxy = TCPProxy(
+ target_address=target_address.host,
+ target_port=target_address.port,
+ host=listen.host,
+ port=listen.port,
+ )
+ proxy.start()
+ if not asynchronous:
+ proxy.join()
+ return proxy
+
+
+def start_edge(listen_str: str, use_ssl: bool = True, asynchronous: bool = False):
+ if listen_str:
+ listen = parse_gateway_listen(
+ listen_str, default_host=config.default_ip, default_port=constants.DEFAULT_PORT_EDGE
+ )
+ else:
+ listen = config.GATEWAY_LISTEN
+
+ if len(listen) == 0:
+ raise ValueError("no listen addresses provided")
+
+ # separate privileged and unprivileged addresses
+ unprivileged, privileged = split_list_by(listen, lambda addr: addr.is_unprivileged() or False)
+
+ # if we are root, we can directly bind to privileged ports as well
+ if is_root():
+ unprivileged = unprivileged + privileged
+ privileged = []
+
+ # check that we are actually started the gateway server
+ if not unprivileged:
+ unprivileged = parse_gateway_listen(
+ f":{get_free_tcp_port()}",
+ default_host=config.default_ip,
+ default_port=constants.DEFAULT_PORT_EDGE,
+ )
+
+ # bind the gateway server to unprivileged addresses
+ edge_thread = do_start_edge(unprivileged, use_ssl=use_ssl, asynchronous=True)
+
+ # start TCP proxies for the remaining addresses
+ proxy_destination = unprivileged[0]
+ for address in privileged:
+ # escalate to root
+ args = [
+ "proxy",
+ "--gateway-listen",
+ str(address),
+ "--target-address",
+ str(proxy_destination),
+ ]
+ run_module_as_sudo(
+ module="localstack.services.edge",
+ arguments=args,
+ asynchronous=True,
+ )
+
+ if edge_thread is not None:
+ edge_thread.join()
+
+
+def run_module_as_sudo(
+ module: str, arguments: Optional[List[str]] = None, asynchronous=False, env_vars=None
+):
+ # prepare environment
+ env_vars = env_vars or {}
+ env_vars["PYTHONPATH"] = f".:{LOCALSTACK_ROOT_FOLDER}"
+
+ # start the process as sudo
+ python_cmd = sys.executable
+ cmd = ["sudo", "-n", "--preserve-env", python_cmd, "-m", module]
+ arguments = arguments or []
+ shell_cmd = shlex.join(cmd + arguments)
+
+ # make sure we can run sudo commands
+ try:
+ ensure_can_use_sudo()
+ except Exception as e:
+ LOG.error("cannot run command as root (%s): %s ", str(e), shell_cmd)
+ return
+
+ def run_command(*_):
+ run(shell_cmd, outfile=subprocess.PIPE, print_error=False, env_vars=env_vars)
+
+ LOG.debug("Running command as sudo: %s", shell_cmd)
+ result = (
+ start_thread(run_command, quiet=True, name="sudo-edge") if asynchronous else run_command()
+ )
+ return result
+
+
+def parse_gateway_listen(listen: str, default_host: str, default_port: int) -> List[HostAndPort]:
+ addresses = []
+ for address in listen.split(","):
+ addresses.append(HostAndPort.parse(address, default_host, default_port))
+ return addresses
+
+
+if __name__ == "__main__":
+ logging.basicConfig()
+ parser = argparse.ArgumentParser()
+ parser.add_argument("component")
+ parser.add_argument("-l", "--gateway-listen", required=False, type=str)
+ parser.add_argument("-t", "--target-address", required=False, type=str)
+ args = parser.parse_args()
+
+ start_component(args.component, args.gateway_listen, args.target_address)
diff --git a/localstack-core/localstack/services/es/__init__.py b/localstack-core/localstack/services/es/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/es/plugins.py b/localstack-core/localstack/services/es/plugins.py
new file mode 100644
index 0000000000000..d14c5f15bd72f
--- /dev/null
+++ b/localstack-core/localstack/services/es/plugins.py
@@ -0,0 +1,8 @@
+from localstack.packages import Package, package
+
+
+@package(name="elasticsearch")
+def elasticsearch_package() -> Package:
+ from localstack.services.opensearch.packages import elasticsearch_package
+
+ return elasticsearch_package
diff --git a/localstack-core/localstack/services/es/provider.py b/localstack-core/localstack/services/es/provider.py
new file mode 100644
index 0000000000000..4519e417bceaa
--- /dev/null
+++ b/localstack-core/localstack/services/es/provider.py
@@ -0,0 +1,441 @@
+from contextlib import contextmanager
+from typing import Dict, Optional, cast
+
+from botocore.exceptions import ClientError
+
+from localstack import constants
+from localstack.aws.api import RequestContext, handler
+from localstack.aws.api.es import (
+ ARN,
+ AccessDeniedException,
+ AdvancedOptions,
+ AdvancedSecurityOptionsInput,
+ AutoTuneOptionsInput,
+ CognitoOptions,
+ CompatibleElasticsearchVersionsList,
+ CompatibleVersionsMap,
+ ConflictException,
+ CreateElasticsearchDomainResponse,
+ DeleteElasticsearchDomainResponse,
+ DescribeElasticsearchDomainConfigResponse,
+ DescribeElasticsearchDomainResponse,
+ DescribeElasticsearchDomainsResponse,
+ DisabledOperationException,
+ DomainEndpointOptions,
+ DomainInfoList,
+ DomainName,
+ DomainNameList,
+ EBSOptions,
+ ElasticsearchClusterConfig,
+ ElasticsearchClusterConfigStatus,
+ ElasticsearchDomainConfig,
+ ElasticsearchDomainStatus,
+ ElasticsearchVersionStatus,
+ ElasticsearchVersionString,
+ EncryptionAtRestOptions,
+ EngineType,
+ EsApi,
+ GetCompatibleElasticsearchVersionsResponse,
+ InternalException,
+ InvalidPaginationTokenException,
+ InvalidTypeException,
+ LimitExceededException,
+ ListDomainNamesResponse,
+ ListElasticsearchVersionsResponse,
+ ListTagsResponse,
+ LogPublishingOptions,
+ MaxResults,
+ NextToken,
+ NodeToNodeEncryptionOptions,
+ OptionStatus,
+ PolicyDocument,
+ ResourceAlreadyExistsException,
+ ResourceNotFoundException,
+ SnapshotOptions,
+ StringList,
+ TagList,
+ UpdateElasticsearchDomainConfigRequest,
+ UpdateElasticsearchDomainConfigResponse,
+ ValidationException,
+ VPCOptions,
+)
+from localstack.aws.api.es import BaseException as EsBaseException
+from localstack.aws.api.opensearch import (
+ ClusterConfig,
+ CompatibleVersionsList,
+ DomainConfig,
+ DomainStatus,
+ VersionString,
+)
+from localstack.aws.connect import connect_to
+
+
+def _version_to_opensearch(
+ version: Optional[ElasticsearchVersionString],
+) -> Optional[VersionString]:
+ if version is not None:
+ if version.startswith("OpenSearch_"):
+ return version
+ else:
+ return f"Elasticsearch_{version}"
+
+
+def _version_from_opensearch(
+ version: Optional[VersionString],
+) -> Optional[ElasticsearchVersionString]:
+ if version is not None:
+ if version.startswith("Elasticsearch_"):
+ return version.split("_")[1]
+ else:
+ return version
+
+
+def _instancetype_to_opensearch(instance_type: Optional[str]) -> Optional[str]:
+ if instance_type is not None:
+ return instance_type.replace("elasticsearch", "search")
+
+
+def _instancetype_from_opensearch(instance_type: Optional[str]) -> Optional[str]:
+ if instance_type is not None:
+ return instance_type.replace("search", "elasticsearch")
+
+
+def _clusterconfig_from_opensearch(
+ cluster_config: Optional[ClusterConfig],
+) -> Optional[ElasticsearchClusterConfig]:
+ if cluster_config is not None:
+ # Just take the whole typed dict and typecast it to our target type
+ result = cast(ElasticsearchClusterConfig, cluster_config)
+
+ # Adjust the instance type names
+ result["InstanceType"] = _instancetype_from_opensearch(cluster_config.get("InstanceType"))
+ result["DedicatedMasterType"] = _instancetype_from_opensearch(
+ cluster_config.get("DedicatedMasterType")
+ )
+ result["WarmType"] = _instancetype_from_opensearch(cluster_config.get("WarmType"))
+ return result
+
+
+def _domainstatus_from_opensearch(
+ domain_status: Optional[DomainStatus],
+) -> Optional[ElasticsearchDomainStatus]:
+ if domain_status is not None:
+ # Just take the whole typed dict and typecast it to our target type
+ result = cast(ElasticsearchDomainStatus, domain_status)
+ # Only specifically handle keys which are named differently or their values differ (version and clusterconfig)
+ result["ElasticsearchVersion"] = _version_from_opensearch(
+ domain_status.get("EngineVersion")
+ )
+ result["ElasticsearchClusterConfig"] = _clusterconfig_from_opensearch(
+ domain_status.get("ClusterConfig")
+ )
+ result.pop("EngineVersion", None)
+ result.pop("ClusterConfig", None)
+ return result
+
+
+def _clusterconfig_to_opensearch(
+ elasticsearch_cluster_config: Optional[ElasticsearchClusterConfig],
+) -> Optional[ClusterConfig]:
+ if elasticsearch_cluster_config is not None:
+ result = cast(ClusterConfig, elasticsearch_cluster_config)
+ if instance_type := result.get("InstanceType"):
+ result["InstanceType"] = _instancetype_to_opensearch(instance_type)
+ if dedicated_master_type := result.get("DedicatedMasterType"):
+ result["DedicatedMasterType"] = _instancetype_to_opensearch(dedicated_master_type)
+ if warm_type := result.get("WarmType"):
+ result["WarmType"] = _instancetype_to_opensearch(warm_type)
+ return result
+
+
+def _domainconfig_from_opensearch(
+ domain_config: Optional[DomainConfig],
+) -> Optional[ElasticsearchDomainConfig]:
+ if domain_config is not None:
+ result = cast(ElasticsearchDomainConfig, domain_config)
+ engine_version = domain_config.get("EngineVersion", {})
+ result["ElasticsearchVersion"] = ElasticsearchVersionStatus(
+ Options=_version_from_opensearch(engine_version.get("Options")),
+ Status=cast(OptionStatus, engine_version.get("Status")),
+ )
+ cluster_config = domain_config.get("ClusterConfig", {})
+ result["ElasticsearchClusterConfig"] = ElasticsearchClusterConfigStatus(
+ Options=_clusterconfig_from_opensearch(cluster_config.get("Options")),
+ Status=cluster_config.get("Status"),
+ )
+ result.pop("EngineVersion", None)
+ result.pop("ClusterConfig", None)
+ return result
+
+
+def _compatible_version_list_from_opensearch(
+ compatible_version_list: Optional[CompatibleVersionsList],
+) -> Optional[CompatibleElasticsearchVersionsList]:
+ if compatible_version_list is not None:
+ return [
+ CompatibleVersionsMap(
+ SourceVersion=_version_from_opensearch(version_map["SourceVersion"]),
+ TargetVersions=[
+ _version_from_opensearch(target_version)
+ for target_version in version_map["TargetVersions"]
+ ],
+ )
+ for version_map in compatible_version_list
+ ]
+
+
+@contextmanager
+def exception_mapper():
+ """Maps an exception thrown by the OpenSearch client to an exception thrown by the ElasticSearch API."""
+ try:
+ yield
+ except ClientError as err:
+ exception_types = {
+ "AccessDeniedException": AccessDeniedException,
+ "BaseException": EsBaseException,
+ "ConflictException": ConflictException,
+ "DisabledOperationException": DisabledOperationException,
+ "InternalException": InternalException,
+ "InvalidPaginationTokenException": InvalidPaginationTokenException,
+ "InvalidTypeException": InvalidTypeException,
+ "LimitExceededException": LimitExceededException,
+ "ResourceAlreadyExistsException": ResourceAlreadyExistsException,
+ "ResourceNotFoundException": ResourceNotFoundException,
+ "ValidationException": ValidationException,
+ }
+ mapped_exception_type = exception_types.get(err.response["Error"]["Code"], EsBaseException)
+ raise mapped_exception_type(err.response["Error"]["Message"])
+
+
+class EsProvider(EsApi):
+ def create_elasticsearch_domain(
+ self,
+ context: RequestContext,
+ domain_name: DomainName,
+ elasticsearch_version: ElasticsearchVersionString = None,
+ elasticsearch_cluster_config: ElasticsearchClusterConfig = None,
+ ebs_options: EBSOptions = None,
+ access_policies: PolicyDocument = None,
+ snapshot_options: SnapshotOptions = None,
+ vpc_options: VPCOptions = None,
+ cognito_options: CognitoOptions = None,
+ encryption_at_rest_options: EncryptionAtRestOptions = None,
+ node_to_node_encryption_options: NodeToNodeEncryptionOptions = None,
+ advanced_options: AdvancedOptions = None,
+ log_publishing_options: LogPublishingOptions = None,
+ domain_endpoint_options: DomainEndpointOptions = None,
+ advanced_security_options: AdvancedSecurityOptionsInput = None,
+ auto_tune_options: AutoTuneOptionsInput = None,
+ tag_list: TagList = None,
+ **kwargs,
+ ) -> CreateElasticsearchDomainResponse:
+ opensearch_client = connect_to(
+ region_name=context.region, aws_access_key_id=context.account_id
+ ).opensearch
+ # If no version is given, we set our default elasticsearch version
+ engine_version = (
+ _version_to_opensearch(elasticsearch_version)
+ if elasticsearch_version
+ else constants.ELASTICSEARCH_DEFAULT_VERSION
+ )
+ kwargs = {
+ "DomainName": domain_name,
+ "EngineVersion": engine_version,
+ "ClusterConfig": _clusterconfig_to_opensearch(elasticsearch_cluster_config),
+ "EBSOptions": ebs_options,
+ "AccessPolicies": access_policies,
+ "SnapshotOptions": snapshot_options,
+ "VPCOptions": vpc_options,
+ "CognitoOptions": cognito_options,
+ "EncryptionAtRestOptions": encryption_at_rest_options,
+ "NodeToNodeEncryptionOptions": node_to_node_encryption_options,
+ "AdvancedOptions": advanced_options,
+ "LogPublishingOptions": log_publishing_options,
+ "DomainEndpointOptions": domain_endpoint_options,
+ "AdvancedSecurityOptions": advanced_security_options,
+ "AutoTuneOptions": auto_tune_options,
+ "TagList": tag_list,
+ }
+
+ # Filter the kwargs to not set None values at all (boto doesn't like that)
+ kwargs = {key: value for key, value in kwargs.items() if value is not None}
+
+ with exception_mapper():
+ domain_status = opensearch_client.create_domain(**kwargs)["DomainStatus"]
+
+ status = _domainstatus_from_opensearch(domain_status)
+ return CreateElasticsearchDomainResponse(DomainStatus=status)
+
+ def delete_elasticsearch_domain(
+ self, context: RequestContext, domain_name: DomainName, **kwargs
+ ) -> DeleteElasticsearchDomainResponse:
+ opensearch_client = connect_to(
+ region_name=context.region, aws_access_key_id=context.account_id
+ ).opensearch
+
+ with exception_mapper():
+ domain_status = opensearch_client.delete_domain(
+ DomainName=domain_name,
+ )["DomainStatus"]
+
+ status = _domainstatus_from_opensearch(domain_status)
+ return DeleteElasticsearchDomainResponse(DomainStatus=status)
+
+ def describe_elasticsearch_domain(
+ self, context: RequestContext, domain_name: DomainName, **kwargs
+ ) -> DescribeElasticsearchDomainResponse:
+ opensearch_client = connect_to(
+ region_name=context.region, aws_access_key_id=context.account_id
+ ).opensearch
+
+ with exception_mapper():
+ opensearch_status = opensearch_client.describe_domain(
+ DomainName=domain_name,
+ )["DomainStatus"]
+
+ status = _domainstatus_from_opensearch(opensearch_status)
+ return DescribeElasticsearchDomainResponse(DomainStatus=status)
+
+ @handler("UpdateElasticsearchDomainConfig", expand=False)
+ def update_elasticsearch_domain_config(
+ self, context: RequestContext, payload: UpdateElasticsearchDomainConfigRequest
+ ) -> UpdateElasticsearchDomainConfigResponse:
+ opensearch_client = connect_to(
+ region_name=context.region, aws_access_key_id=context.account_id
+ ).opensearch
+
+ payload: Dict
+ if "ElasticsearchClusterConfig" in payload:
+ payload["ClusterConfig"] = payload["ElasticsearchClusterConfig"]
+ payload["ClusterConfig"]["InstanceType"] = _instancetype_to_opensearch(
+ payload["ClusterConfig"]["InstanceType"]
+ )
+ payload.pop("ElasticsearchClusterConfig")
+
+ with exception_mapper():
+ opensearch_config = opensearch_client.update_domain_config(**payload)["DomainConfig"]
+
+ config = _domainconfig_from_opensearch(opensearch_config)
+ return UpdateElasticsearchDomainConfigResponse(DomainConfig=config)
+
+ def describe_elasticsearch_domains(
+ self, context: RequestContext, domain_names: DomainNameList, **kwargs
+ ) -> DescribeElasticsearchDomainsResponse:
+ opensearch_client = connect_to(
+ region_name=context.region, aws_access_key_id=context.account_id
+ ).opensearch
+
+ with exception_mapper():
+ opensearch_status_list = opensearch_client.describe_domains(
+ DomainNames=domain_names,
+ )["DomainStatusList"]
+
+ status_list = [_domainstatus_from_opensearch(s) for s in opensearch_status_list]
+ return DescribeElasticsearchDomainsResponse(DomainStatusList=status_list)
+
+ def list_domain_names(
+ self, context: RequestContext, engine_type: EngineType = None, **kwargs
+ ) -> ListDomainNamesResponse:
+ opensearch_client = connect_to(
+ region_name=context.region, aws_access_key_id=context.account_id
+ ).opensearch
+ # Only hand the EngineType param to boto if it's set
+ kwargs = {}
+ if engine_type:
+ kwargs["EngineType"] = engine_type
+
+ with exception_mapper():
+ domain_names = opensearch_client.list_domain_names(**kwargs)["DomainNames"]
+
+ return ListDomainNamesResponse(DomainNames=cast(Optional[DomainInfoList], domain_names))
+
+ def list_elasticsearch_versions(
+ self,
+ context: RequestContext,
+ max_results: MaxResults = None,
+ next_token: NextToken = None,
+ **kwargs,
+ ) -> ListElasticsearchVersionsResponse:
+ opensearch_client = connect_to(
+ region_name=context.region, aws_access_key_id=context.account_id
+ ).opensearch
+ # Construct the arguments as kwargs to not set None values at all (boto doesn't like that)
+ kwargs = {
+ key: value
+ for key, value in {"MaxResults": max_results, "NextToken": next_token}.items()
+ if value is not None
+ }
+ with exception_mapper():
+ versions = opensearch_client.list_versions(**kwargs)
+
+ return ListElasticsearchVersionsResponse(
+ ElasticsearchVersions=[
+ _version_from_opensearch(version) for version in versions["Versions"]
+ ],
+ NextToken=versions.get(next_token),
+ )
+
+ def get_compatible_elasticsearch_versions(
+ self, context: RequestContext, domain_name: DomainName = None, **kwargs
+ ) -> GetCompatibleElasticsearchVersionsResponse:
+ opensearch_client = connect_to(
+ region_name=context.region, aws_access_key_id=context.account_id
+ ).opensearch
+ # Only hand the DomainName param to boto if it's set
+ kwargs = {}
+ if domain_name:
+ kwargs["DomainName"] = domain_name
+
+ with exception_mapper():
+ compatible_versions_response = opensearch_client.get_compatible_versions(**kwargs)
+
+ compatible_versions = compatible_versions_response.get("CompatibleVersions")
+ return GetCompatibleElasticsearchVersionsResponse(
+ CompatibleElasticsearchVersions=_compatible_version_list_from_opensearch(
+ compatible_versions
+ )
+ )
+
+ def describe_elasticsearch_domain_config(
+ self, context: RequestContext, domain_name: DomainName, **kwargs
+ ) -> DescribeElasticsearchDomainConfigResponse:
+ opensearch_client = connect_to(
+ region_name=context.region, aws_access_key_id=context.account_id
+ ).opensearch
+
+ with exception_mapper():
+ domain_config = opensearch_client.describe_domain_config(DomainName=domain_name).get(
+ "DomainConfig"
+ )
+
+ return DescribeElasticsearchDomainConfigResponse(
+ DomainConfig=_domainconfig_from_opensearch(domain_config)
+ )
+
+ def add_tags(self, context: RequestContext, arn: ARN, tag_list: TagList, **kwargs) -> None:
+ opensearch_client = connect_to(
+ region_name=context.region, aws_access_key_id=context.account_id
+ ).opensearch
+
+ with exception_mapper():
+ opensearch_client.add_tags(ARN=arn, TagList=tag_list)
+
+ def list_tags(self, context: RequestContext, arn: ARN, **kwargs) -> ListTagsResponse:
+ opensearch_client = connect_to(
+ region_name=context.region, aws_access_key_id=context.account_id
+ ).opensearch
+
+ with exception_mapper():
+ response = opensearch_client.list_tags(ARN=arn)
+
+ return ListTagsResponse(TagList=response.get("TagList"))
+
+ def remove_tags(
+ self, context: RequestContext, arn: ARN, tag_keys: StringList, **kwargs
+ ) -> None:
+ opensearch_client = connect_to(
+ region_name=context.region, aws_access_key_id=context.account_id
+ ).opensearch
+
+ with exception_mapper():
+ opensearch_client.remove_tags(ARN=arn, TagKeys=tag_keys)
diff --git a/localstack-core/localstack/services/events/__init__.py b/localstack-core/localstack/services/events/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/events/api_destination.py b/localstack-core/localstack/services/events/api_destination.py
new file mode 100644
index 0000000000000..a7fe116eaed21
--- /dev/null
+++ b/localstack-core/localstack/services/events/api_destination.py
@@ -0,0 +1,291 @@
+import base64
+import json
+import logging
+import re
+
+import requests
+
+from localstack.aws.api.events import (
+ ApiDestinationDescription,
+ ApiDestinationHttpMethod,
+ ApiDestinationInvocationRateLimitPerSecond,
+ ApiDestinationName,
+ ApiDestinationState,
+ Arn,
+ ConnectionArn,
+ ConnectionAuthorizationType,
+ ConnectionState,
+ HttpsEndpoint,
+ Timestamp,
+)
+from localstack.aws.connect import connect_to
+from localstack.services.events.models import ApiDestination, Connection, ValidationException
+from localstack.utils.aws.arns import (
+ extract_account_id_from_arn,
+ extract_region_from_arn,
+ parse_arn,
+)
+from localstack.utils.aws.message_forwarding import (
+ list_of_parameters_to_object,
+)
+from localstack.utils.http import add_query_params_to_url
+from localstack.utils.strings import to_str
+
+VALID_AUTH_TYPES = [t.value for t in ConnectionAuthorizationType]
+LOG = logging.getLogger(__name__)
+
+
+class APIDestinationService:
+ def __init__(
+ self,
+ name: ApiDestinationName,
+ region: str,
+ account_id: str,
+ connection_arn: ConnectionArn,
+ connection: Connection,
+ invocation_endpoint: HttpsEndpoint,
+ http_method: ApiDestinationHttpMethod,
+ invocation_rate_limit_per_second: ApiDestinationInvocationRateLimitPerSecond | None,
+ description: ApiDestinationDescription | None = None,
+ ):
+ self.validate_input(name, connection_arn, http_method, invocation_endpoint)
+ self.connection = connection
+ state = self._get_state()
+
+ self.api_destination = ApiDestination(
+ name,
+ region,
+ account_id,
+ connection_arn,
+ invocation_endpoint,
+ http_method,
+ state,
+ invocation_rate_limit_per_second,
+ description,
+ )
+
+ @property
+ def arn(self) -> Arn:
+ return self.api_destination.arn
+
+ @property
+ def state(self) -> ApiDestinationState:
+ return self.api_destination.state
+
+ @property
+ def creation_time(self) -> Timestamp:
+ return self.api_destination.creation_time
+
+ @property
+ def last_modified_time(self) -> Timestamp:
+ return self.api_destination.last_modified_time
+
+ def set_state(self, state: ApiDestinationState) -> None:
+ if hasattr(self, "api_destination"):
+ if state == ApiDestinationState.ACTIVE:
+ state = self._get_state()
+ self.api_destination.state = state
+
+ def update(
+ self,
+ connection,
+ invocation_endpoint,
+ http_method,
+ invocation_rate_limit_per_second,
+ description,
+ ):
+ self.set_state(ApiDestinationState.INACTIVE)
+ self.connection = connection
+ self.api_destination.connection_arn = connection.arn
+ if invocation_endpoint:
+ self.api_destination.invocation_endpoint = invocation_endpoint
+ if http_method:
+ self.api_destination.http_method = http_method
+ if invocation_rate_limit_per_second:
+ self.api_destination.invocation_rate_limit_per_second = invocation_rate_limit_per_second
+ if description:
+ self.api_destination.description = description
+ self.api_destination.last_modified_time = Timestamp.now()
+ self.set_state(ApiDestinationState.ACTIVE)
+
+ def _get_state(self) -> ApiDestinationState:
+ """Determine ApiDestinationState based on ConnectionState."""
+ return (
+ ApiDestinationState.ACTIVE
+ if self.connection.state == ConnectionState.AUTHORIZED
+ else ApiDestinationState.INACTIVE
+ )
+
+ @classmethod
+ def validate_input(
+ cls,
+ name: ApiDestinationName,
+ connection_arn: ConnectionArn,
+ http_method: ApiDestinationHttpMethod,
+ invocation_endpoint: HttpsEndpoint,
+ ) -> None:
+ errors = []
+ errors.extend(cls._validate_api_destination_name(name))
+ errors.extend(cls._validate_connection_arn(connection_arn))
+ errors.extend(cls._validate_http_method(http_method))
+ errors.extend(cls._validate_invocation_endpoint(invocation_endpoint))
+
+ if errors:
+ error_message = (
+ f"{len(errors)} validation error{'s' if len(errors) > 1 else ''} detected: "
+ )
+ error_message += "; ".join(errors)
+ raise ValidationException(error_message)
+
+ @staticmethod
+ def _validate_api_destination_name(name: str) -> list[str]:
+ """Validate the API destination name according to AWS rules. Returns a list of validation errors."""
+ errors = []
+ if not re.match(r"^[\.\-_A-Za-z0-9]+$", name):
+ errors.append(
+ f"Value '{name}' at 'name' failed to satisfy constraint: "
+ "Member must satisfy regular expression pattern: [\\.\\-_A-Za-z0-9]+"
+ )
+ if not (1 <= len(name) <= 64):
+ errors.append(
+ f"Value '{name}' at 'name' failed to satisfy constraint: "
+ "Member must have length less than or equal to 64"
+ )
+ return errors
+
+ @staticmethod
+ def _validate_connection_arn(connection_arn: ConnectionArn) -> list[str]:
+ errors = []
+ if not re.match(
+ r"^arn:aws([a-z]|\-)*:events:[a-z0-9\-]+:\d{12}:connection/[\.\-_A-Za-z0-9]+/[\-A-Za-z0-9]+$",
+ connection_arn,
+ ):
+ errors.append(
+ f"Value '{connection_arn}' at 'connectionArn' failed to satisfy constraint: "
+ "Member must satisfy regular expression pattern: "
+ "^arn:aws([a-z]|\\-)*:events:([a-z]|\\d|\\-)*:([0-9]{12})?:connection\\/[\\.\\-_A-Za-z0-9]+\\/[\\-A-Za-z0-9]+$"
+ )
+ return errors
+
+ @staticmethod
+ def _validate_http_method(http_method: ApiDestinationHttpMethod) -> list[str]:
+ errors = []
+ allowed_methods = ["HEAD", "POST", "PATCH", "DELETE", "PUT", "GET", "OPTIONS"]
+ if http_method not in allowed_methods:
+ errors.append(
+ f"Value '{http_method}' at 'httpMethod' failed to satisfy constraint: "
+ f"Member must satisfy enum value set: [{', '.join(allowed_methods)}]"
+ )
+ return errors
+
+ @staticmethod
+ def _validate_invocation_endpoint(invocation_endpoint: HttpsEndpoint) -> list[str]:
+ errors = []
+ endpoint_pattern = r"^((%[0-9A-Fa-f]{2}|[-()_.!~*';/?:@&=+$,A-Za-z0-9])+)([).!';/?:,])?$"
+ if not re.match(endpoint_pattern, invocation_endpoint):
+ errors.append(
+ f"Value '{invocation_endpoint}' at 'invocationEndpoint' failed to satisfy constraint: "
+ "Member must satisfy regular expression pattern: "
+ "^((%[0-9A-Fa-f]{2}|[-()_.!~*';/?:@&=+$,A-Za-z0-9])+)([).!';/?:,])?$"
+ )
+ return errors
+
+
+ApiDestinationServiceDict = dict[Arn, APIDestinationService]
+
+
+def add_api_destination_authorization(destination, headers, event):
+ connection_arn = destination.get("ConnectionArn", "")
+ connection_name = re.search(r"connection\/([a-zA-Z0-9-_]+)\/", connection_arn).group(1)
+
+ account_id = extract_account_id_from_arn(connection_arn)
+ region = extract_region_from_arn(connection_arn)
+
+ events_client = connect_to(aws_access_key_id=account_id, region_name=region).events
+ connection_details = events_client.describe_connection(Name=connection_name)
+ secret_arn = connection_details["SecretArn"]
+ parsed_arn = parse_arn(secret_arn)
+ secretsmanager_client = connect_to(
+ aws_access_key_id=parsed_arn["account"], region_name=parsed_arn["region"]
+ ).secretsmanager
+ auth_secret = json.loads(
+ secretsmanager_client.get_secret_value(SecretId=secret_arn)["SecretString"]
+ )
+
+ headers.update(_auth_keys_from_connection(connection_details, auth_secret))
+
+ auth_parameters = connection_details.get("AuthParameters", {})
+ invocation_parameters = auth_parameters.get("InvocationHttpParameters")
+
+ endpoint = destination.get("InvocationEndpoint")
+ if invocation_parameters:
+ header_parameters = list_of_parameters_to_object(
+ invocation_parameters.get("HeaderParameters", [])
+ )
+ headers.update(header_parameters)
+
+ body_parameters = list_of_parameters_to_object(
+ invocation_parameters.get("BodyParameters", [])
+ )
+ event.update(body_parameters)
+
+ query_parameters = invocation_parameters.get("QueryStringParameters", [])
+ query_object = list_of_parameters_to_object(query_parameters)
+ endpoint = add_query_params_to_url(endpoint, query_object)
+
+ return endpoint
+
+
+def _auth_keys_from_connection(connection_details, auth_secret):
+ headers = {}
+
+ auth_type = connection_details.get("AuthorizationType").upper()
+ auth_parameters = connection_details.get("AuthParameters")
+ match auth_type:
+ case ConnectionAuthorizationType.BASIC:
+ username = auth_secret.get("username", "")
+ password = auth_secret.get("password", "")
+ auth = "Basic " + to_str(base64.b64encode(f"{username}:{password}".encode("ascii")))
+ headers.update({"authorization": auth})
+
+ case ConnectionAuthorizationType.API_KEY:
+ api_key_name = auth_secret.get("api_key_name", "")
+ api_key_value = auth_secret.get("api_key_value", "")
+ headers.update({api_key_name: api_key_value})
+
+ case ConnectionAuthorizationType.OAUTH_CLIENT_CREDENTIALS:
+ oauth_parameters = auth_parameters.get("OAuthParameters", {})
+ oauth_method = auth_secret.get("http_method")
+
+ oauth_http_parameters = oauth_parameters.get("OAuthHttpParameters", {})
+ oauth_endpoint = auth_secret.get("authorization_endpoint", "")
+ query_object = list_of_parameters_to_object(
+ oauth_http_parameters.get("QueryStringParameters", [])
+ )
+ oauth_endpoint = add_query_params_to_url(oauth_endpoint, query_object)
+
+ client_id = auth_secret.get("client_id", "")
+ client_secret = auth_secret.get("client_secret", "")
+
+ oauth_body = list_of_parameters_to_object(
+ oauth_http_parameters.get("BodyParameters", [])
+ )
+ oauth_body.update({"client_id": client_id, "client_secret": client_secret})
+
+ oauth_header = list_of_parameters_to_object(
+ oauth_http_parameters.get("HeaderParameters", [])
+ )
+ oauth_result = requests.request(
+ method=oauth_method,
+ url=oauth_endpoint,
+ data=json.dumps(oauth_body),
+ headers=oauth_header,
+ )
+ oauth_data = json.loads(oauth_result.text)
+
+ token_type = oauth_data.get("token_type", "")
+ access_token = oauth_data.get("access_token", "")
+ auth_header = f"{token_type} {access_token}"
+ headers.update({"authorization": auth_header})
+
+ return headers
diff --git a/localstack-core/localstack/services/events/archive.py b/localstack-core/localstack/services/events/archive.py
new file mode 100644
index 0000000000000..12d7e4601747f
--- /dev/null
+++ b/localstack-core/localstack/services/events/archive.py
@@ -0,0 +1,189 @@
+import json
+import logging
+from datetime import datetime, timezone
+from typing import Self
+
+from botocore.client import BaseClient
+
+from localstack.aws.api.events import (
+ ArchiveState,
+ Arn,
+ EventBusName,
+ TargetId,
+ Timestamp,
+)
+from localstack.aws.connect import connect_to
+from localstack.services.events.models import (
+ Archive,
+ ArchiveDescription,
+ ArchiveName,
+ EventPattern,
+ FormattedEvent,
+ FormattedEventList,
+ RetentionDays,
+ RuleName,
+)
+from localstack.services.events.utils import extract_event_bus_name
+from localstack.utils.aws.client_types import ServicePrincipal
+
+LOG = logging.getLogger(__name__)
+
+
+class ArchiveService:
+ archive_name: ArchiveName
+ region: str
+ account_id: str
+ event_source_arn: Arn
+ description: ArchiveDescription
+ event_pattern: EventPattern
+ retention_days: RetentionDays
+ archive: Archive
+ client: BaseClient
+ event_bus_name: EventBusName
+ rule_name: RuleName
+ target_id: TargetId
+
+ def __init__(self, archive: Archive):
+ self.archive = archive
+ self.set_state(ArchiveState.CREATING)
+ self.set_creation_time()
+ self.client: BaseClient = self._initialize_client()
+ self.event_bus_name: EventBusName = extract_event_bus_name(archive.event_source_arn)
+ self.set_state(ArchiveState.ENABLED)
+ self.rule_name = f"Events-Archive-{self.archive_name}"
+ self.target_id = f"Events-Archive-{self.archive_name}"
+
+ @classmethod
+ def create_archive_service(
+ cls,
+ archive_name: ArchiveName,
+ region: str,
+ account_id: str,
+ event_source_arn: Arn,
+ description: ArchiveDescription,
+ event_pattern: EventPattern,
+ retention_days: RetentionDays,
+ ) -> Self:
+ return cls(
+ Archive(
+ archive_name,
+ region,
+ account_id,
+ event_source_arn,
+ description,
+ event_pattern,
+ retention_days,
+ )
+ )
+
+ def register_archive_rule_and_targets(self):
+ self._create_archive_rule()
+ self._create_archive_target()
+
+ def __getattr__(self, name):
+ return getattr(self.archive, name)
+
+ @property
+ def archive_name(self) -> ArchiveName:
+ return self.archive.name
+
+ @property
+ def archive_arn(self) -> Arn:
+ return self.archive.arn
+
+ def set_state(self, state: ArchiveState) -> None:
+ self.archive.state = state
+
+ def set_creation_time(self) -> None:
+ self.archive.creation_time = datetime.now(timezone.utc)
+
+ def update(
+ self,
+ description: ArchiveDescription,
+ event_pattern: EventPattern,
+ retention_days: RetentionDays,
+ ) -> None:
+ self.set_state(ArchiveState.UPDATING)
+ if description is not None:
+ self.archive.description = description
+ if event_pattern is not None:
+ self.archive.event_pattern = event_pattern
+ if retention_days is not None:
+ self.archive.retention_days = retention_days
+ self.set_state(ArchiveState.ENABLED)
+
+ def delete(self) -> None:
+ self.set_state(ArchiveState.DISABLED)
+ try:
+ self.client.remove_targets(
+ Rule=self.rule_name, EventBusName=self.event_bus_name, Ids=[self.target_id]
+ )
+ except Exception as e:
+ LOG.debug("Target %s could not be removed, %s", self.target_id, e)
+ try:
+ self.client.delete_rule(Name=self.rule_name, EventBusName=self.event_bus_name)
+ except Exception as e:
+ LOG.debug("Rule %s could not be deleted, %s", self.rule_name, e)
+
+ def put_events(self, events: FormattedEventList) -> None:
+ for event in events:
+ self.archive.events[event["id"]] = event
+
+ def get_events(self, start_time: Timestamp, end_time: Timestamp) -> FormattedEventList:
+ events_to_replay = self._filter_events_start_end_time(start_time, end_time)
+ return events_to_replay
+
+ def _initialize_client(self) -> BaseClient:
+ client_factory = connect_to(aws_access_key_id=self.account_id, region_name=self.region)
+ client = client_factory.get_client("events")
+
+ service_principal = ServicePrincipal.events
+ client = client.request_metadata(service_principal=service_principal, source_arn=self.arn)
+ return client
+
+ def _create_archive_rule(
+ self,
+ ):
+ default_event_pattern = {
+ "replay-name": [{"exists": False}],
+ }
+ if self.event_pattern:
+ updated_event_pattern = json.loads(self.event_pattern)
+ updated_event_pattern.update(default_event_pattern)
+ else:
+ updated_event_pattern = default_event_pattern
+ self.client.put_rule(
+ Name=self.rule_name,
+ EventBusName=self.event_bus_name,
+ EventPattern=json.dumps(updated_event_pattern),
+ )
+
+ def _create_archive_target(
+ self,
+ ):
+ """Creates a target for the archive rule. The target is required for accessing parameters
+ from the provider during sending of events to the target but it is not invoked
+ because events are put to the archive directly to not overload the gateway"""
+ self.client.put_targets(
+ Rule=self.rule_name,
+ EventBusName=self.event_bus_name,
+ Targets=[{"Id": self.target_id, "Arn": self.arn}],
+ )
+
+ def _normalize_datetime(self, dt: datetime) -> datetime:
+ return dt.replace(second=0, microsecond=0)
+
+ def _filter_events_start_end_time(
+ self, event_start_time: Timestamp, event_end_time: Timestamp
+ ) -> list[FormattedEvent]:
+ events = self.archive.events
+ event_start_time = self._normalize_datetime(event_start_time)
+ event_end_time = self._normalize_datetime(event_end_time)
+ return [
+ event
+ for event in events.values()
+ if event_start_time <= self._normalize_datetime(event["time"]) <= event_end_time
+ ]
+
+
+ArchiveServiceDict = dict[Arn, ArchiveService]
diff --git a/localstack-core/localstack/services/events/connection.py b/localstack-core/localstack/services/events/connection.py
new file mode 100644
index 0000000000000..bb855c9203e0c
--- /dev/null
+++ b/localstack-core/localstack/services/events/connection.py
@@ -0,0 +1,327 @@
+import json
+import logging
+import re
+import uuid
+from datetime import datetime, timezone
+
+from localstack.aws.api.events import (
+ Arn,
+ ConnectionAuthorizationType,
+ ConnectionDescription,
+ ConnectionName,
+ ConnectionState,
+ ConnectivityResourceParameters,
+ CreateConnectionAuthRequestParameters,
+ Timestamp,
+ UpdateConnectionAuthRequestParameters,
+)
+from localstack.aws.connect import connect_to
+from localstack.services.events.models import Connection, ValidationException
+
+VALID_AUTH_TYPES = [t.value for t in ConnectionAuthorizationType]
+LOG = logging.getLogger(__name__)
+
+
+class ConnectionService:
+ def __init__(
+ self,
+ name: ConnectionName,
+ region: str,
+ account_id: str,
+ authorization_type: ConnectionAuthorizationType,
+ auth_parameters: CreateConnectionAuthRequestParameters,
+ description: ConnectionDescription | None = None,
+ invocation_connectivity_parameters: ConnectivityResourceParameters | None = None,
+ ):
+ self._validate_input(name, authorization_type)
+ state = self._get_initial_state(authorization_type)
+ secret_arn = self.create_connection_secret(
+ region, account_id, name, authorization_type, auth_parameters
+ )
+ public_auth_parameters = self._get_public_parameters(authorization_type, auth_parameters)
+
+ self.connection = Connection(
+ name,
+ region,
+ account_id,
+ authorization_type,
+ public_auth_parameters,
+ state,
+ secret_arn,
+ description,
+ invocation_connectivity_parameters,
+ )
+
+ @property
+ def arn(self) -> Arn:
+ return self.connection.arn
+
+ @property
+ def state(self) -> ConnectionState:
+ return self.connection.state
+
+ @property
+ def creation_time(self) -> Timestamp:
+ return self.connection.creation_time
+
+ @property
+ def last_modified_time(self) -> Timestamp:
+ return self.connection.last_modified_time
+
+ @property
+ def last_authorized_time(self) -> Timestamp:
+ return self.connection.last_authorized_time
+
+ @property
+ def secret_arn(self) -> Arn:
+ return self.connection.secret_arn
+
+ @property
+ def auth_parameters(self) -> CreateConnectionAuthRequestParameters:
+ return self.connection.auth_parameters
+
+ def set_state(self, state: ConnectionState) -> None:
+ if hasattr(self, "connection"):
+ self.connection.state = state
+
+ def update(
+ self,
+ description: ConnectionDescription,
+ authorization_type: ConnectionAuthorizationType,
+ auth_parameters: UpdateConnectionAuthRequestParameters,
+ invocation_connectivity_parameters: ConnectivityResourceParameters | None = None,
+ ) -> None:
+ self.set_state(ConnectionState.UPDATING)
+ if description:
+ self.connection.description = description
+ if invocation_connectivity_parameters:
+ self.connection.invocation_connectivity_parameters = invocation_connectivity_parameters
+ # Use existing values if not provided in update
+ if authorization_type:
+ auth_type = (
+ authorization_type.value
+ if hasattr(authorization_type, "value")
+ else authorization_type
+ )
+ self._validate_auth_type(auth_type)
+ else:
+ auth_type = self.connection.authorization_type
+
+ try:
+ if self.connection.secret_arn:
+ self.update_connection_secret(
+ self.connection.secret_arn, auth_type, auth_parameters
+ )
+ else:
+ secret_arn = self.create_connection_secret(
+ self.connection.region,
+ self.connection.account_id,
+ self.connection.name,
+ auth_type,
+ auth_parameters,
+ )
+ self.connection.secret_arn = secret_arn
+ self.connection.last_authorized_time = datetime.now(timezone.utc)
+
+ # Set new values
+ self.connection.authorization_type = auth_type
+ public_auth_parameters = (
+ self._get_public_parameters(authorization_type, auth_parameters)
+ if auth_parameters
+ else self.connection.auth_parameters
+ )
+ self.connection.auth_parameters = public_auth_parameters
+ self.set_state(ConnectionState.AUTHORIZED)
+ self.connection.last_modified_time = datetime.now(timezone.utc)
+
+ except Exception as error:
+ LOG.warning(
+ "Connection with name %s updating failed with errors: %s.",
+ self.connection.name,
+ error,
+ )
+
+ def delete(self) -> None:
+ self.set_state(ConnectionState.DELETING)
+ self.delete_connection_secret(self.connection.secret_arn)
+ self.set_state(ConnectionState.DELETING) # required for AWS parity
+ self.connection.last_modified_time = datetime.now(timezone.utc)
+
+ def create_connection_secret(
+ self,
+ region: str,
+ account_id: str,
+ name: str,
+ authorization_type: ConnectionAuthorizationType,
+ auth_parameters: CreateConnectionAuthRequestParameters
+ | UpdateConnectionAuthRequestParameters,
+ ) -> Arn | None:
+ self.set_state(ConnectionState.AUTHORIZING)
+ secretsmanager_client = connect_to(
+ aws_access_key_id=account_id, region_name=region
+ ).secretsmanager
+ secret_value = self._get_secret_value(authorization_type, auth_parameters)
+ secret_name = f"events!connection/{name}/{str(uuid.uuid4())}"
+ try:
+ secret_arn = secretsmanager_client.create_secret(
+ Name=secret_name,
+ SecretString=secret_value,
+ Tags=[{"Key": "BYPASS_SECRET_ID_VALIDATION", "Value": "1"}],
+ )["ARN"]
+ self.set_state(ConnectionState.AUTHORIZED)
+ return secret_arn
+ except Exception as error:
+ LOG.warning("Secret with name %s creation failed with errors: %s.", secret_name, error)
+
+ def update_connection_secret(
+ self,
+ secret_arn: str,
+ authorization_type: ConnectionAuthorizationType,
+ auth_parameters: UpdateConnectionAuthRequestParameters,
+ ) -> None:
+ self.set_state(ConnectionState.AUTHORIZING)
+ secretsmanager_client = connect_to(
+ aws_access_key_id=self.connection.account_id, region_name=self.connection.region
+ ).secretsmanager
+ secret_value = self._get_secret_value(authorization_type, auth_parameters)
+ try:
+ secretsmanager_client.update_secret(SecretId=secret_arn, SecretString=secret_value)
+ self.set_state(ConnectionState.AUTHORIZED)
+ self.connection.last_authorized_time = datetime.now(timezone.utc)
+ except Exception as error:
+ LOG.warning("Secret with id %s updating failed with errors: %s.", secret_arn, error)
+
+ def delete_connection_secret(self, secret_arn: str) -> None:
+ self.set_state(ConnectionState.DEAUTHORIZING)
+ secretsmanager_client = connect_to(
+ aws_access_key_id=self.connection.account_id, region_name=self.connection.region
+ ).secretsmanager
+ try:
+ secretsmanager_client.delete_secret(
+ SecretId=secret_arn, ForceDeleteWithoutRecovery=True
+ )
+ self.set_state(ConnectionState.DEAUTHORIZED)
+ except Exception as error:
+ LOG.warning("Secret with id %s deleting failed with errors: %s.", secret_arn, error)
+
+ def _get_initial_state(self, auth_type: str) -> ConnectionState:
+ if auth_type == "OAUTH_CLIENT_CREDENTIALS":
+ return ConnectionState.AUTHORIZING
+ return ConnectionState.AUTHORIZED
+
+ def _get_secret_value(
+ self,
+ authorization_type: ConnectionAuthorizationType,
+ auth_parameters: CreateConnectionAuthRequestParameters
+ | UpdateConnectionAuthRequestParameters,
+ ) -> str:
+ result = {}
+ match authorization_type:
+ case ConnectionAuthorizationType.BASIC:
+ params = auth_parameters.get("BasicAuthParameters", {})
+ result = {"username": params.get("Username"), "password": params.get("Password")}
+ case ConnectionAuthorizationType.API_KEY:
+ params = auth_parameters.get("ApiKeyAuthParameters", {})
+ result = {
+ "api_key_name": params.get("ApiKeyName"),
+ "api_key_value": params.get("ApiKeyValue"),
+ }
+ case ConnectionAuthorizationType.OAUTH_CLIENT_CREDENTIALS:
+ params = auth_parameters.get("OAuthParameters", {})
+ client_params = params.get("ClientParameters", {})
+ result = {
+ "client_id": client_params.get("ClientID"),
+ "client_secret": client_params.get("ClientSecret"),
+ "authorization_endpoint": params.get("AuthorizationEndpoint"),
+ "http_method": params.get("HttpMethod"),
+ }
+
+ if "InvocationHttpParameters" in auth_parameters:
+ result["invocation_http_parameters"] = auth_parameters["InvocationHttpParameters"]
+
+ return json.dumps(result)
+
+ def _get_public_parameters(
+ self,
+ auth_type: ConnectionAuthorizationType,
+ auth_parameters: CreateConnectionAuthRequestParameters
+ | UpdateConnectionAuthRequestParameters,
+ ) -> CreateConnectionAuthRequestParameters:
+ """Extract public parameters (without secrets) based on auth type."""
+ public_params = {}
+
+ if (
+ auth_type == ConnectionAuthorizationType.BASIC
+ and "BasicAuthParameters" in auth_parameters
+ ):
+ public_params["BasicAuthParameters"] = {
+ "Username": auth_parameters["BasicAuthParameters"]["Username"]
+ }
+
+ elif (
+ auth_type == ConnectionAuthorizationType.API_KEY
+ and "ApiKeyAuthParameters" in auth_parameters
+ ):
+ public_params["ApiKeyAuthParameters"] = {
+ "ApiKeyName": auth_parameters["ApiKeyAuthParameters"]["ApiKeyName"]
+ }
+
+ elif (
+ auth_type == ConnectionAuthorizationType.OAUTH_CLIENT_CREDENTIALS
+ and "OAuthParameters" in auth_parameters
+ ):
+ oauth_params = auth_parameters["OAuthParameters"]
+ public_params["OAuthParameters"] = {
+ "AuthorizationEndpoint": oauth_params["AuthorizationEndpoint"],
+ "HttpMethod": oauth_params["HttpMethod"],
+ "ClientParameters": {"ClientID": oauth_params["ClientParameters"]["ClientID"]},
+ }
+ if "OAuthHttpParameters" in oauth_params:
+ public_params["OAuthParameters"]["OAuthHttpParameters"] = oauth_params.get(
+ "OAuthHttpParameters"
+ )
+
+ if "InvocationHttpParameters" in auth_parameters:
+ public_params["InvocationHttpParameters"] = auth_parameters["InvocationHttpParameters"]
+
+ return public_params
+
+ def _validate_input(
+ self,
+ name: ConnectionName,
+ authorization_type: ConnectionAuthorizationType,
+ ) -> None:
+ errors = []
+ errors.extend(self._validate_connection_name(name))
+ errors.extend(self._validate_auth_type(authorization_type))
+ if errors:
+ error_message = (
+ f"{len(errors)} validation error{'s' if len(errors) > 1 else ''} detected: "
+ )
+ error_message += "; ".join(errors)
+ raise ValidationException(error_message)
+
+ def _validate_connection_name(self, name: str) -> list[str]:
+ errors = []
+ if not re.match("^[\\.\\-_A-Za-z0-9]+$", name):
+ errors.append(
+ f"Value '{name}' at 'name' failed to satisfy constraint: "
+ "Member must satisfy regular expression pattern: [\\.\\-_A-Za-z0-9]+"
+ )
+ if not (1 <= len(name) <= 64):
+ errors.append(
+ f"Value '{name}' at 'name' failed to satisfy constraint: "
+ "Member must have length less than or equal to 64"
+ )
+ return errors
+
+ def _validate_auth_type(self, auth_type: str) -> list[str]:
+ if auth_type not in VALID_AUTH_TYPES:
+ return [
+ f"Value '{auth_type}' at 'authorizationType' failed to satisfy constraint: "
+ f"Member must satisfy enum value set: [{', '.join(VALID_AUTH_TYPES)}]"
+ ]
+ return []
+
+
+ConnectionServiceDict = dict[Arn, ConnectionService]
diff --git a/localstack-core/localstack/services/events/event_bus.py b/localstack-core/localstack/services/events/event_bus.py
new file mode 100644
index 0000000000000..1ea6f332a493b
--- /dev/null
+++ b/localstack-core/localstack/services/events/event_bus.py
@@ -0,0 +1,131 @@
+import json
+from datetime import datetime, timezone
+from typing import Optional, Self
+
+from localstack.aws.api.events import (
+ Action,
+ Arn,
+ Condition,
+ EventBusName,
+ Principal,
+ ResourceNotFoundException,
+ StatementId,
+ TagList,
+)
+from localstack.services.events.models import EventBus, ResourcePolicy, RuleDict, Statement
+from localstack.utils.aws.arns import get_partition
+
+
+class EventBusService:
+ name: EventBusName
+ region: str
+ account_id: str
+ event_source_name: str | None
+ tags: TagList | None
+ policy: str | None
+ event_bus: EventBus
+
+ def __init__(self, event_bus: EventBus):
+ self.event_bus = event_bus
+
+ @classmethod
+ def create_event_bus_service(
+ cls,
+ name: EventBusName,
+ region: str,
+ account_id: str,
+ event_source_name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[TagList] = None,
+ policy: Optional[str] = None,
+ rules: Optional[RuleDict] = None,
+ ) -> Self:
+ return cls(
+ EventBus(
+ name,
+ region,
+ account_id,
+ event_source_name,
+ description,
+ tags,
+ policy,
+ rules,
+ )
+ )
+
+ @property
+ def arn(self) -> Arn:
+ return self.event_bus.arn
+
+ def put_permission(
+ self,
+ action: Action,
+ principal: Principal,
+ statement_id: StatementId,
+ condition: Condition,
+ policy: str,
+ ):
+ # TODO: cover via test
+ # if policy and any([action, principal, statement_id, condition]):
+ # raise ValueError("Combination of policy with other arguments is not allowed")
+ self.event_bus.last_modified_time = datetime.now(timezone.utc)
+ if policy: # policy document replaces all existing permissions
+ policy = json.loads(policy)
+ parsed_policy = ResourcePolicy(**policy)
+ self.event_bus.policy = parsed_policy
+ else:
+ permission_statement = self._parse_statement(
+ statement_id, action, principal, self.arn, condition
+ )
+
+ if existing_policy := self.event_bus.policy:
+ if permission_statement["Principal"] == "*":
+ for statement in existing_policy["Statement"]:
+ if "*" == statement["Principal"]:
+ return
+ existing_policy["Statement"].append(permission_statement)
+ else:
+ parsed_policy = ResourcePolicy(
+ Version="2012-10-17", Statement=[permission_statement]
+ )
+ self.event_bus.policy = parsed_policy
+
+ def revoke_put_events_permission(self, statement_id: str):
+ policy = self.event_bus.policy
+ if not policy or not any(
+ statement.get("Sid") == statement_id for statement in policy["Statement"]
+ ):
+ raise ResourceNotFoundException("Statement with the provided id does not exist.")
+ if policy:
+ policy["Statement"] = [
+ statement
+ for statement in policy["Statement"]
+ if statement.get("Sid") != statement_id
+ ]
+ self.event_bus.last_modified_time = datetime.now(timezone.utc)
+
+ def _parse_statement(
+ self,
+ statement_id: StatementId,
+ action: Action,
+ principal: Principal,
+ resource_arn: Arn,
+ condition: Condition,
+ ) -> Statement:
+ # TODO: cover via test
+ # if condition and principal != "*":
+ # raise ValueError("Condition can only be set when principal is '*'")
+ if principal != "*":
+ principal = {"AWS": f"arn:{get_partition(self.event_bus.region)}:iam::{principal}:root"}
+ statement = Statement(
+ Sid=statement_id,
+ Effect="Allow",
+ Principal=principal,
+ Action=action,
+ Resource=resource_arn,
+ Condition=condition,
+ )
+ return statement
+
+
+EventBusServiceDict = dict[Arn, EventBusService]
diff --git a/localstack-core/localstack/services/events/event_rule_engine.py b/localstack-core/localstack/services/events/event_rule_engine.py
new file mode 100644
index 0000000000000..157bd6e95c367
--- /dev/null
+++ b/localstack-core/localstack/services/events/event_rule_engine.py
@@ -0,0 +1,609 @@
+import ipaddress
+import json
+import re
+import typing as t
+
+from localstack.aws.api.events import InvalidEventPatternException
+
+
+class EventRuleEngine:
+ def evaluate_pattern_on_event(self, compiled_event_pattern: dict, event: str | dict):
+ if isinstance(event, str):
+ try:
+ body = json.loads(event)
+ if not isinstance(body, dict):
+ return False
+ except json.JSONDecodeError:
+ # Event pattern for the message body assume that the message payload is a well-formed JSON object.
+ return False
+ else:
+ body = event
+
+ return self._evaluate_nested_event_pattern_on_dict(compiled_event_pattern, payload=body)
+
+ def _evaluate_nested_event_pattern_on_dict(self, event_pattern, payload: dict) -> bool:
+ """
+ This method evaluates the event pattern against the JSON decoded payload.
+ Although it's not documented anywhere, AWS allows `.` in the fields name in the event pattern and the payload,
+ and will evaluate them. However, it's not JSONPath compatible.
+ See:
+ https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-create-pattern.html#eb-create-pattern-considerations
+ Example:
+ Pattern: `{"field1.field2": "value1"}`
+ This pattern will match both `{"field1.field2": "value1"}` and {"field1: {"field2": "value1"}}`, unlike JSONPath
+ for which `.` points to a child node.
+ This might show they are flattening the both dictionaries to a single level for an easier matching without
+ recursion.
+ :param event_pattern: a dict, starting at the Event Pattern
+ :param payload: a dict, starting at the MessageBody
+ :return: True if the payload respect the event pattern, otherwise False
+ """
+ if not event_pattern:
+ return True
+
+ # TODO: maybe save/cache the flattened/expanded pattern?
+ flat_pattern_conditions = self.flatten_pattern(event_pattern)
+ flat_payloads = self.flatten_payload(payload)
+
+ return any(
+ all(
+ any(
+ self._evaluate_condition(
+ flat_payload.get(key), condition, field_exists=key in flat_payload
+ )
+ for condition in values
+ for flat_payload in flat_payloads
+ )
+ for key, values in flat_pattern.items()
+ )
+ for flat_pattern in flat_pattern_conditions
+ )
+
+ def _evaluate_condition(self, value, condition, field_exists: bool):
+ if not isinstance(condition, dict):
+ return field_exists and value == condition
+ elif (must_exist := condition.get("exists")) is not None:
+ # if must_exists is True then field_exists must be True
+ # if must_exists is False then fields_exists must be False
+ return must_exist == field_exists
+ elif (anything_but := condition.get("anything-but")) is not None:
+ if isinstance(anything_but, dict):
+ if (not_condition := anything_but.get("prefix")) is not None:
+ predicate = self._evaluate_prefix
+ elif (not_condition := anything_but.get("suffix")) is not None:
+ predicate = self._evaluate_suffix
+ elif (not_condition := anything_but.get("equals-ignore-case")) is not None:
+ predicate = self._evaluate_equal_ignore_case
+ elif (not_condition := anything_but.get("wildcard")) is not None:
+ predicate = self._evaluate_wildcard
+ else:
+ # this should not happen as we validate the EventPattern before
+ return False
+
+ if isinstance(not_condition, str):
+ return not predicate(not_condition, value)
+ elif isinstance(not_condition, list):
+ return all(
+ not predicate(sub_condition, value) for sub_condition in not_condition
+ )
+
+ elif isinstance(anything_but, list):
+ return value not in anything_but
+ else:
+ return value != anything_but
+
+ elif value is None:
+ # the remaining conditions require the value to not be None
+ return False
+ elif (prefix := condition.get("prefix")) is not None:
+ if isinstance(prefix, dict):
+ if (prefix_equal_ignore_case := prefix.get("equals-ignore-case")) is not None:
+ return self._evaluate_prefix(prefix_equal_ignore_case.lower(), value.lower())
+ else:
+ return self._evaluate_prefix(prefix, value)
+
+ elif (suffix := condition.get("suffix")) is not None:
+ if isinstance(suffix, dict):
+ if suffix_equal_ignore_case := suffix.get("equals-ignore-case"):
+ return self._evaluate_suffix(suffix_equal_ignore_case.lower(), value.lower())
+ else:
+ return self._evaluate_suffix(suffix, value)
+
+ elif (equal_ignore_case := condition.get("equals-ignore-case")) is not None:
+ return self._evaluate_equal_ignore_case(equal_ignore_case, value)
+
+ # we validated that `numeric` should be a non-empty list when creating the rule, we don't need the None check
+ elif numeric_condition := condition.get("numeric"):
+ return self._evaluate_numeric_condition(numeric_condition, value)
+
+ # we also validated the `cidr` that it cannot be empty
+ elif cidr := condition.get("cidr"):
+ return self._evaluate_cidr(cidr, value)
+
+ elif (wildcard := condition.get("wildcard")) is not None:
+ return self._evaluate_wildcard(wildcard, value)
+
+ return False
+
+ @staticmethod
+ def _evaluate_prefix(condition: str | list, value: str) -> bool:
+ return value.startswith(condition)
+
+ @staticmethod
+ def _evaluate_suffix(condition: str | list, value: str) -> bool:
+ return value.endswith(condition)
+
+ @staticmethod
+ def _evaluate_equal_ignore_case(condition: str, value: str) -> bool:
+ return condition.lower() == value.lower()
+
+ @staticmethod
+ def _evaluate_cidr(condition: str, value: str) -> bool:
+ try:
+ ip = ipaddress.ip_address(value)
+ return ip in ipaddress.ip_network(condition)
+ except ValueError:
+ return False
+
+ @staticmethod
+ def _evaluate_wildcard(condition: str, value: str) -> bool:
+ return re.match(re.escape(condition).replace("\\*", ".+") + "$", value)
+
+ @staticmethod
+ def _evaluate_numeric_condition(conditions: list, value: t.Any) -> bool:
+ if not isinstance(value, (int, float)):
+ return False
+ try:
+ # try if the value is numeric
+ value = float(value)
+ except ValueError:
+ # the value is not numeric, the condition is False
+ return False
+
+ for i in range(0, len(conditions), 2):
+ operator = conditions[i]
+ operand = float(conditions[i + 1])
+
+ if operator == "=":
+ if value != operand:
+ return False
+ elif operator == ">":
+ if value <= operand:
+ return False
+ elif operator == "<":
+ if value >= operand:
+ return False
+ elif operator == ">=":
+ if value < operand:
+ return False
+ elif operator == "<=":
+ if value > operand:
+ return False
+
+ return True
+
+ @staticmethod
+ def flatten_pattern(nested_dict: dict) -> list[dict]:
+ """
+ Takes a dictionary as input and will output the dictionary on a single level.
+ Input:
+ `{"field1": {"field2": {"field3": "val1", "field4": "val2"}}}`
+ Output:
+ `[
+ {
+ "field1.field2.field3": "val1",
+ "field1.field2.field4": "val2"
+ }
+ ]`
+ Input with $or will create multiple outputs:
+ `{"$or": [{"field1": "val1"}, {"field2": "val2"}], "field3": "val3"}`
+ Output:
+ `[
+ {"field1": "val1", "field3": "val3"},
+ {"field2": "val2", "field3": "val3"}
+ ]`
+ :param nested_dict: a (nested) dictionary
+ :return: a list of flattened dictionaries with no nested dict or list inside, flattened to a
+ single level, one list item for every list item encountered
+ """
+
+ def _traverse_event_pattern(obj, array=None, parent_key=None) -> list:
+ if array is None:
+ array = [{}]
+
+ for key, values in obj.items():
+ if key == "$or" and isinstance(values, list) and len(values) > 1:
+ # $or will create multiple new branches in the array.
+ # Each current branch will traverse with each choice in $or
+ array = [
+ i
+ for value in values
+ for i in _traverse_event_pattern(value, array, parent_key)
+ ]
+ else:
+ # We update the parent key do that {"key1": {"key2": ""}} becomes "key1.key2"
+ _parent_key = f"{parent_key}.{key}" if parent_key else key
+ if isinstance(values, dict):
+ # If the current key has child dict -- key: "key1", child: {"key2": ["val1", val2"]}
+ # We only update the parent_key and traverse its children with the current branches
+ array = _traverse_event_pattern(values, array, _parent_key)
+ else:
+ # If the current key has no child, this means we found the values to match -- child: ["val1", val2"]
+ # we update the branches with the parent chain and the values -- {"key1.key2": ["val1, val2"]}
+ array = [{**item, _parent_key: values} for item in array]
+
+ return array
+
+ return _traverse_event_pattern(nested_dict)
+
+ @staticmethod
+ def flatten_payload(nested_dict: dict) -> list[dict]:
+ """
+ Takes a dictionary as input and will output the dictionary on a single level.
+ The dictionary can have lists containing other dictionaries, and one root level entry will be created for every
+ item in a list.
+ Input:
+ `{"field1": {
+ "field2: [
+ {"field3: "val1", "field4": "val2"},
+ {"field3: "val3", "field4": "val4"},
+ }
+ ]}`
+ Output:
+ `[
+ {
+ "field1.field2.field3": "val1",
+ "field1.field2.field4": "val2"
+ },
+ {
+ "field1.field2.field3": "val3",
+ "field1.field2.field4": "val4"
+ },
+ ]`
+ :param nested_dict: a (nested) dictionary
+ :return: flatten_dict: a dictionary with no nested dict inside, flattened to a single level
+ """
+
+ def _traverse(_object: dict, array=None, parent_key=None) -> list:
+ if isinstance(_object, dict):
+ for key, values in _object.items():
+ # We update the parent key do that {"key1": {"key2": ""}} becomes "key1.key2"
+ _parent_key = f"{parent_key}.{key}" if parent_key else key
+ array = _traverse(values, array, _parent_key)
+
+ elif isinstance(_object, list):
+ if not _object:
+ return array
+ array = [i for value in _object for i in _traverse(value, array, parent_key)]
+ else:
+ array = [{**item, parent_key: _object} for item in array]
+ return array
+
+ return _traverse(nested_dict, array=[{}], parent_key=None)
+
+
+class EventPatternCompiler:
+ def __init__(self):
+ self.error_prefix = "Event pattern is not valid. Reason: "
+
+ def compile_event_pattern(self, event_pattern: str | dict) -> dict[str, t.Any]:
+ if isinstance(event_pattern, str):
+ try:
+ event_pattern = json.loads(event_pattern)
+ if not isinstance(event_pattern, dict):
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Filter is not an object"
+ )
+ except json.JSONDecodeError:
+ # this error message is not in parity, as it is tightly coupled to AWS parsing engine
+ raise InvalidEventPatternException(f"{self.error_prefix}Filter is not valid JSON")
+
+ aggregated_rules, combinations = self.aggregate_rules(event_pattern)
+
+ for rules in aggregated_rules:
+ for rule in rules:
+ self._validate_rule(rule)
+
+ return event_pattern
+
+ def aggregate_rules(self, event_pattern: dict[str, t.Any]) -> tuple[list[list[t.Any]], int]:
+ """
+ This method evaluate the event pattern recursively, and returns only a list of lists of rules.
+ It also calculates the combinations of rules, calculated depending on the nesting of the rules.
+ Example:
+ nested_event_pattern = {
+ "key_a": {
+ "key_b": {
+ "key_c": ["value_one", "value_two", "value_three", "value_four"]
+ }
+ },
+ "key_d": {
+ "key_e": ["value_one", "value_two", "value_three"]
+ }
+ }
+ This function then iterates on the values of the top level keys of the event pattern: ("key_a", "key_d")
+ If the iterated value is not a list, it means it is a nested property. If the scope is `MessageBody`, it is
+ allowed, we call this method on the value, adding a level to the depth to keep track on how deep the key is.
+ If the value is a list, it means it contains rules: we will append this list of rules in _rules, and
+ calculate the combinations it adds.
+ For the example event pattern containing nested properties, we calculate it this way
+ The first array has four values in a three-level nested key, and the second has three values in a two-level
+ nested key. 3 x 4 x 2 x 3 = 72
+ The return value would be:
+ [["value_one", "value_two", "value_three", "value_four"], ["value_one", "value_two", "value_three"]]
+ It allows us to later iterate of the list of rules in an easy way, to verify its conditions only.
+
+ :param event_pattern: a dict, starting at the Event Pattern
+ :return: a tuple with a list of lists of rules and the calculated number of combinations
+ """
+
+ def _inner(
+ pattern_elements: dict[str, t.Any], depth: int = 1, combinations: int = 1
+ ) -> tuple[list[list[t.Any]], int]:
+ _rules = []
+ for key, _value in pattern_elements.items():
+ if isinstance(_value, dict):
+ # From AWS docs: "unlike attribute-based policies, payload-based policies support property nesting."
+ sub_rules, combinations = _inner(
+ _value, depth=depth + 1, combinations=combinations
+ )
+ _rules.extend(sub_rules)
+ elif isinstance(_value, list):
+ if not _value:
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Empty arrays are not allowed"
+ )
+
+ current_combination = 0
+ if key == "$or":
+ for val in _value:
+ sub_rules, or_combinations = _inner(
+ val, depth=depth, combinations=combinations
+ )
+ _rules.extend(sub_rules)
+ current_combination += or_combinations
+
+ combinations = current_combination
+ else:
+ _rules.append(_value)
+ combinations = combinations * len(_value) * depth
+ else:
+ raise InvalidEventPatternException(
+ f'{self.error_prefix}"{key}" must be an object or an array'
+ )
+
+ return _rules, combinations
+
+ return _inner(event_pattern)
+
+ def _validate_rule(self, rule: t.Any, from_: str | None = None) -> None:
+ match rule:
+ case None | str() | bool():
+ return
+
+ case int() | float():
+ # TODO: AWS says they support only from -10^9 to 10^9 but seems to accept it, so we just return
+ # if rule <= -1000000000 or rule >= 1000000000:
+ # raise ""
+ return
+
+ case {**kwargs}:
+ if len(kwargs) != 1:
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Only one key allowed in match expression"
+ )
+
+ operator, value = None, None
+ for k, v in kwargs.items():
+ operator, value = k, v
+
+ if operator in (
+ "prefix",
+ "suffix",
+ ):
+ if from_ == "anything-but":
+ if isinstance(value, dict):
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Value of {from_} must be an array or single string/number value."
+ )
+
+ if not self._is_str_or_list_of_str(value):
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}prefix/suffix match pattern must be a string"
+ )
+ elif not value:
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Null prefix/suffix not allowed"
+ )
+
+ elif isinstance(value, dict):
+ for inner_operator in value.keys():
+ if inner_operator != "equals-ignore-case":
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Unsupported anything-but pattern: {inner_operator}"
+ )
+
+ elif not isinstance(value, str):
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}{operator} match pattern must be a string"
+ )
+ return
+
+ elif operator == "equals-ignore-case":
+ if from_ == "anything-but":
+ if not self._is_str_or_list_of_str(value):
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Inside {from_}/{operator} list, number|start|null|boolean is not supported."
+ )
+ elif not isinstance(value, str):
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}{operator} match pattern must be a string"
+ )
+ return
+
+ elif operator == "anything-but":
+ # anything-but can actually contain any kind of simple rule (str, number, and list)
+ if isinstance(value, list):
+ for v in value:
+ self._validate_rule(v)
+
+ return
+
+ # or have a nested `prefix`, `suffix` or `equals-ignore-case` pattern
+ elif isinstance(value, dict):
+ for inner_operator in value.keys():
+ if inner_operator not in (
+ "prefix",
+ "equals-ignore-case",
+ "suffix",
+ "wildcard",
+ ):
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Unsupported anything-but pattern: {inner_operator}"
+ )
+
+ self._validate_rule(value, from_="anything-but")
+ return
+
+ elif operator == "exists":
+ if not isinstance(value, bool):
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}exists match pattern must be either true or false."
+ )
+ return
+
+ elif operator == "numeric":
+ self._validate_numeric_condition(value)
+
+ elif operator == "cidr":
+ self._validate_cidr_condition(value)
+
+ elif operator == "wildcard":
+ if from_ == "anything-but" and isinstance(value, list):
+ for v in value:
+ self._validate_wildcard(v)
+ else:
+ self._validate_wildcard(value)
+
+ else:
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Unrecognized match type {operator}"
+ )
+
+ case _:
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Match value must be String, number, true, false, or null"
+ )
+
+ def _validate_numeric_condition(self, value):
+ if not isinstance(value, list):
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Value of numeric must be an array."
+ )
+ if not value:
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Invalid member in numeric match: ]"
+ )
+ num_values = value[::-1]
+
+ operator = num_values.pop()
+ if not isinstance(operator, str):
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Invalid member in numeric match: {operator}"
+ )
+ elif operator not in ("<", "<=", "=", ">", ">="):
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Unrecognized numeric range operator: {operator}"
+ )
+
+ value = num_values.pop() if num_values else None
+ if not isinstance(value, (int, float)):
+ exc_operator = "equals" if operator == "=" else operator
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Value of {exc_operator} must be numeric"
+ )
+
+ if not num_values:
+ return
+
+ if operator not in (">", ">="):
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Too many elements in numeric expression"
+ )
+
+ second_operator = num_values.pop()
+ if not isinstance(second_operator, str):
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Bad value in numeric range: {second_operator}"
+ )
+ elif second_operator not in ("<", "<="):
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Bad numeric range operator: {second_operator}"
+ )
+
+ second_value = num_values.pop() if num_values else None
+ if not isinstance(second_value, (int, float)):
+ exc_operator = "equals" if second_operator == "=" else second_operator
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Value of {exc_operator} must be numeric"
+ )
+
+ elif second_value <= value:
+ raise InvalidEventPatternException(f"{self.error_prefix}Bottom must be less than top")
+
+ elif num_values:
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Too many terms in numeric range expression"
+ )
+
+ def _validate_wildcard(self, value: t.Any):
+ if not isinstance(value, str):
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}wildcard match pattern must be a string"
+ )
+ # TODO: properly calculate complexity of wildcard
+ # https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-create-pattern-operators.html#eb-filtering-wildcard-matching-complexity
+ # > calculate complexity of repeating character sequences that occur after a wildcard character
+ if "**" in value:
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Consecutive wildcard characters at pos {value.index('**') + 1}"
+ )
+
+ if value.count("*") > 5:
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Rule is too complex - try using fewer wildcard characters or fewer repeating character sequences after a wildcard character"
+ )
+
+ def _validate_cidr_condition(self, value):
+ if not isinstance(value, str):
+ # `cidr` returns the prefix error
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}prefix match pattern must be a string"
+ )
+ ip_and_mask = value.split("/")
+ if len(ip_and_mask) != 2:
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Malformed CIDR, one '/' required"
+ )
+ ip_addr, mask = value.split("/")
+ try:
+ int(mask)
+ except ValueError:
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Malformed CIDR, mask bits must be an integer"
+ )
+ try:
+ ipaddress.ip_network(value)
+ except ValueError:
+ raise InvalidEventPatternException(
+ f"{self.error_prefix}Nonstandard IP address: {ip_addr}"
+ )
+
+ @staticmethod
+ def _is_str_or_list_of_str(value: t.Any) -> bool:
+ if not isinstance(value, (str, list)):
+ return False
+ if isinstance(value, list) and not all(isinstance(v, str) for v in value):
+ return False
+
+ return True
diff --git a/localstack-core/localstack/services/events/models.py b/localstack-core/localstack/services/events/models.py
new file mode 100644
index 0000000000000..95e64ece83711
--- /dev/null
+++ b/localstack-core/localstack/services/events/models.py
@@ -0,0 +1,340 @@
+import uuid
+from dataclasses import dataclass, field
+from datetime import datetime, timezone
+from enum import Enum
+from typing import Literal, Optional, TypeAlias, TypedDict
+
+from localstack.aws.api.core import ServiceException
+from localstack.aws.api.events import (
+ ApiDestinationDescription,
+ ApiDestinationHttpMethod,
+ ApiDestinationInvocationRateLimitPerSecond,
+ ApiDestinationName,
+ ApiDestinationState,
+ ArchiveDescription,
+ ArchiveName,
+ ArchiveState,
+ Arn,
+ ConnectionArn,
+ ConnectionAuthorizationType,
+ ConnectionDescription,
+ ConnectionName,
+ ConnectionState,
+ ConnectivityResourceParameters,
+ CreateConnectionAuthRequestParameters,
+ CreatedBy,
+ EventBusName,
+ EventPattern,
+ EventResourceList,
+ EventSourceName,
+ EventTime,
+ HttpsEndpoint,
+ ManagedBy,
+ ReplayDescription,
+ ReplayDestination,
+ ReplayName,
+ ReplayState,
+ ReplayStateReason,
+ RetentionDays,
+ RoleArn,
+ RuleDescription,
+ RuleName,
+ RuleState,
+ ScheduleExpression,
+ TagList,
+ Target,
+ TargetId,
+ Timestamp,
+)
+from localstack.services.stores import (
+ AccountRegionBundle,
+ BaseStore,
+ CrossRegionAttribute,
+ LocalAttribute,
+)
+from localstack.utils.aws.arns import (
+ event_bus_arn,
+ events_api_destination_arn,
+ events_archive_arn,
+ events_connection_arn,
+ events_replay_arn,
+ events_rule_arn,
+)
+from localstack.utils.strings import short_uid
+from localstack.utils.tagging import TaggingService
+
+TargetDict = dict[TargetId, Target]
+
+
+class ValidationException(ServiceException):
+ code: str = "ValidationException"
+ sender_fault: bool = True
+ status_code: int = 400
+
+
+class InvalidEventPatternException(Exception):
+ reason: str
+
+ def __init__(self, reason=None, message=None) -> None:
+ self.reason = reason
+ self.message = message or f"Event pattern is not valid. Reason: {reason}"
+
+
+FormattedEvent = TypedDict( # functional syntax required due to name-name keys
+ "FormattedEvent",
+ {
+ "version": str,
+ "id": str,
+ "detail-type": Optional[str],
+ "source": Optional[EventSourceName],
+ "account": str,
+ "time": EventTime,
+ "region": str,
+ "resources": Optional[EventResourceList],
+ "detail": dict[str, str | dict],
+ "replay-name": Optional[ReplayName],
+ "event-bus-name": EventBusName,
+ },
+)
+
+
+FormattedEventDict = dict[str, FormattedEvent]
+FormattedEventList = list[FormattedEvent]
+
+TransformedEvent: TypeAlias = FormattedEvent | dict | str
+
+
+class ResourceType(Enum):
+ EVENT_BUS = "event_bus"
+ RULE = "rule"
+
+
+class Condition(TypedDict):
+ Type: Literal["StringEquals"]
+ Key: Literal["aws:PrincipalOrgID"]
+ Value: str
+
+
+class Statement(TypedDict):
+ Sid: str
+ Effect: str
+ Principal: str | dict[str, str]
+ Action: str
+ Resource: str
+ Condition: Condition
+
+
+class ResourcePolicy(TypedDict):
+ Version: str
+ Statement: list[Statement]
+
+
+@dataclass
+class Rule:
+ name: RuleName
+ region: str
+ account_id: str
+ schedule_expression: Optional[ScheduleExpression] = None
+ event_pattern: Optional[EventPattern] = None
+ state: Optional[RuleState] = None
+ description: Optional[RuleDescription] = None
+ role_arn: Optional[RoleArn] = None
+ tags: TagList = field(default_factory=list)
+ event_bus_name: EventBusName = "default"
+ targets: TargetDict = field(default_factory=dict)
+ managed_by: Optional[ManagedBy] = None # can only be set by AWS services
+ created_by: CreatedBy = field(init=False)
+
+ def __post_init__(self):
+ self.created_by = self.account_id
+ if self.tags is None:
+ self.tags = []
+ if self.targets is None:
+ self.targets = {}
+ if self.state is None:
+ self.state = RuleState.ENABLED
+
+ @property
+ def arn(self) -> Arn:
+ return events_rule_arn(self.name, self.account_id, self.region, self.event_bus_name)
+
+
+RuleDict = dict[RuleName, Rule]
+
+
+@dataclass
+class Replay:
+ name: str
+ region: str
+ account_id: str
+ event_source_arn: Arn
+ destination: ReplayDestination # Event Bus Arn or Rule Arns
+ event_start_time: Timestamp
+ event_end_time: Timestamp
+ description: Optional[ReplayDescription] = None
+ state: Optional[ReplayState] = None
+ state_reason: Optional[ReplayStateReason] = None
+ event_last_replayed_time: Optional[Timestamp] = None
+ replay_start_time: Optional[Timestamp] = None
+ replay_end_time: Optional[Timestamp] = None
+
+ @property
+ def arn(self) -> Arn:
+ return events_replay_arn(self.name, self.account_id, self.region)
+
+
+ReplayDict = dict[ReplayName, Replay]
+
+
+@dataclass
+class Archive:
+ name: ArchiveName
+ region: str
+ account_id: str
+ event_source_arn: Arn
+ description: ArchiveDescription = None
+ event_pattern: EventPattern = None
+ retention_days: RetentionDays = None
+ state: ArchiveState = ArchiveState.DISABLED
+ creation_time: Timestamp = None
+ size_bytes: int = 0 # TODO how to deal with updating this value?
+ events: FormattedEventDict = field(default_factory=dict)
+
+ @property
+ def arn(self) -> Arn:
+ return events_archive_arn(self.name, self.account_id, self.region)
+
+ @property
+ def event_count(self) -> int:
+ return len(self.events)
+
+
+ArchiveDict = dict[ArchiveName, Archive]
+
+
+@dataclass
+class EventBus:
+ name: EventBusName
+ region: str
+ account_id: str
+ event_source_name: Optional[str] = None
+ description: Optional[str] = None
+ tags: TagList = field(default_factory=list)
+ policy: Optional[ResourcePolicy] = None
+ rules: RuleDict = field(default_factory=dict)
+ creation_time: Timestamp = field(init=False)
+ last_modified_time: Timestamp = field(init=False)
+
+ def __post_init__(self):
+ self.creation_time = datetime.now(timezone.utc)
+ self.last_modified_time = datetime.now(timezone.utc)
+ if self.rules is None:
+ self.rules = {}
+ if self.tags is None:
+ self.tags = []
+
+ @property
+ def arn(self) -> Arn:
+ return event_bus_arn(self.name, self.account_id, self.region)
+
+
+EventBusDict = dict[EventBusName, EventBus]
+
+
+@dataclass
+class Connection:
+ name: ConnectionName
+ region: str
+ account_id: str
+ authorization_type: ConnectionAuthorizationType
+ auth_parameters: CreateConnectionAuthRequestParameters
+ state: ConnectionState
+ secret_arn: Arn
+ description: ConnectionDescription | None = None
+ invocation_connectivity_parameters: ConnectivityResourceParameters | None = None
+ creation_time: Timestamp = field(init=False)
+ last_modified_time: Timestamp = field(init=False)
+ last_authorized_time: Timestamp = field(init=False)
+ tags: TagList = field(default_factory=list)
+ id: str = str(uuid.uuid4())
+
+ def __post_init__(self):
+ timestamp_now = datetime.now(timezone.utc)
+ self.creation_time = timestamp_now
+ self.last_modified_time = timestamp_now
+ self.last_authorized_time = timestamp_now
+ if self.tags is None:
+ self.tags = []
+
+ @property
+ def arn(self) -> Arn:
+ return events_connection_arn(self.name, self.id, self.account_id, self.region)
+
+
+ConnectionDict = dict[ConnectionName, Connection]
+
+
+@dataclass
+class ApiDestination:
+ name: ApiDestinationName
+ region: str
+ account_id: str
+ connection_arn: ConnectionArn
+ invocation_endpoint: HttpsEndpoint
+ http_method: ApiDestinationHttpMethod
+ state: ApiDestinationState
+ _invocation_rate_limit_per_second: ApiDestinationInvocationRateLimitPerSecond | None = None
+ description: ApiDestinationDescription | None = None
+ creation_time: Timestamp = field(init=False)
+ last_modified_time: Timestamp = field(init=False)
+ last_authorized_time: Timestamp = field(init=False)
+ tags: TagList = field(default_factory=list)
+ id: str = str(short_uid())
+
+ def __post_init__(self):
+ timestamp_now = datetime.now(timezone.utc)
+ self.creation_time = timestamp_now
+ self.last_modified_time = timestamp_now
+ self.last_authorized_time = timestamp_now
+ if self.tags is None:
+ self.tags = []
+
+ @property
+ def arn(self) -> Arn:
+ return events_api_destination_arn(self.name, self.id, self.account_id, self.region)
+
+ @property
+ def invocation_rate_limit_per_second(self) -> int:
+ return self._invocation_rate_limit_per_second or 300 # Default value
+
+ @invocation_rate_limit_per_second.setter
+ def invocation_rate_limit_per_second(
+ self, value: ApiDestinationInvocationRateLimitPerSecond | None
+ ):
+ self._invocation_rate_limit_per_second = value
+
+
+ApiDestinationDict = dict[ApiDestinationName, ApiDestination]
+
+
+class EventsStore(BaseStore):
+ # Map of eventbus names to eventbus objects. The name MUST be unique per account and region (works with AccountRegionBundle)
+ event_buses: EventBusDict = LocalAttribute(default=dict)
+
+ # Map of archive names to archive objects. The name MUST be unique per account and region (works with AccountRegionBundle)
+ archives: ArchiveDict = LocalAttribute(default=dict)
+
+ # Map of replay names to replay objects. The name MUST be unique per account and region (works with AccountRegionBundle)
+ replays: ReplayDict = LocalAttribute(default=dict)
+
+ # Map of connection names to connection objects.
+ connections: ConnectionDict = LocalAttribute(default=dict)
+
+ # Map of api destination names to api destination objects
+ api_destinations: ApiDestinationDict = LocalAttribute(default=dict)
+
+ # Maps resource ARN to tags
+ TAGS: TaggingService = CrossRegionAttribute(default=TaggingService)
+
+
+events_stores = AccountRegionBundle("events", EventsStore)
diff --git a/localstack-core/localstack/services/events/provider.py b/localstack-core/localstack/services/events/provider.py
new file mode 100644
index 0000000000000..a0a79e24cc1a0
--- /dev/null
+++ b/localstack-core/localstack/services/events/provider.py
@@ -0,0 +1,1912 @@
+import base64
+import json
+import logging
+import re
+from typing import Callable, Optional
+
+from localstack.aws.api import RequestContext, handler
+from localstack.aws.api.config import TagsList
+from localstack.aws.api.events import (
+ Action,
+ ApiDestinationDescription,
+ ApiDestinationHttpMethod,
+ ApiDestinationInvocationRateLimitPerSecond,
+ ApiDestinationName,
+ ApiDestinationResponseList,
+ ArchiveDescription,
+ ArchiveName,
+ ArchiveResponseList,
+ ArchiveState,
+ Arn,
+ Boolean,
+ CancelReplayResponse,
+ Condition,
+ ConnectionArn,
+ ConnectionAuthorizationType,
+ ConnectionDescription,
+ ConnectionName,
+ ConnectionResponseList,
+ ConnectionState,
+ ConnectivityResourceParameters,
+ CreateApiDestinationResponse,
+ CreateArchiveResponse,
+ CreateConnectionAuthRequestParameters,
+ CreateConnectionResponse,
+ CreateEventBusResponse,
+ DeadLetterConfig,
+ DeleteApiDestinationResponse,
+ DeleteArchiveResponse,
+ DeleteConnectionResponse,
+ DescribeApiDestinationResponse,
+ DescribeArchiveResponse,
+ DescribeConnectionResponse,
+ DescribeEventBusResponse,
+ DescribeReplayResponse,
+ DescribeRuleResponse,
+ EndpointId,
+ EventBusDescription,
+ EventBusList,
+ EventBusName,
+ EventBusNameOrArn,
+ EventPattern,
+ EventsApi,
+ EventSourceName,
+ HttpsEndpoint,
+ InternalException,
+ KmsKeyIdentifier,
+ LimitMax100,
+ ListApiDestinationsResponse,
+ ListArchivesResponse,
+ ListConnectionsResponse,
+ ListEventBusesResponse,
+ ListReplaysResponse,
+ ListRuleNamesByTargetResponse,
+ ListRulesResponse,
+ ListTagsForResourceResponse,
+ ListTargetsByRuleResponse,
+ NextToken,
+ NonPartnerEventBusName,
+ Principal,
+ PutEventsRequestEntry,
+ PutEventsRequestEntryList,
+ PutEventsResponse,
+ PutEventsResultEntry,
+ PutEventsResultEntryList,
+ PutPartnerEventsRequestEntryList,
+ PutPartnerEventsResponse,
+ PutRuleResponse,
+ PutTargetsResponse,
+ RemoveTargetsResponse,
+ ReplayDescription,
+ ReplayDestination,
+ ReplayList,
+ ReplayName,
+ ReplayState,
+ ResourceAlreadyExistsException,
+ ResourceNotFoundException,
+ RetentionDays,
+ RoleArn,
+ RuleDescription,
+ RuleName,
+ RuleResponseList,
+ RuleState,
+ ScheduleExpression,
+ StartReplayResponse,
+ StatementId,
+ String,
+ TagKeyList,
+ TagList,
+ TagResourceResponse,
+ Target,
+ TargetArn,
+ TargetId,
+ TargetIdList,
+ TargetList,
+ TestEventPatternResponse,
+ Timestamp,
+ UntagResourceResponse,
+ UpdateApiDestinationResponse,
+ UpdateArchiveResponse,
+ UpdateConnectionAuthRequestParameters,
+ UpdateConnectionResponse,
+)
+from localstack.aws.api.events import ApiDestination as ApiTypeApiDestination
+from localstack.aws.api.events import Archive as ApiTypeArchive
+from localstack.aws.api.events import Connection as ApiTypeConnection
+from localstack.aws.api.events import EventBus as ApiTypeEventBus
+from localstack.aws.api.events import Replay as ApiTypeReplay
+from localstack.aws.api.events import Rule as ApiTypeRule
+from localstack.services.events.api_destination import (
+ APIDestinationService,
+ ApiDestinationServiceDict,
+)
+from localstack.services.events.archive import ArchiveService, ArchiveServiceDict
+from localstack.services.events.connection import (
+ ConnectionService,
+ ConnectionServiceDict,
+)
+from localstack.services.events.event_bus import EventBusService, EventBusServiceDict
+from localstack.services.events.models import (
+ ApiDestination,
+ ApiDestinationDict,
+ Archive,
+ ArchiveDict,
+ Connection,
+ ConnectionDict,
+ EventBus,
+ EventBusDict,
+ EventsStore,
+ FormattedEvent,
+ Replay,
+ ReplayDict,
+ ResourceType,
+ Rule,
+ RuleDict,
+ TargetDict,
+ ValidationException,
+ events_stores,
+)
+from localstack.services.events.replay import ReplayService, ReplayServiceDict
+from localstack.services.events.rule import RuleService, RuleServiceDict
+from localstack.services.events.scheduler import JobScheduler
+from localstack.services.events.target import (
+ TargetSender,
+ TargetSenderDict,
+ TargetSenderFactory,
+)
+from localstack.services.events.usage import rule_error, rule_invocation
+from localstack.services.events.utils import (
+ TARGET_ID_PATTERN,
+ extract_connection_name,
+ extract_event_bus_name,
+ extract_region_and_account_id,
+ format_event,
+ get_resource_type,
+ get_trace_header_encoded_region_account,
+ is_archive_arn,
+ recursive_remove_none_values_from_dict,
+)
+from localstack.services.plugins import ServiceLifecycleHook
+from localstack.utils.common import truncate
+from localstack.utils.event_matcher import matches_event
+from localstack.utils.strings import long_uid
+from localstack.utils.time import TIMESTAMP_FORMAT_TZ, timestamp
+
+LOG = logging.getLogger(__name__)
+
+ARCHIVE_TARGET_ID_NAME_PATTERN = re.compile(r"^Events-Archive-(?P[a-zA-Z0-9_-]+)$")
+
+
+def decode_next_token(token: NextToken) -> int:
+ """Decode a pagination token from base64 to integer."""
+ return int.from_bytes(base64.b64decode(token), "big")
+
+
+def encode_next_token(token: int) -> NextToken:
+ """Encode a pagination token to base64 from integer."""
+ return base64.b64encode(token.to_bytes(128, "big")).decode("utf-8")
+
+
+def get_filtered_dict(name_prefix: str, input_dict: dict) -> dict:
+ """Filter dictionary by prefix."""
+ return {name: value for name, value in input_dict.items() if name.startswith(name_prefix)}
+
+
+def validate_event(event: PutEventsRequestEntry) -> None | PutEventsResultEntry:
+ if not event.get("Source"):
+ return {
+ "ErrorCode": "InvalidArgument",
+ "ErrorMessage": "Parameter Source is not valid. Reason: Source is a required argument.",
+ }
+ elif not event.get("DetailType"):
+ return {
+ "ErrorCode": "InvalidArgument",
+ "ErrorMessage": "Parameter DetailType is not valid. Reason: DetailType is a required argument.",
+ }
+ elif not event.get("Detail"):
+ return {
+ "ErrorCode": "InvalidArgument",
+ "ErrorMessage": "Parameter Detail is not valid. Reason: Detail is a required argument.",
+ }
+ elif event.get("Detail") and len(event["Detail"]) >= 262144:
+ raise ValidationException("Total size of the entries in the request is over the limit.")
+ elif event.get("Detail"):
+ try:
+ json_detail = json.loads(event.get("Detail"))
+ if isinstance(json_detail, dict):
+ return
+ except json.JSONDecodeError:
+ pass
+
+ return {
+ "ErrorCode": "MalformedDetail",
+ "ErrorMessage": "Detail is malformed.",
+ }
+
+
+def check_unique_tags(tags: TagsList) -> None:
+ unique_tag_keys = {tag["Key"] for tag in tags}
+ if len(unique_tag_keys) < len(tags):
+ raise ValidationException("Invalid parameter: Duplicated keys are not allowed.")
+
+
+class EventsProvider(EventsApi, ServiceLifecycleHook):
+ # api methods are grouped by resource type and sorted in alphabetical order
+ # functions in each group is sorted alphabetically
+ def __init__(self):
+ self._event_bus_services_store: EventBusServiceDict = {}
+ self._rule_services_store: RuleServiceDict = {}
+ self._target_sender_store: TargetSenderDict = {}
+ self._archive_service_store: ArchiveServiceDict = {}
+ self._replay_service_store: ReplayServiceDict = {}
+ self._connection_service_store: ConnectionServiceDict = {}
+ self._api_destination_service_store: ApiDestinationServiceDict = {}
+
+ def on_before_start(self):
+ JobScheduler.start()
+
+ def on_before_stop(self):
+ JobScheduler.shutdown()
+
+ ##################
+ # API Destinations
+ ##################
+ @handler("CreateApiDestination")
+ def create_api_destination(
+ self,
+ context: RequestContext,
+ name: ApiDestinationName,
+ connection_arn: ConnectionArn,
+ invocation_endpoint: HttpsEndpoint,
+ http_method: ApiDestinationHttpMethod,
+ description: ApiDestinationDescription = None,
+ invocation_rate_limit_per_second: ApiDestinationInvocationRateLimitPerSecond = None,
+ **kwargs,
+ ) -> CreateApiDestinationResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ if name in store.api_destinations:
+ raise ResourceAlreadyExistsException(f"An api-destination '{name}' already exists.")
+ APIDestinationService.validate_input(name, connection_arn, http_method, invocation_endpoint)
+ connection_name = extract_connection_name(connection_arn)
+ connection = self.get_connection(connection_name, store)
+ api_destination_service = self.create_api_destinations_service(
+ name,
+ region,
+ account_id,
+ connection_arn,
+ connection,
+ invocation_endpoint,
+ http_method,
+ invocation_rate_limit_per_second,
+ description,
+ )
+ store.api_destinations[api_destination_service.api_destination.name] = (
+ api_destination_service.api_destination
+ )
+
+ response = CreateApiDestinationResponse(
+ ApiDestinationArn=api_destination_service.arn,
+ ApiDestinationState=api_destination_service.state,
+ CreationTime=api_destination_service.creation_time,
+ LastModifiedTime=api_destination_service.last_modified_time,
+ )
+ return response
+
+ @handler("DescribeApiDestination")
+ def describe_api_destination(
+ self, context: RequestContext, name: ApiDestinationName, **kwargs
+ ) -> DescribeApiDestinationResponse:
+ store = self.get_store(context.region, context.account_id)
+ api_destination = self.get_api_destination(name, store)
+
+ response = self._api_destination_to_api_type_api_destination(api_destination)
+ return response
+
+ @handler("DeleteApiDestination")
+ def delete_api_destination(
+ self, context: RequestContext, name: ApiDestinationName, **kwargs
+ ) -> DeleteApiDestinationResponse:
+ store = self.get_store(context.region, context.account_id)
+ if api_destination := self.get_api_destination(name, store):
+ del self._api_destination_service_store[api_destination.arn]
+ del store.api_destinations[name]
+ del store.TAGS[api_destination.arn]
+
+ return DeleteApiDestinationResponse()
+
+ @handler("ListApiDestinations")
+ def list_api_destinations(
+ self,
+ context: RequestContext,
+ name_prefix: ApiDestinationName = None,
+ connection_arn: ConnectionArn = None,
+ next_token: NextToken = None,
+ limit: LimitMax100 = None,
+ **kwargs,
+ ) -> ListApiDestinationsResponse:
+ store = self.get_store(context.region, context.account_id)
+ api_destinations = (
+ get_filtered_dict(name_prefix, store.api_destinations)
+ if name_prefix
+ else store.api_destinations
+ )
+ limited_rules, next_token = self._get_limited_dict_and_next_token(
+ api_destinations, next_token, limit
+ )
+
+ response = ListApiDestinationsResponse(
+ ApiDestinations=list(
+ self._api_destination_dict_to_api_destination_response_list(limited_rules)
+ )
+ )
+ if next_token is not None:
+ response["NextToken"] = next_token
+ return response
+
+ @handler("UpdateApiDestination")
+ def update_api_destination(
+ self,
+ context: RequestContext,
+ name: ApiDestinationName,
+ description: ApiDestinationDescription = None,
+ connection_arn: ConnectionArn = None,
+ invocation_endpoint: HttpsEndpoint = None,
+ http_method: ApiDestinationHttpMethod = None,
+ invocation_rate_limit_per_second: ApiDestinationInvocationRateLimitPerSecond = None,
+ **kwargs,
+ ) -> UpdateApiDestinationResponse:
+ store = self.get_store(context.region, context.account_id)
+ api_destination = self.get_api_destination(name, store)
+ api_destination_service = self._api_destination_service_store[api_destination.arn]
+ if connection_arn:
+ connection_name = extract_connection_name(connection_arn)
+ connection = self.get_connection(connection_name, store)
+ else:
+ connection = api_destination_service.connection
+ api_destination_service.update(
+ connection,
+ invocation_endpoint,
+ http_method,
+ invocation_rate_limit_per_second,
+ description,
+ )
+
+ response = UpdateApiDestinationResponse(
+ ApiDestinationArn=api_destination_service.arn,
+ ApiDestinationState=api_destination_service.state,
+ CreationTime=api_destination_service.creation_time,
+ LastModifiedTime=api_destination_service.last_modified_time,
+ )
+ return response
+
+ #############
+ # Connections
+ #############
+ @handler("CreateConnection")
+ def create_connection(
+ self,
+ context: RequestContext,
+ name: ConnectionName,
+ authorization_type: ConnectionAuthorizationType,
+ auth_parameters: CreateConnectionAuthRequestParameters,
+ description: ConnectionDescription = None,
+ invocation_connectivity_parameters: ConnectivityResourceParameters = None,
+ **kwargs,
+ ) -> CreateConnectionResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ if name in store.connections:
+ raise ResourceAlreadyExistsException(f"Connection {name} already exists.")
+ connection_service = self.create_connection_service(
+ name,
+ region,
+ account_id,
+ authorization_type,
+ auth_parameters,
+ description,
+ invocation_connectivity_parameters,
+ )
+ store.connections[connection_service.connection.name] = connection_service.connection
+
+ response = CreateConnectionResponse(
+ ConnectionArn=connection_service.arn,
+ ConnectionState=connection_service.state,
+ CreationTime=connection_service.creation_time,
+ LastModifiedTime=connection_service.last_modified_time,
+ )
+ return response
+
+ @handler("DescribeConnection")
+ def describe_connection(
+ self, context: RequestContext, name: ConnectionName, **kwargs
+ ) -> DescribeConnectionResponse:
+ store = self.get_store(context.region, context.account_id)
+ connection = self.get_connection(name, store)
+
+ response = self._connection_to_api_type_connection(connection)
+ return response
+
+ @handler("DeleteConnection")
+ def delete_connection(
+ self, context: RequestContext, name: ConnectionName, **kwargs
+ ) -> DeleteConnectionResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ if connection := self.get_connection(name, store):
+ connection_service = self._connection_service_store.pop(connection.arn)
+ connection_service.delete()
+ del store.connections[name]
+ del store.TAGS[connection.arn]
+
+ response = DeleteConnectionResponse(
+ ConnectionArn=connection.arn,
+ ConnectionState=connection.state,
+ CreationTime=connection.creation_time,
+ LastModifiedTime=connection.last_modified_time,
+ LastAuthorizedTime=connection.last_authorized_time,
+ )
+ return response
+
+ @handler("ListConnections")
+ def list_connections(
+ self,
+ context: RequestContext,
+ name_prefix: ConnectionName = None,
+ connection_state: ConnectionState = None,
+ next_token: NextToken = None,
+ limit: LimitMax100 = None,
+ **kwargs,
+ ) -> ListConnectionsResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ connections = (
+ get_filtered_dict(name_prefix, store.connections) if name_prefix else store.connections
+ )
+ limited_rules, next_token = self._get_limited_dict_and_next_token(
+ connections, next_token, limit
+ )
+
+ response = ListConnectionsResponse(
+ Connections=list(self._connection_dict_to_connection_response_list(limited_rules))
+ )
+ if next_token is not None:
+ response["NextToken"] = next_token
+ return response
+
+ @handler("UpdateConnection")
+ def update_connection(
+ self,
+ context: RequestContext,
+ name: ConnectionName,
+ description: ConnectionDescription = None,
+ authorization_type: ConnectionAuthorizationType = None,
+ auth_parameters: UpdateConnectionAuthRequestParameters = None,
+ invocation_connectivity_parameters: ConnectivityResourceParameters = None,
+ **kwargs,
+ ) -> UpdateConnectionResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ connection = self.get_connection(name, store)
+ connection_service = self._connection_service_store[connection.arn]
+ connection_service.update(
+ description, authorization_type, auth_parameters, invocation_connectivity_parameters
+ )
+
+ response = UpdateConnectionResponse(
+ ConnectionArn=connection_service.arn,
+ ConnectionState=connection_service.state,
+ CreationTime=connection_service.creation_time,
+ LastModifiedTime=connection_service.last_modified_time,
+ LastAuthorizedTime=connection_service.last_authorized_time,
+ )
+ return response
+
+ ##########
+ # EventBus
+ ##########
+
+ @handler("CreateEventBus")
+ def create_event_bus(
+ self,
+ context: RequestContext,
+ name: EventBusName,
+ event_source_name: EventSourceName = None,
+ description: EventBusDescription = None,
+ kms_key_identifier: KmsKeyIdentifier = None,
+ dead_letter_config: DeadLetterConfig = None,
+ tags: TagList = None,
+ **kwargs,
+ ) -> CreateEventBusResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ if name in store.event_buses:
+ raise ResourceAlreadyExistsException(f"Event bus {name} already exists.")
+ event_bus_service = self.create_event_bus_service(
+ name, region, account_id, event_source_name, description, tags
+ )
+ store.event_buses[event_bus_service.event_bus.name] = event_bus_service.event_bus
+
+ if tags:
+ self.tag_resource(context, event_bus_service.arn, tags)
+
+ response = CreateEventBusResponse(
+ EventBusArn=event_bus_service.arn,
+ )
+ if description := getattr(event_bus_service.event_bus, "description", None):
+ response["Description"] = description
+ return response
+
+ @handler("DeleteEventBus")
+ def delete_event_bus(self, context: RequestContext, name: EventBusName, **kwargs) -> None:
+ if name == "default":
+ raise ValidationException("Cannot delete event bus default.")
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ try:
+ if event_bus := self.get_event_bus(name, store):
+ del self._event_bus_services_store[event_bus.arn]
+ if rules := event_bus.rules:
+ self._delete_rule_services(rules)
+ del store.event_buses[name]
+ del store.TAGS[event_bus.arn]
+ except ResourceNotFoundException as error:
+ return error
+
+ @handler("DescribeEventBus")
+ def describe_event_bus(
+ self, context: RequestContext, name: EventBusNameOrArn = None, **kwargs
+ ) -> DescribeEventBusResponse:
+ name = extract_event_bus_name(name)
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ event_bus = self.get_event_bus(name, store)
+
+ response = self._event_bus_to_api_type_event_bus(event_bus)
+ return response
+
+ @handler("ListEventBuses")
+ def list_event_buses(
+ self,
+ context: RequestContext,
+ name_prefix: EventBusName = None,
+ next_token: NextToken = None,
+ limit: LimitMax100 = None,
+ **kwargs,
+ ) -> ListEventBusesResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ event_buses = (
+ get_filtered_dict(name_prefix, store.event_buses) if name_prefix else store.event_buses
+ )
+ limited_event_buses, next_token = self._get_limited_dict_and_next_token(
+ event_buses, next_token, limit
+ )
+
+ response = ListEventBusesResponse(
+ EventBuses=self._event_bust_dict_to_event_bus_response_list(limited_event_buses)
+ )
+ if next_token is not None:
+ response["NextToken"] = next_token
+ return response
+
+ @handler("PutPermission")
+ def put_permission(
+ self,
+ context: RequestContext,
+ event_bus_name: NonPartnerEventBusName = None,
+ action: Action = None,
+ principal: Principal = None,
+ statement_id: StatementId = None,
+ condition: Condition = None,
+ policy: String = None,
+ **kwargs,
+ ) -> None:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ event_bus = self.get_event_bus(event_bus_name, store)
+ event_bus_service = self._event_bus_services_store[event_bus.arn]
+ event_bus_service.put_permission(action, principal, statement_id, condition, policy)
+
+ @handler("RemovePermission")
+ def remove_permission(
+ self,
+ context: RequestContext,
+ statement_id: StatementId = None,
+ remove_all_permissions: Boolean = None,
+ event_bus_name: NonPartnerEventBusName = None,
+ **kwargs,
+ ) -> None:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ event_bus = self.get_event_bus(event_bus_name, store)
+ event_bus_service = self._event_bus_services_store[event_bus.arn]
+ if remove_all_permissions:
+ event_bus_service.event_bus.policy = None
+ return
+ if not statement_id:
+ raise ValidationException("Parameter StatementId is required.")
+ event_bus_service.revoke_put_events_permission(statement_id)
+
+ #######
+ # Rules
+ #######
+ @handler("EnableRule")
+ def enable_rule(
+ self,
+ context: RequestContext,
+ name: RuleName,
+ event_bus_name: EventBusNameOrArn = None,
+ **kwargs,
+ ) -> None:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ event_bus_name = extract_event_bus_name(event_bus_name)
+ event_bus = self.get_event_bus(event_bus_name, store)
+ rule = self.get_rule(name, event_bus)
+ rule.state = RuleState.ENABLED
+
+ @handler("DeleteRule")
+ def delete_rule(
+ self,
+ context: RequestContext,
+ name: RuleName,
+ event_bus_name: EventBusNameOrArn = None,
+ force: Boolean = None,
+ **kwargs,
+ ) -> None:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ event_bus_name = extract_event_bus_name(event_bus_name)
+ event_bus = self.get_event_bus(event_bus_name, store)
+ try:
+ rule = self.get_rule(name, event_bus)
+ if rule.targets and not force:
+ raise ValidationException("Rule can't be deleted since it has targets.")
+ self._delete_rule_services(rule)
+ del event_bus.rules[name]
+ del store.TAGS[rule.arn]
+ except ResourceNotFoundException as error:
+ return error
+
+ @handler("DescribeRule")
+ def describe_rule(
+ self,
+ context: RequestContext,
+ name: RuleName,
+ event_bus_name: EventBusNameOrArn = None,
+ **kwargs,
+ ) -> DescribeRuleResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ event_bus_name = extract_event_bus_name(event_bus_name)
+ event_bus = self.get_event_bus(event_bus_name, store)
+ rule = self.get_rule(name, event_bus)
+
+ response = self._rule_to_api_type_rule(rule)
+ return response
+
+ @handler("DisableRule")
+ def disable_rule(
+ self,
+ context: RequestContext,
+ name: RuleName,
+ event_bus_name: EventBusNameOrArn = None,
+ **kwargs,
+ ) -> None:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ event_bus_name = extract_event_bus_name(event_bus_name)
+ event_bus = self.get_event_bus(event_bus_name, store)
+ rule = self.get_rule(name, event_bus)
+ rule.state = RuleState.DISABLED
+
+ @handler("ListRules")
+ def list_rules(
+ self,
+ context: RequestContext,
+ name_prefix: RuleName = None,
+ event_bus_name: EventBusNameOrArn = None,
+ next_token: NextToken = None,
+ limit: LimitMax100 = None,
+ **kwargs,
+ ) -> ListRulesResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ event_bus_name = extract_event_bus_name(event_bus_name)
+ event_bus = self.get_event_bus(event_bus_name, store)
+ rules = get_filtered_dict(name_prefix, event_bus.rules) if name_prefix else event_bus.rules
+ limited_rules, next_token = self._get_limited_dict_and_next_token(rules, next_token, limit)
+
+ response = ListRulesResponse(
+ Rules=list(self._rule_dict_to_rule_response_list(limited_rules))
+ )
+ if next_token is not None:
+ response["NextToken"] = next_token
+ return response
+
+ @handler("ListRuleNamesByTarget")
+ def list_rule_names_by_target(
+ self,
+ context: RequestContext,
+ target_arn: TargetArn,
+ event_bus_name: EventBusNameOrArn = None,
+ next_token: NextToken = None,
+ limit: LimitMax100 = None,
+ **kwargs,
+ ) -> ListRuleNamesByTargetResponse:
+ raise NotImplementedError
+
+ @handler("PutRule")
+ def put_rule(
+ self,
+ context: RequestContext,
+ name: RuleName,
+ schedule_expression: ScheduleExpression = None,
+ event_pattern: EventPattern = None,
+ state: RuleState = None,
+ description: RuleDescription = None,
+ role_arn: RoleArn = None,
+ tags: TagList = None,
+ event_bus_name: EventBusNameOrArn = None,
+ **kwargs,
+ ) -> PutRuleResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ event_bus_name = extract_event_bus_name(event_bus_name)
+ event_bus = self.get_event_bus(event_bus_name, store)
+ existing_rule = event_bus.rules.get(name)
+ targets = existing_rule.targets if existing_rule else None
+ rule_service = self.create_rule_service(
+ name,
+ region,
+ account_id,
+ schedule_expression,
+ event_pattern,
+ state,
+ description,
+ role_arn,
+ tags,
+ event_bus_name,
+ targets,
+ )
+ event_bus.rules[name] = rule_service.rule
+
+ if tags:
+ self.tag_resource(context, rule_service.arn, tags)
+
+ response = PutRuleResponse(RuleArn=rule_service.arn)
+ return response
+
+ @handler("TestEventPattern")
+ def test_event_pattern(
+ self, context: RequestContext, event_pattern: EventPattern, event: str, **kwargs
+ ) -> TestEventPatternResponse:
+ """Test event pattern uses EventBridge event pattern matching:
+ https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-event-patterns.html
+ """
+ try:
+ json_event = json.loads(event)
+ except json.JSONDecodeError:
+ raise ValidationException("Parameter Event is not valid.")
+
+ mandatory_fields = {
+ "id",
+ "account",
+ "source",
+ "time",
+ "region",
+ "detail-type",
+ }
+ # https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_TestEventPattern.html
+ # the documentation says that `resources` is mandatory, but it is not in reality
+
+ if not isinstance(json_event, dict) or not mandatory_fields.issubset(json_event):
+ raise ValidationException("Parameter Event is not valid.")
+
+ result = matches_event(event_pattern, event)
+ return TestEventPatternResponse(Result=result)
+
+ #########
+ # Targets
+ #########
+
+ @handler("ListTargetsByRule")
+ def list_targets_by_rule(
+ self,
+ context: RequestContext,
+ rule: RuleName,
+ event_bus_name: EventBusNameOrArn = None,
+ next_token: NextToken = None,
+ limit: LimitMax100 = None,
+ **kwargs,
+ ) -> ListTargetsByRuleResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ event_bus_name = extract_event_bus_name(event_bus_name)
+ event_bus = self.get_event_bus(event_bus_name, store)
+ rule = self.get_rule(rule, event_bus)
+ targets = rule.targets
+ limited_targets, next_token = self._get_limited_dict_and_next_token(
+ targets, next_token, limit
+ )
+
+ response = ListTargetsByRuleResponse(Targets=list(limited_targets.values()))
+ if next_token is not None:
+ response["NextToken"] = next_token
+ return response
+
+ @handler("PutTargets")
+ def put_targets(
+ self,
+ context: RequestContext,
+ rule: RuleName,
+ targets: TargetList,
+ event_bus_name: EventBusNameOrArn = None,
+ **kwargs,
+ ) -> PutTargetsResponse:
+ region = context.region
+ account_id = context.account_id
+ rule_service = self.get_rule_service(region, account_id, rule, event_bus_name)
+ failed_entries = rule_service.add_targets(targets)
+ rule_arn = rule_service.arn
+ rule_name = rule_service.rule.name
+ for index, target in enumerate(targets): # TODO only add successful targets
+ target_id = target["Id"]
+ if len(target_id) > 64:
+ raise ValidationException(
+ rf"1 validation error detected: Value '{target_id}' at 'targets.{index + 1}.member.id' failed to satisfy constraint: Member must have length less than or equal to 64"
+ )
+ if not bool(TARGET_ID_PATTERN.match(target_id)):
+ raise ValidationException(
+ rf"1 validation error detected: Value '{target_id}' at 'targets.{index + 1}.member.id' failed to satisfy constraint: Member must satisfy regular expression pattern: [\.\-_A-Za-z0-9]+"
+ )
+ self.create_target_sender(target, rule_arn, rule_name, region, account_id)
+
+ if rule_service.schedule_cron:
+ schedule_job_function = self._get_scheduled_rule_job_function(
+ account_id, region, rule_service.rule
+ )
+ rule_service.create_schedule_job(schedule_job_function)
+ response = PutTargetsResponse(
+ FailedEntryCount=len(failed_entries), FailedEntries=failed_entries
+ )
+ return response
+
+ @handler("RemoveTargets")
+ def remove_targets(
+ self,
+ context: RequestContext,
+ rule: RuleName,
+ ids: TargetIdList,
+ event_bus_name: EventBusNameOrArn = None,
+ force: Boolean = None,
+ **kwargs,
+ ) -> RemoveTargetsResponse:
+ region = context.region
+ account_id = context.account_id
+ rule_service = self.get_rule_service(region, account_id, rule, event_bus_name)
+ failed_entries = rule_service.remove_targets(ids)
+ self._delete_target_sender(ids, rule_service.rule)
+
+ response = RemoveTargetsResponse(
+ FailedEntryCount=len(failed_entries), FailedEntries=failed_entries
+ )
+ return response
+
+ #########
+ # Archive
+ #########
+ @handler("CreateArchive")
+ def create_archive(
+ self,
+ context: RequestContext,
+ archive_name: ArchiveName,
+ event_source_arn: Arn,
+ description: ArchiveDescription = None,
+ event_pattern: EventPattern = None,
+ retention_days: RetentionDays = None,
+ **kwargs,
+ ) -> CreateArchiveResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ if archive_name in store.archives:
+ raise ResourceAlreadyExistsException(f"Archive {archive_name} already exists.")
+ self._check_event_bus_exists(event_source_arn, store)
+ archive_service = self.create_archive_service(
+ archive_name,
+ region,
+ account_id,
+ event_source_arn,
+ description,
+ event_pattern,
+ retention_days,
+ )
+ store.archives[archive_service.archive.name] = archive_service.archive
+
+ response = CreateArchiveResponse(
+ ArchiveArn=archive_service.arn,
+ State=archive_service.state,
+ CreationTime=archive_service.creation_time,
+ )
+ return response
+
+ @handler("DeleteArchive")
+ def delete_archive(
+ self, context: RequestContext, archive_name: ArchiveName, **kwargs
+ ) -> DeleteArchiveResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ if archive := self.get_archive(archive_name, store):
+ try:
+ archive_service = self._archive_service_store.pop(archive.arn)
+ archive_service.delete()
+ del store.archives[archive_name]
+ except ResourceNotFoundException as error:
+ return error
+
+ @handler("DescribeArchive")
+ def describe_archive(
+ self, context: RequestContext, archive_name: ArchiveName, **kwargs
+ ) -> DescribeArchiveResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ archive = self.get_archive(archive_name, store)
+
+ response = self._archive_to_describe_archive_response(archive)
+ return response
+
+ @handler("ListArchives")
+ def list_archives(
+ self,
+ context: RequestContext,
+ name_prefix: ArchiveName = None,
+ event_source_arn: Arn = None,
+ state: ArchiveState = None,
+ next_token: NextToken = None,
+ limit: LimitMax100 = None,
+ **kwargs,
+ ) -> ListArchivesResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ if event_source_arn:
+ self._check_event_bus_exists(event_source_arn, store)
+ archives = {
+ key: archive
+ for key, archive in store.archives.items()
+ if archive.event_source_arn == event_source_arn
+ }
+ elif name_prefix:
+ archives = get_filtered_dict(name_prefix, store.archives)
+ else:
+ archives = store.archives
+ limited_archives, next_token = self._get_limited_dict_and_next_token(
+ archives, next_token, limit
+ )
+
+ response = ListArchivesResponse(
+ Archives=list(self._archive_dict_to_archive_response_list(limited_archives))
+ )
+ if next_token is not None:
+ response["NextToken"] = next_token
+ return response
+
+ @handler("UpdateArchive")
+ def update_archive(
+ self,
+ context: RequestContext,
+ archive_name: ArchiveName,
+ description: ArchiveDescription = None,
+ event_pattern: EventPattern = None,
+ retention_days: RetentionDays = None,
+ **kwargs,
+ ) -> UpdateArchiveResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ try:
+ archive = self.get_archive(archive_name, store)
+ except ResourceNotFoundException:
+ raise InternalException("Service encountered unexpected problem. Please try again.")
+ archive_service = self._archive_service_store[archive.arn]
+ archive_service.update(description, event_pattern, retention_days)
+
+ response = UpdateArchiveResponse(
+ ArchiveArn=archive_service.arn,
+ State=archive.state,
+ # StateReason=archive.state_reason,
+ CreationTime=archive.creation_time,
+ )
+ return response
+
+ ########
+ # Events
+ ########
+
+ @handler("PutEvents")
+ def put_events(
+ self,
+ context: RequestContext,
+ entries: PutEventsRequestEntryList,
+ endpoint_id: EndpointId = None,
+ **kwargs,
+ ) -> PutEventsResponse:
+ if len(entries) > 10:
+ formatted_entries = [self._event_to_error_type_event(entry) for entry in entries]
+ formatted_entries = f"[{', '.join(formatted_entries)}]"
+ raise ValidationException(
+ f"1 validation error detected: Value '{formatted_entries}' at 'entries' failed to satisfy constraint: Member must have length less than or equal to 10"
+ )
+ entries, failed_entry_count = self._process_entries(context, entries)
+
+ response = PutEventsResponse(
+ Entries=entries,
+ FailedEntryCount=failed_entry_count,
+ )
+ return response
+
+ @handler("PutPartnerEvents")
+ def put_partner_events(
+ self,
+ context: RequestContext,
+ entries: PutPartnerEventsRequestEntryList,
+ **kwargs,
+ ) -> PutPartnerEventsResponse:
+ raise NotImplementedError
+
+ ########
+ # Replay
+ ########
+
+ @handler("CancelReplay")
+ def cancel_replay(
+ self, context: RequestContext, replay_name: ReplayName, **kwargs
+ ) -> CancelReplayResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ replay = self.get_replay(replay_name, store)
+ replay_service = self._replay_service_store[replay.arn]
+ replay_service.stop()
+ response = CancelReplayResponse(
+ ReplayArn=replay_service.arn,
+ State=replay_service.state,
+ # StateReason=replay_service.state_reason,
+ )
+ return response
+
+ @handler("DescribeReplay")
+ def describe_replay(
+ self, context: RequestContext, replay_name: ReplayName, **kwargs
+ ) -> DescribeReplayResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ replay = self.get_replay(replay_name, store)
+
+ response = self._replay_to_describe_replay_response(replay)
+ return response
+
+ @handler("ListReplays")
+ def list_replays(
+ self,
+ context: RequestContext,
+ name_prefix: ReplayName = None,
+ state: ReplayState = None,
+ event_source_arn: Arn = None,
+ next_token: NextToken = None,
+ limit: LimitMax100 = None,
+ **kwargs,
+ ) -> ListReplaysResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ if event_source_arn:
+ replays = {
+ key: replay
+ for key, replay in store.replays.items()
+ if replay.event_source_arn == event_source_arn
+ }
+ elif name_prefix:
+ replays = get_filtered_dict(name_prefix, store.replays)
+ else:
+ replays = store.replays
+ limited_replays, next_token = self._get_limited_dict_and_next_token(
+ replays, next_token, limit
+ )
+
+ response = ListReplaysResponse(
+ Replays=list(self._replay_dict_to_replay_response_list(limited_replays))
+ )
+ if next_token is not None:
+ response["NextToken"] = next_token
+ return response
+
+ @handler("StartReplay")
+ def start_replay(
+ self,
+ context: RequestContext,
+ replay_name: ReplayName,
+ event_source_arn: Arn, # Archive Arn
+ event_start_time: Timestamp,
+ event_end_time: Timestamp,
+ destination: ReplayDestination,
+ description: ReplayDescription = None,
+ **kwargs,
+ ) -> StartReplayResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ if replay_name in store.replays:
+ raise ResourceAlreadyExistsException(f"Replay {replay_name} already exists.")
+ self._validate_replay_time(event_start_time, event_end_time)
+ if event_source_arn not in self._archive_service_store:
+ archive_name = event_source_arn.split("/")[-1]
+ raise ValidationException(
+ f"Parameter EventSourceArn is not valid. Reason: Archive {archive_name} does not exist."
+ )
+ self._validate_replay_destination(destination, event_source_arn)
+ replay_service = self.create_replay_service(
+ replay_name,
+ region,
+ account_id,
+ event_source_arn,
+ destination,
+ event_start_time,
+ event_end_time,
+ description,
+ )
+ store.replays[replay_service.replay.name] = replay_service.replay
+ archive_service = self._archive_service_store[event_source_arn]
+ events_to_replay = archive_service.get_events(
+ replay_service.event_start_time, replay_service.event_end_time
+ )
+ replay_service.start(events_to_replay)
+ if events_to_replay:
+ re_formatted_event_to_replay = replay_service.re_format_events_from_archive(
+ events_to_replay, replay_name
+ )
+ self._process_entries(context, re_formatted_event_to_replay)
+ replay_service.finish()
+
+ response = StartReplayResponse(
+ ReplayArn=replay_service.arn,
+ State=replay_service.state,
+ StateReason=replay_service.state_reason,
+ ReplayStartTime=replay_service.replay_start_time,
+ )
+ return response
+
+ ######
+ # Tags
+ ######
+
+ @handler("ListTagsForResource")
+ def list_tags_for_resource(
+ self, context: RequestContext, resource_arn: Arn, **kwargs
+ ) -> ListTagsForResourceResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ resource_type = get_resource_type(resource_arn)
+ self._check_resource_exists(resource_arn, resource_type, store)
+ tags = store.TAGS.list_tags_for_resource(resource_arn)
+ return ListTagsForResourceResponse(tags)
+
+ @handler("TagResource")
+ def tag_resource(
+ self, context: RequestContext, resource_arn: Arn, tags: TagList, **kwargs
+ ) -> TagResourceResponse:
+ # each tag key must be unique
+ # https://docs.aws.amazon.com/general/latest/gr/aws_tagging.html#tag-best-practices
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ resource_type = get_resource_type(resource_arn)
+ self._check_resource_exists(resource_arn, resource_type, store)
+ check_unique_tags(tags)
+ store.TAGS.tag_resource(resource_arn, tags)
+
+ @handler("UntagResource")
+ def untag_resource(
+ self, context: RequestContext, resource_arn: Arn, tag_keys: TagKeyList, **kwargs
+ ) -> UntagResourceResponse:
+ region = context.region
+ account_id = context.account_id
+ store = self.get_store(region, account_id)
+ resource_type = get_resource_type(resource_arn)
+ self._check_resource_exists(resource_arn, resource_type, store)
+ store.TAGS.untag_resource(resource_arn, tag_keys)
+
+ #########
+ # Methods
+ #########
+
+ def get_store(self, region: str, account_id: str) -> EventsStore:
+ """Returns the events store for the account and region.
+ On first call, creates the default event bus for the account region."""
+ store = events_stores[account_id][region]
+ # create default event bus for account region on first call
+ default_event_bus_name = "default"
+ if default_event_bus_name not in store.event_buses:
+ event_bus_service = self.create_event_bus_service(
+ default_event_bus_name, region, account_id, None, None, None
+ )
+ store.event_buses[event_bus_service.event_bus.name] = event_bus_service.event_bus
+ return store
+
+ def get_event_bus(self, name: EventBusName, store: EventsStore) -> EventBus:
+ if event_bus := store.event_buses.get(name):
+ return event_bus
+ raise ResourceNotFoundException(f"Event bus {name} does not exist.")
+
+ def get_rule(self, name: RuleName, event_bus: EventBus) -> Rule:
+ if rule := event_bus.rules.get(name):
+ return rule
+ raise ResourceNotFoundException(f"Rule {name} does not exist on EventBus {event_bus.name}.")
+
+ def get_target(self, target_id: TargetId, rule: Rule) -> Target:
+ if target := rule.targets.get(target_id):
+ return target
+ raise ResourceNotFoundException(f"Target {target_id} does not exist on Rule {rule.name}.")
+
+ def get_archive(self, name: ArchiveName, store: EventsStore) -> Archive:
+ if archive := store.archives.get(name):
+ return archive
+ raise ResourceNotFoundException(f"Archive {name} does not exist.")
+
+ def get_replay(self, name: ReplayName, store: EventsStore) -> Replay:
+ if replay := store.replays.get(name):
+ return replay
+ raise ResourceNotFoundException(f"Replay {name} does not exist.")
+
+ def get_connection(self, name: ConnectionName, store: EventsStore) -> Connection:
+ if connection := store.connections.get(name):
+ return connection
+ raise ResourceNotFoundException(
+ f"Failed to describe the connection(s). Connection '{name}' does not exist."
+ )
+
+ def get_api_destination(self, name: ApiDestinationName, store: EventsStore) -> ApiDestination:
+ if api_destination := store.api_destinations.get(name):
+ return api_destination
+ raise ResourceNotFoundException(
+ f"Failed to describe the api-destination(s). An api-destination '{name}' does not exist."
+ )
+
+ def get_rule_service(
+ self,
+ region: str,
+ account_id: str,
+ rule_name: RuleName,
+ event_bus_name: EventBusName,
+ ) -> RuleService:
+ store = self.get_store(region, account_id)
+ event_bus_name = extract_event_bus_name(event_bus_name)
+ event_bus = self.get_event_bus(event_bus_name, store)
+ rule = self.get_rule(rule_name, event_bus)
+ return self._rule_services_store[rule.arn]
+
+ def create_event_bus_service(
+ self,
+ name: EventBusName,
+ region: str,
+ account_id: str,
+ event_source_name: Optional[EventSourceName],
+ description: Optional[EventBusDescription],
+ tags: Optional[TagList],
+ ) -> EventBusService:
+ event_bus_service = EventBusService.create_event_bus_service(
+ name,
+ region,
+ account_id,
+ event_source_name,
+ description,
+ tags,
+ )
+ self._event_bus_services_store[event_bus_service.arn] = event_bus_service
+ return event_bus_service
+
+ def create_rule_service(
+ self,
+ name: RuleName,
+ region: str,
+ account_id: str,
+ schedule_expression: Optional[ScheduleExpression],
+ event_pattern: Optional[EventPattern],
+ state: Optional[RuleState],
+ description: Optional[RuleDescription],
+ role_arn: Optional[RoleArn],
+ tags: Optional[TagList],
+ event_bus_name: Optional[EventBusName],
+ targets: Optional[TargetDict],
+ ) -> RuleService:
+ rule_service = RuleService.create_rule_service(
+ name,
+ region,
+ account_id,
+ schedule_expression,
+ event_pattern,
+ state,
+ description,
+ role_arn,
+ tags,
+ event_bus_name,
+ targets,
+ )
+ self._rule_services_store[rule_service.arn] = rule_service
+ return rule_service
+
+ def create_target_sender(
+ self, target: Target, rule_arn: Arn, rule_name: RuleName, region: str, account_id: str
+ ) -> TargetSender:
+ target_sender = TargetSenderFactory(
+ target, rule_arn, rule_name, region, account_id
+ ).get_target_sender()
+ self._target_sender_store[target_sender.unique_id] = target_sender
+ return target_sender
+
+ def create_archive_service(
+ self,
+ archive_name: ArchiveName,
+ region: str,
+ account_id: str,
+ event_source_arn: Arn,
+ description: ArchiveDescription,
+ event_pattern: EventPattern,
+ retention_days: RetentionDays,
+ ) -> ArchiveService:
+ archive_service = ArchiveService.create_archive_service(
+ archive_name,
+ region,
+ account_id,
+ event_source_arn,
+ description,
+ event_pattern,
+ retention_days,
+ )
+ archive_service.register_archive_rule_and_targets()
+ self._archive_service_store[archive_service.arn] = archive_service
+ return archive_service
+
+ def create_replay_service(
+ self,
+ name: ReplayName,
+ region: str,
+ account_id: str,
+ event_source_arn: Arn,
+ destination: ReplayDestination,
+ event_start_time: Timestamp,
+ event_end_time: Timestamp,
+ description: ReplayDescription,
+ ) -> ReplayService:
+ replay_service = ReplayService(
+ name,
+ region,
+ account_id,
+ event_source_arn,
+ destination,
+ event_start_time,
+ event_end_time,
+ description,
+ )
+ self._replay_service_store[replay_service.arn] = replay_service
+ return replay_service
+
+ def create_connection_service(
+ self,
+ name: ConnectionName,
+ region: str,
+ account_id: str,
+ authorization_type: ConnectionAuthorizationType,
+ auth_parameters: CreateConnectionAuthRequestParameters,
+ description: ConnectionDescription,
+ invocation_connectivity_parameters: ConnectivityResourceParameters,
+ ) -> ConnectionService:
+ connection_service = ConnectionService(
+ name,
+ region,
+ account_id,
+ authorization_type,
+ auth_parameters,
+ description,
+ invocation_connectivity_parameters,
+ )
+ self._connection_service_store[connection_service.arn] = connection_service
+ return connection_service
+
+ def create_api_destinations_service(
+ self,
+ name: ConnectionName,
+ region: str,
+ account_id: str,
+ connection_arn: ConnectionArn,
+ connection: Connection,
+ invocation_endpoint: HttpsEndpoint,
+ http_method: ApiDestinationHttpMethod,
+ invocation_rate_limit_per_second: ApiDestinationInvocationRateLimitPerSecond,
+ description: ApiDestinationDescription,
+ ) -> APIDestinationService:
+ api_destination_service = APIDestinationService(
+ name,
+ region,
+ account_id,
+ connection_arn,
+ connection,
+ invocation_endpoint,
+ http_method,
+ invocation_rate_limit_per_second,
+ description,
+ )
+ self._api_destination_service_store[api_destination_service.arn] = api_destination_service
+ return api_destination_service
+
+ def _delete_connection(self, connection_arn: Arn) -> None:
+ del self._connection_service_store[connection_arn]
+
+ def _delete_rule_services(self, rules: RuleDict | Rule) -> None:
+ """
+ Delete all rule services associated to the input from the store.
+ Accepts a single Rule object or a dict of Rule objects as input.
+ """
+ if isinstance(rules, Rule):
+ rules = {rules.name: rules}
+ for rule in rules.values():
+ del self._rule_services_store[rule.arn]
+
+ def _delete_target_sender(self, ids: TargetIdList, rule) -> None:
+ for target_id in ids:
+ if target := rule.targets.get(target_id):
+ target_unique_id = f"{rule.arn}-{target_id}"
+ try:
+ del self._target_sender_store[target_unique_id]
+ except KeyError:
+ LOG.error("Error deleting target service %s.", target["Arn"])
+
+ def _get_limited_dict_and_next_token(
+ self, input_dict: dict, next_token: NextToken | None, limit: LimitMax100 | None
+ ) -> tuple[dict, NextToken]:
+ """Return a slice of the given dictionary starting from next_token with length of limit
+ and new last index encoded as a next_token for pagination."""
+ input_dict_len = len(input_dict)
+ start_index = decode_next_token(next_token) if next_token is not None else 0
+ end_index = start_index + limit if limit is not None else input_dict_len
+ limited_dict = dict(list(input_dict.items())[start_index:end_index])
+
+ next_token = (
+ encode_next_token(end_index)
+ # return a next_token (encoded integer of next starting index) if not all items are returned
+ if end_index < input_dict_len
+ else None
+ )
+ return limited_dict, next_token
+
+ def _check_resource_exists(
+ self, resource_arn: Arn, resource_type: ResourceType, store: EventsStore
+ ) -> None:
+ if resource_type == ResourceType.EVENT_BUS:
+ event_bus_name = extract_event_bus_name(resource_arn)
+ self.get_event_bus(event_bus_name, store)
+ if resource_type == ResourceType.RULE:
+ event_bus_name = extract_event_bus_name(resource_arn)
+ event_bus = self.get_event_bus(event_bus_name, store)
+ rule_name = resource_arn.split("/")[-1]
+ self.get_rule(rule_name, event_bus)
+
+ def _get_scheduled_rule_job_function(self, account_id, region, rule: Rule) -> Callable:
+ def func(*args, **kwargs):
+ """Create custom scheduled event and send it to all targets specified by associated rule using respective TargetSender"""
+ for target in rule.targets.values():
+ if custom_input := target.get("Input"):
+ event = json.loads(custom_input)
+ else:
+ event = {
+ "version": "0",
+ "id": long_uid(),
+ "detail-type": "Scheduled Event",
+ "source": "aws.events",
+ "account": account_id,
+ "time": timestamp(format=TIMESTAMP_FORMAT_TZ),
+ "region": region,
+ "resources": [rule.arn],
+ "detail": {},
+ }
+ target_unique_id = f"{rule.arn}-{target['Id']}"
+ target_sender = self._target_sender_store[target_unique_id]
+ try:
+ target_sender.process_event(event.copy())
+ except Exception as e:
+ LOG.info(
+ "Unable to send event notification %s to target %s: %s",
+ truncate(event),
+ target,
+ e,
+ )
+
+ return func
+
+ def _check_event_bus_exists(
+ self, event_bus_name_or_arn: EventBusNameOrArn, store: EventsStore
+ ) -> None:
+ event_bus_name = extract_event_bus_name(event_bus_name_or_arn)
+ self.get_event_bus(event_bus_name, store)
+
+ def _validate_replay_time(self, event_start_time: Timestamp, event_end_time: Timestamp) -> None:
+ if event_end_time <= event_start_time:
+ raise ValidationException(
+ "Parameter EventEndTime is not valid. Reason: EventStartTime must be before EventEndTime."
+ )
+
+ def _validate_replay_destination(
+ self, destination: ReplayDestination, event_source_arn: Arn
+ ) -> None:
+ archive_service = self._archive_service_store[event_source_arn]
+ if destination_arn := destination.get("Arn"):
+ if destination_arn != archive_service.archive.event_source_arn:
+ if destination_arn in self._event_bus_services_store:
+ raise ValidationException(
+ "Parameter Destination.Arn is not valid. Reason: Cross event bus replay is not permitted."
+ )
+ else:
+ event_bus_name = extract_event_bus_name(destination_arn)
+ raise ResourceNotFoundException(f"Event bus {event_bus_name} does not exist.")
+
+ # Internal type to API type remappings
+
+ def _event_bust_dict_to_event_bus_response_list(
+ self, event_buses: EventBusDict
+ ) -> EventBusList:
+ """Return a converted dict of EventBus model objects as a list of event buses in API type EventBus format."""
+ event_bus_list = [
+ self._event_bus_to_api_type_event_bus(event_bus) for event_bus in event_buses.values()
+ ]
+ return event_bus_list
+
+ def _event_bus_to_api_type_event_bus(self, event_bus: EventBus) -> ApiTypeEventBus:
+ event_bus_api_type = {
+ "Name": event_bus.name,
+ "Arn": event_bus.arn,
+ }
+ if getattr(event_bus, "description", None):
+ event_bus_api_type["Description"] = event_bus.description
+ if event_bus.creation_time:
+ event_bus_api_type["CreationTime"] = event_bus.creation_time
+ if event_bus.last_modified_time:
+ event_bus_api_type["LastModifiedTime"] = event_bus.last_modified_time
+ if event_bus.policy:
+ event_bus_api_type["Policy"] = json.dumps(
+ recursive_remove_none_values_from_dict(event_bus.policy)
+ )
+
+ return event_bus_api_type
+
+ def _event_to_error_type_event(self, entry: PutEventsRequestEntry) -> str:
+ detail = (
+ json.dumps(json.loads(entry["Detail"]), separators=(", ", ": "))
+ if entry.get("Detail")
+ else "null"
+ )
+ return (
+ f"PutEventsRequestEntry("
+ f"time={entry.get('Time', 'null')}, "
+ f"source={entry.get('Source', 'null')}, "
+ f"resources={entry.get('Resources', 'null')}, "
+ f"detailType={entry.get('DetailType', 'null')}, "
+ f"detail={detail}, "
+ f"eventBusName={entry.get('EventBusName', 'null')}, "
+ f"traceHeader={entry.get('TraceHeader', 'null')}, "
+ f"kmsKeyIdentifier={entry.get('kmsKeyIdentifier', 'null')}, "
+ f"internalMetadata={entry.get('internalMetadata', 'null')}"
+ f")"
+ )
+
+ def _rule_dict_to_rule_response_list(self, rules: RuleDict) -> RuleResponseList:
+ """Return a converted dict of Rule model objects as a list of rules in API type Rule format."""
+ rule_list = [self._rule_to_api_type_rule(rule) for rule in rules.values()]
+ return rule_list
+
+ def _rule_to_api_type_rule(self, rule: Rule) -> ApiTypeRule:
+ rule = {
+ "Name": rule.name,
+ "Arn": rule.arn,
+ "EventPattern": rule.event_pattern,
+ "State": rule.state,
+ "Description": rule.description,
+ "ScheduleExpression": rule.schedule_expression,
+ "RoleArn": rule.role_arn,
+ "ManagedBy": rule.managed_by,
+ "EventBusName": rule.event_bus_name,
+ "CreatedBy": rule.created_by,
+ }
+ return {key: value for key, value in rule.items() if value is not None}
+
+ def _archive_dict_to_archive_response_list(self, archives: ArchiveDict) -> ArchiveResponseList:
+ """Return a converted dict of Archive model objects as a list of archives in API type Archive format."""
+ archive_list = [self._archive_to_api_type_archive(archive) for archive in archives.values()]
+ return archive_list
+
+ def _archive_to_api_type_archive(self, archive: Archive) -> ApiTypeArchive:
+ archive = {
+ "ArchiveName": archive.name,
+ "EventSourceArn": archive.event_source_arn,
+ "State": archive.state,
+ # TODO add "StateReason": archive.state_reason,
+ "RetentionDays": archive.retention_days,
+ "SizeBytes": archive.size_bytes,
+ "EventCount": archive.event_count,
+ "CreationTime": archive.creation_time,
+ }
+ return {key: value for key, value in archive.items() if value is not None}
+
+ def _archive_to_describe_archive_response(self, archive: Archive) -> DescribeArchiveResponse:
+ archive_dict = {
+ "ArchiveArn": archive.arn,
+ "ArchiveName": archive.name,
+ "EventSourceArn": archive.event_source_arn,
+ "State": archive.state,
+ # TODO add "StateReason": archive.state_reason,
+ "RetentionDays": archive.retention_days,
+ "SizeBytes": archive.size_bytes,
+ "EventCount": archive.event_count,
+ "CreationTime": archive.creation_time,
+ "EventPattern": archive.event_pattern,
+ "Description": archive.description,
+ }
+ return {key: value for key, value in archive_dict.items() if value is not None}
+
+ def _replay_dict_to_replay_response_list(self, replays: ReplayDict) -> ReplayList:
+ """Return a converted dict of Replay model objects as a list of replays in API type Replay format."""
+ replay_list = [self._replay_to_api_type_replay(replay) for replay in replays.values()]
+ return replay_list
+
+ def _replay_to_api_type_replay(self, replay: Replay) -> ApiTypeReplay:
+ replay = {
+ "ReplayName": replay.name,
+ "EventSourceArn": replay.event_source_arn,
+ "State": replay.state,
+ # # "StateReason": replay.state_reason,
+ "EventStartTime": replay.event_start_time,
+ "EventEndTime": replay.event_end_time,
+ "EventLastReplayedTime": replay.event_last_replayed_time,
+ "ReplayStartTime": replay.replay_start_time,
+ "ReplayEndTime": replay.replay_end_time,
+ }
+ return {key: value for key, value in replay.items() if value is not None}
+
+ def _replay_to_describe_replay_response(self, replay: Replay) -> DescribeReplayResponse:
+ replay_dict = {
+ "ReplayName": replay.name,
+ "ReplayArn": replay.arn,
+ "Description": replay.description,
+ "State": replay.state,
+ # # "StateReason": replay.state_reason,
+ "EventSourceArn": replay.event_source_arn,
+ "Destination": replay.destination,
+ "EventStartTime": replay.event_start_time,
+ "EventEndTime": replay.event_end_time,
+ "EventLastReplayedTime": replay.event_last_replayed_time,
+ "ReplayStartTime": replay.replay_start_time,
+ "ReplayEndTime": replay.replay_end_time,
+ }
+ return {key: value for key, value in replay_dict.items() if value is not None}
+
+ def _connection_to_api_type_connection(self, connection: Connection) -> ApiTypeConnection:
+ connection = {
+ "ConnectionArn": connection.arn,
+ "Name": connection.name,
+ "ConnectionState": connection.state,
+ # "StateReason": connection.state_reason, # TODO implement state reason
+ "AuthorizationType": connection.authorization_type,
+ "AuthParameters": connection.auth_parameters,
+ "SecretArn": connection.secret_arn,
+ "CreationTime": connection.creation_time,
+ "LastModifiedTime": connection.last_modified_time,
+ "LastAuthorizedTime": connection.last_authorized_time,
+ }
+ return {key: value for key, value in connection.items() if value is not None}
+
+ def _connection_dict_to_connection_response_list(
+ self, connections: ConnectionDict
+ ) -> ConnectionResponseList:
+ """Return a converted dict of Connection model objects as a list of connections in API type Connection format."""
+ connection_list = [
+ self._connection_to_api_type_connection(connection)
+ for connection in connections.values()
+ ]
+ return connection_list
+
+ def _api_destination_to_api_type_api_destination(
+ self, api_destination: ApiDestination
+ ) -> ApiTypeApiDestination:
+ api_destination = {
+ "ApiDestinationArn": api_destination.arn,
+ "Name": api_destination.name,
+ "ConnectionArn": api_destination.connection_arn,
+ "ApiDestinationState": api_destination.state,
+ "InvocationEndpoint": api_destination.invocation_endpoint,
+ "HttpMethod": api_destination.http_method,
+ "InvocationRateLimitPerSecond": api_destination.invocation_rate_limit_per_second,
+ "CreationTime": api_destination.creation_time,
+ "LastModifiedTime": api_destination.last_modified_time,
+ "Description": api_destination.description,
+ }
+ return {key: value for key, value in api_destination.items() if value is not None}
+
+ def _api_destination_dict_to_api_destination_response_list(
+ self, api_destinations: ApiDestinationDict
+ ) -> ApiDestinationResponseList:
+ """Return a converted dict of ApiDestination model objects as a list of connections in API type ApiDestination format."""
+ api_destination_list = [
+ self._api_destination_to_api_type_api_destination(api_destination)
+ for api_destination in api_destinations.values()
+ ]
+ return api_destination_list
+
+ def _put_to_archive(
+ self,
+ region: str,
+ account_id: str,
+ archive_target_id: str,
+ event: FormattedEvent,
+ ) -> None:
+ archive_name = ARCHIVE_TARGET_ID_NAME_PATTERN.match(archive_target_id).group("name")
+
+ store = self.get_store(region, account_id)
+ archive = self.get_archive(archive_name, store)
+ archive_service = self._archive_service_store[archive.arn]
+ archive_service.put_events([event])
+
+ def _process_entries(
+ self, context: RequestContext, entries: PutEventsRequestEntryList
+ ) -> tuple[PutEventsResultEntryList, int]:
+ """Main method to process events put to an event bus.
+ Events are validated to contain the proper fields and formatted.
+ Events are matched against all the rules of the respective event bus.
+ For matching rules the event is either sent to the respective target,
+ via the target sender put to the defined archived."""
+ processed_entries = []
+ failed_entry_count = {"count": 0}
+ for event in entries:
+ self._process_entry(event, processed_entries, failed_entry_count, context)
+ return processed_entries, failed_entry_count["count"]
+
+ def _process_entry(
+ self,
+ entry: PutEventsRequestEntry,
+ processed_entries: PutEventsResultEntryList,
+ failed_entry_count: dict[str, int],
+ context: RequestContext,
+ ) -> None:
+ event_bus_name_or_arn = entry.get("EventBusName", "default")
+ event_bus_name = extract_event_bus_name(event_bus_name_or_arn)
+ if event_failed_validation := validate_event(entry):
+ processed_entries.append(event_failed_validation)
+ failed_entry_count["count"] += 1
+ LOG.info(json.dumps(event_failed_validation))
+ return
+
+ region, account_id = extract_region_and_account_id(event_bus_name_or_arn, context)
+ if encoded_trace_header := get_trace_header_encoded_region_account(
+ entry, context.region, context.account_id, region, account_id
+ ):
+ entry["TraceHeader"] = encoded_trace_header
+
+ event_formatted = format_event(entry, region, account_id, event_bus_name)
+ store = self.get_store(region, account_id)
+
+ try:
+ event_bus = self.get_event_bus(event_bus_name, store)
+ except ResourceNotFoundException:
+ # ignore events for non-existing event buses but add processed event
+ processed_entries.append({"EventId": event_formatted["id"]})
+ LOG.info(
+ json.dumps(
+ {
+ "ErrorCode": "ResourceNotFoundException at get_event_bus",
+ "ErrorMessage": f"Event_bus {event_bus_name} does not exist",
+ }
+ )
+ )
+ return
+
+ self._proxy_capture_input_event(event_formatted)
+
+ # Always add the successful EventId entry, even if target processing might fail
+ processed_entries.append({"EventId": event_formatted["id"]})
+
+ if configured_rules := list(event_bus.rules.values()):
+ for rule in configured_rules:
+ self._process_rules(rule, region, account_id, event_formatted)
+ else:
+ LOG.info(
+ json.dumps(
+ {
+ "InfoCode": "InternalInfoEvents at process_rules",
+ "InfoMessage": f"No rules attached to event_bus: {event_bus_name}",
+ }
+ )
+ )
+
+ def _proxy_capture_input_event(self, event: FormattedEvent) -> None:
+ # only required for eventstudio to capture input event if no rule is configured
+ pass
+
+ def _process_rules(
+ self,
+ rule: Rule,
+ region: str,
+ account_id: str,
+ event_formatted: FormattedEvent,
+ ) -> None:
+ """Process rules for an event. Note that we no longer handle entries here as AWS returns success regardless of target failures."""
+ event_pattern = rule.event_pattern
+
+ if matches_event(event_pattern, event_formatted):
+ if not rule.targets:
+ LOG.info(
+ json.dumps(
+ {
+ "InfoCode": "InternalInfoEvents at iterate over targets",
+ "InfoMessage": f"No target configured for matched rule: {rule}",
+ }
+ )
+ )
+ return
+
+ for target in rule.targets.values():
+ target_id = target["Id"]
+ if is_archive_arn(target["Arn"]):
+ self._put_to_archive(
+ region,
+ account_id,
+ archive_target_id=target_id,
+ event=event_formatted,
+ )
+ else:
+ target_unique_id = f"{rule.arn}-{target_id}"
+ target_sender = self._target_sender_store[target_unique_id]
+ try:
+ target_sender.process_event(event_formatted.copy())
+ rule_invocation.record(target_sender.service)
+ except Exception as error:
+ rule_error.record(target_sender.service)
+ # Log the error but don't modify the response
+ LOG.info(
+ json.dumps(
+ {
+ "ErrorCode": "TargetDeliveryFailure",
+ "ErrorMessage": f"Failed to deliver to target {target_id}: {str(error)}",
+ }
+ )
+ )
+ else:
+ LOG.info(
+ json.dumps(
+ {
+ "InfoCode": "InternalInfoEvents at matches_rule",
+ "InfoMessage": f"No rules matched for formatted event: {event_formatted}",
+ }
+ )
+ )
diff --git a/localstack-core/localstack/services/events/replay.py b/localstack-core/localstack/services/events/replay.py
new file mode 100644
index 0000000000000..7a58fb3534d05
--- /dev/null
+++ b/localstack-core/localstack/services/events/replay.py
@@ -0,0 +1,94 @@
+from datetime import datetime, timezone
+
+from localstack.aws.api.events import (
+ Arn,
+ PutEventsRequestEntry,
+ ReplayDescription,
+ ReplayDestination,
+ ReplayName,
+ ReplayState,
+ Timestamp,
+)
+from localstack.services.events.models import FormattedEventList, Replay
+from localstack.services.events.utils import (
+ convert_to_timezone_aware_datetime,
+ extract_event_bus_name,
+ re_format_event,
+)
+
+
+class ReplayService:
+ name: ReplayName
+ region: str
+ account_id: str
+ event_source_arn: Arn
+ destination: ReplayDestination
+ event_start_time: Timestamp
+ event_end_time: Timestamp
+ description: ReplayDescription
+ replay: Replay
+
+ def __init__(
+ self,
+ name: ReplayName,
+ region: str,
+ account_id: str,
+ event_source_arn: Arn,
+ destination: ReplayDestination,
+ event_start_time: Timestamp,
+ event_end_time: Timestamp,
+ description: ReplayDescription,
+ ):
+ event_start_time = convert_to_timezone_aware_datetime(event_start_time)
+ event_end_time = convert_to_timezone_aware_datetime(event_end_time)
+ self.replay = Replay(
+ name,
+ region,
+ account_id,
+ event_source_arn,
+ destination,
+ event_start_time,
+ event_end_time,
+ description,
+ )
+ self.set_state(ReplayState.STARTING)
+
+ def __getattr__(self, name):
+ return getattr(self.replay, name)
+
+ def set_state(self, state: ReplayState) -> None:
+ self.replay.state = state
+
+ def start(self, events: FormattedEventList | None) -> None:
+ self.set_state(ReplayState.RUNNING)
+ self.replay.replay_start_time = datetime.now(timezone.utc)
+ if events:
+ self._set_event_last_replayed_time(events)
+
+ def finish(self) -> None:
+ self.set_state(ReplayState.COMPLETED)
+ self.replay.replay_end_time = datetime.now(timezone.utc)
+
+ def stop(self) -> None:
+ self.set_state(ReplayState.CANCELLING)
+ self.replay.event_last_replayed_time = None
+ self.replay.replay_end_time = None
+
+ def re_format_events_from_archive(
+ self, events: FormattedEventList, replay_name: ReplayName
+ ) -> PutEventsRequestEntry:
+ event_bus_name = extract_event_bus_name(
+ self.destination["Arn"]
+ ) # TODO deal with filter arn -> defining rules to replay to
+ re_formatted_events = [re_format_event(event, event_bus_name) for event in events]
+ re_formatted_events_from_archive = [
+ {**event, "ReplayName": replay_name} for event in re_formatted_events
+ ]
+ return re_formatted_events_from_archive
+
+ def _set_event_last_replayed_time(self, events: FormattedEventList) -> None:
+ latest_event_time = max(event["time"] for event in events)
+ self.replay.event_last_replayed_time = latest_event_time
+
+
+ReplayServiceDict = dict[ReplayName, ReplayService]
diff --git a/localstack-core/localstack/services/events/resource_providers/__init__.py b/localstack-core/localstack/services/events/resource_providers/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/events/resource_providers/aws_events_apidestination.py b/localstack-core/localstack/services/events/resource_providers/aws_events_apidestination.py
new file mode 100644
index 0000000000000..372d45de40dce
--- /dev/null
+++ b/localstack-core/localstack/services/events/resource_providers/aws_events_apidestination.py
@@ -0,0 +1,115 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EventsApiDestinationProperties(TypedDict):
+ ConnectionArn: Optional[str]
+ HttpMethod: Optional[str]
+ InvocationEndpoint: Optional[str]
+ Arn: Optional[str]
+ Description: Optional[str]
+ InvocationRateLimitPerSecond: Optional[int]
+ Name: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EventsApiDestinationProvider(ResourceProvider[EventsApiDestinationProperties]):
+ TYPE = "AWS::Events::ApiDestination" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EventsApiDestinationProperties],
+ ) -> ProgressEvent[EventsApiDestinationProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Name
+
+ Required properties:
+ - ConnectionArn
+ - InvocationEndpoint
+ - HttpMethod
+
+ Create-only properties:
+ - /properties/Name
+
+ Read-only properties:
+ - /properties/Arn
+
+ IAM permissions required:
+ - events:CreateApiDestination
+ - events:DescribeApiDestination
+
+ """
+ model = request.desired_state
+ events = request.aws_client_factory.events
+
+ response = events.create_api_destination(**model)
+ model["Arn"] = response["ApiDestinationArn"]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EventsApiDestinationProperties],
+ ) -> ProgressEvent[EventsApiDestinationProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - events:DescribeApiDestination
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EventsApiDestinationProperties],
+ ) -> ProgressEvent[EventsApiDestinationProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - events:DeleteApiDestination
+ - events:DescribeApiDestination
+ """
+ model = request.desired_state
+ events = request.aws_client_factory.events
+
+ events.delete_api_destination(Name=model["Name"])
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[EventsApiDestinationProperties],
+ ) -> ProgressEvent[EventsApiDestinationProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - events:UpdateApiDestination
+ - events:DescribeApiDestination
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/events/resource_providers/aws_events_apidestination.schema.json b/localstack-core/localstack/services/events/resource_providers/aws_events_apidestination.schema.json
new file mode 100644
index 0000000000000..f50460b1aea17
--- /dev/null
+++ b/localstack-core/localstack/services/events/resource_providers/aws_events_apidestination.schema.json
@@ -0,0 +1,92 @@
+{
+ "typeName": "AWS::Events::ApiDestination",
+ "description": "Resource Type definition for AWS::Events::ApiDestination.",
+ "properties": {
+ "Name": {
+ "description": "Name of the apiDestination.",
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 64
+ },
+ "Description": {
+ "type": "string",
+ "maxLength": 512
+ },
+ "ConnectionArn": {
+ "description": "The arn of the connection.",
+ "type": "string"
+ },
+ "Arn": {
+ "description": "The arn of the api destination.",
+ "type": "string"
+ },
+ "InvocationRateLimitPerSecond": {
+ "type": "integer",
+ "minimum": 1
+ },
+ "InvocationEndpoint": {
+ "description": "Url endpoint to invoke.",
+ "type": "string"
+ },
+ "HttpMethod": {
+ "type": "string",
+ "enum": [
+ "GET",
+ "HEAD",
+ "POST",
+ "OPTIONS",
+ "PUT",
+ "DELETE",
+ "PATCH"
+ ]
+ }
+ },
+ "additionalProperties": false,
+ "createOnlyProperties": [
+ "/properties/Name"
+ ],
+ "readOnlyProperties": [
+ "/properties/Arn"
+ ],
+ "required": [
+ "ConnectionArn",
+ "InvocationEndpoint",
+ "HttpMethod"
+ ],
+ "primaryIdentifier": [
+ "/properties/Name"
+ ],
+ "tagging": {
+ "taggable": false
+ },
+ "handlers": {
+ "create": {
+ "permissions": [
+ "events:CreateApiDestination",
+ "events:DescribeApiDestination"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "events:DescribeApiDestination"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "events:UpdateApiDestination",
+ "events:DescribeApiDestination"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "events:DeleteApiDestination",
+ "events:DescribeApiDestination"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "events:ListApiDestinations"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/events/resource_providers/aws_events_apidestination_plugin.py b/localstack-core/localstack/services/events/resource_providers/aws_events_apidestination_plugin.py
new file mode 100644
index 0000000000000..0aa7ada08cc50
--- /dev/null
+++ b/localstack-core/localstack/services/events/resource_providers/aws_events_apidestination_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EventsApiDestinationProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::Events::ApiDestination"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.events.resource_providers.aws_events_apidestination import (
+ EventsApiDestinationProvider,
+ )
+
+ self.factory = EventsApiDestinationProvider
diff --git a/localstack-core/localstack/services/events/resource_providers/aws_events_connection.py b/localstack-core/localstack/services/events/resource_providers/aws_events_connection.py
new file mode 100644
index 0000000000000..a99f8df743aca
--- /dev/null
+++ b/localstack-core/localstack/services/events/resource_providers/aws_events_connection.py
@@ -0,0 +1,162 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EventsConnectionProperties(TypedDict):
+ AuthParameters: Optional[AuthParameters]
+ AuthorizationType: Optional[str]
+ Arn: Optional[str]
+ Description: Optional[str]
+ Name: Optional[str]
+ SecretArn: Optional[str]
+
+
+class ApiKeyAuthParameters(TypedDict):
+ ApiKeyName: Optional[str]
+ ApiKeyValue: Optional[str]
+
+
+class BasicAuthParameters(TypedDict):
+ Password: Optional[str]
+ Username: Optional[str]
+
+
+class ClientParameters(TypedDict):
+ ClientID: Optional[str]
+ ClientSecret: Optional[str]
+
+
+class Parameter(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+ IsValueSecret: Optional[bool]
+
+
+class ConnectionHttpParameters(TypedDict):
+ BodyParameters: Optional[list[Parameter]]
+ HeaderParameters: Optional[list[Parameter]]
+ QueryStringParameters: Optional[list[Parameter]]
+
+
+class OAuthParameters(TypedDict):
+ AuthorizationEndpoint: Optional[str]
+ ClientParameters: Optional[ClientParameters]
+ HttpMethod: Optional[str]
+ OAuthHttpParameters: Optional[ConnectionHttpParameters]
+
+
+class AuthParameters(TypedDict):
+ ApiKeyAuthParameters: Optional[ApiKeyAuthParameters]
+ BasicAuthParameters: Optional[BasicAuthParameters]
+ InvocationHttpParameters: Optional[ConnectionHttpParameters]
+ OAuthParameters: Optional[OAuthParameters]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EventsConnectionProvider(ResourceProvider[EventsConnectionProperties]):
+ TYPE = "AWS::Events::Connection" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EventsConnectionProperties],
+ ) -> ProgressEvent[EventsConnectionProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Name
+
+ Required properties:
+ - AuthorizationType
+ - AuthParameters
+
+ Create-only properties:
+ - /properties/Name
+
+ Read-only properties:
+ - /properties/Arn
+ - /properties/SecretArn
+
+ IAM permissions required:
+ - events:CreateConnection
+ - secretsmanager:CreateSecret
+ - secretsmanager:GetSecretValue
+ - secretsmanager:PutSecretValue
+ - iam:CreateServiceLinkedRole
+
+ """
+ model = request.desired_state
+ events = request.aws_client_factory.events
+
+ response = events.create_connection(**model)
+ model["Arn"] = response["ConnectionArn"]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EventsConnectionProperties],
+ ) -> ProgressEvent[EventsConnectionProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - events:DescribeConnection
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EventsConnectionProperties],
+ ) -> ProgressEvent[EventsConnectionProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - events:DeleteConnection
+ """
+ model = request.desired_state
+ events = request.aws_client_factory.events
+
+ events.delete_connection(Name=model["Name"])
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[EventsConnectionProperties],
+ ) -> ProgressEvent[EventsConnectionProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - events:UpdateConnection
+ - events:DescribeConnection
+ - secretsmanager:CreateSecret
+ - secretsmanager:UpdateSecret
+ - secretsmanager:GetSecretValue
+ - secretsmanager:PutSecretValue
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/events/resource_providers/aws_events_connection.schema.json b/localstack-core/localstack/services/events/resource_providers/aws_events_connection.schema.json
new file mode 100644
index 0000000000000..efc8539e82273
--- /dev/null
+++ b/localstack-core/localstack/services/events/resource_providers/aws_events_connection.schema.json
@@ -0,0 +1,251 @@
+{
+ "typeName": "AWS::Events::Connection",
+ "description": "Resource Type definition for AWS::Events::Connection.",
+ "definitions": {
+ "AuthParameters": {
+ "type": "object",
+ "minProperties": 1,
+ "maxProperties": 2,
+ "properties": {
+ "ApiKeyAuthParameters": {
+ "$ref": "#/definitions/ApiKeyAuthParameters"
+ },
+ "BasicAuthParameters": {
+ "$ref": "#/definitions/BasicAuthParameters"
+ },
+ "OAuthParameters": {
+ "$ref": "#/definitions/OAuthParameters"
+ },
+ "InvocationHttpParameters": {
+ "$ref": "#/definitions/ConnectionHttpParameters"
+ }
+ },
+ "oneOf": [
+ {
+ "required": [
+ "BasicAuthParameters"
+ ]
+ },
+ {
+ "required": [
+ "OAuthParameters"
+ ]
+ },
+ {
+ "required": [
+ "ApiKeyAuthParameters"
+ ]
+ }
+ ],
+ "additionalProperties": false
+ },
+ "BasicAuthParameters": {
+ "type": "object",
+ "properties": {
+ "Username": {
+ "type": "string"
+ },
+ "Password": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Username",
+ "Password"
+ ],
+ "additionalProperties": false
+ },
+ "OAuthParameters": {
+ "type": "object",
+ "properties": {
+ "ClientParameters": {
+ "$ref": "#/definitions/ClientParameters"
+ },
+ "AuthorizationEndpoint": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 2048
+ },
+ "HttpMethod": {
+ "type": "string",
+ "enum": [
+ "GET",
+ "POST",
+ "PUT"
+ ]
+ },
+ "OAuthHttpParameters": {
+ "$ref": "#/definitions/ConnectionHttpParameters"
+ }
+ },
+ "required": [
+ "ClientParameters",
+ "AuthorizationEndpoint",
+ "HttpMethod"
+ ],
+ "additionalProperties": false
+ },
+ "ApiKeyAuthParameters": {
+ "type": "object",
+ "properties": {
+ "ApiKeyName": {
+ "type": "string"
+ },
+ "ApiKeyValue": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "ApiKeyName",
+ "ApiKeyValue"
+ ],
+ "additionalProperties": false
+ },
+ "ClientParameters": {
+ "type": "object",
+ "properties": {
+ "ClientID": {
+ "type": "string"
+ },
+ "ClientSecret": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "ClientID",
+ "ClientSecret"
+ ],
+ "additionalProperties": false
+ },
+ "ConnectionHttpParameters": {
+ "type": "object",
+ "properties": {
+ "HeaderParameters": {
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/Parameter"
+ }
+ },
+ "QueryStringParameters": {
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/Parameter"
+ }
+ },
+ "BodyParameters": {
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/Parameter"
+ }
+ }
+ },
+ "additionalProperties": false
+ },
+ "Parameter": {
+ "type": "object",
+ "properties": {
+ "Key": {
+ "type": "string"
+ },
+ "Value": {
+ "type": "string"
+ },
+ "IsValueSecret": {
+ "type": "boolean",
+ "default": true
+ }
+ },
+ "required": [
+ "Key",
+ "Value"
+ ],
+ "additionalProperties": false
+ }
+ },
+ "properties": {
+ "Name": {
+ "description": "Name of the connection.",
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 64
+ },
+ "Arn": {
+ "description": "The arn of the connection resource.",
+ "type": "string"
+ },
+ "SecretArn": {
+ "description": "The arn of the secrets manager secret created in the customer account.",
+ "type": "string"
+ },
+ "Description": {
+ "description": "Description of the connection.",
+ "type": "string",
+ "maxLength": 512
+ },
+ "AuthorizationType": {
+ "type": "string",
+ "enum": [
+ "API_KEY",
+ "BASIC",
+ "OAUTH_CLIENT_CREDENTIALS"
+ ]
+ },
+ "AuthParameters": {
+ "$ref": "#/definitions/AuthParameters"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "AuthorizationType",
+ "AuthParameters"
+ ],
+ "createOnlyProperties": [
+ "/properties/Name"
+ ],
+ "readOnlyProperties": [
+ "/properties/Arn",
+ "/properties/SecretArn"
+ ],
+ "writeOnlyProperties": [
+ "/properties/AuthParameters"
+ ],
+ "primaryIdentifier": [
+ "/properties/Name"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "events:CreateConnection",
+ "secretsmanager:CreateSecret",
+ "secretsmanager:GetSecretValue",
+ "secretsmanager:PutSecretValue",
+ "iam:CreateServiceLinkedRole"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "events:DescribeConnection"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "events:UpdateConnection",
+ "events:DescribeConnection",
+ "secretsmanager:CreateSecret",
+ "secretsmanager:UpdateSecret",
+ "secretsmanager:GetSecretValue",
+ "secretsmanager:PutSecretValue"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "events:DeleteConnection"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "events:ListConnections"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/events/resource_providers/aws_events_connection_plugin.py b/localstack-core/localstack/services/events/resource_providers/aws_events_connection_plugin.py
new file mode 100644
index 0000000000000..c8b16c6c961a1
--- /dev/null
+++ b/localstack-core/localstack/services/events/resource_providers/aws_events_connection_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EventsConnectionProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::Events::Connection"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.events.resource_providers.aws_events_connection import (
+ EventsConnectionProvider,
+ )
+
+ self.factory = EventsConnectionProvider
diff --git a/localstack-core/localstack/services/events/resource_providers/aws_events_eventbus.py b/localstack-core/localstack/services/events/resource_providers/aws_events_eventbus.py
new file mode 100644
index 0000000000000..5929d42f7252b
--- /dev/null
+++ b/localstack-core/localstack/services/events/resource_providers/aws_events_eventbus.py
@@ -0,0 +1,126 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class EventsEventBusProperties(TypedDict):
+ Name: Optional[str]
+ Arn: Optional[str]
+ EventSourceName: Optional[str]
+ Id: Optional[str]
+ Policy: Optional[str]
+ Tags: Optional[list[TagEntry]]
+
+
+class TagEntry(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EventsEventBusProvider(ResourceProvider[EventsEventBusProperties]):
+ TYPE = "AWS::Events::EventBus" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EventsEventBusProperties],
+ ) -> ProgressEvent[EventsEventBusProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - Name
+
+ Create-only properties:
+ - /properties/Name
+ - /properties/EventSourceName
+
+ Read-only properties:
+ - /properties/Id
+ - /properties/Policy
+ - /properties/Arn
+
+ """
+ model = request.desired_state
+ events = request.aws_client_factory.events
+
+ response = events.create_event_bus(Name=model["Name"])
+ model["Arn"] = response["EventBusArn"]
+ model["Id"] = model["Name"]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EventsEventBusProperties],
+ ) -> ProgressEvent[EventsEventBusProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EventsEventBusProperties],
+ ) -> ProgressEvent[EventsEventBusProperties]:
+ """
+ Delete a resource
+
+
+ """
+ model = request.desired_state
+ events = request.aws_client_factory.events
+
+ events.delete_event_bus(Name=model["Name"])
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[EventsEventBusProperties],
+ ) -> ProgressEvent[EventsEventBusProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
+
+ def list(
+ self,
+ request: ResourceRequest[EventsEventBusProperties],
+ ) -> ProgressEvent[EventsEventBusProperties]:
+ resources = request.aws_client_factory.events.list_event_buses()
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_models=[
+ EventsEventBusProperties(Name=resource["Name"])
+ for resource in resources["EventBuses"]
+ ],
+ )
diff --git a/localstack-core/localstack/services/events/resource_providers/aws_events_eventbus.schema.json b/localstack-core/localstack/services/events/resource_providers/aws_events_eventbus.schema.json
new file mode 100644
index 0000000000000..eb5d780188a5f
--- /dev/null
+++ b/localstack-core/localstack/services/events/resource_providers/aws_events_eventbus.schema.json
@@ -0,0 +1,62 @@
+{
+ "typeName": "AWS::Events::EventBus",
+ "description": "Resource Type definition for AWS::Events::EventBus",
+ "additionalProperties": false,
+ "properties": {
+ "Policy": {
+ "type": "string"
+ },
+ "Id": {
+ "type": "string"
+ },
+ "Arn": {
+ "type": "string"
+ },
+ "EventSourceName": {
+ "type": "string"
+ },
+ "Tags": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/TagEntry"
+ }
+ },
+ "Name": {
+ "type": "string"
+ }
+ },
+ "definitions": {
+ "TagEntry": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Value": {
+ "type": "string"
+ },
+ "Key": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ }
+ },
+ "required": [
+ "Name"
+ ],
+ "createOnlyProperties": [
+ "/properties/Name",
+ "/properties/EventSourceName"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id",
+ "/properties/Policy",
+ "/properties/Arn"
+ ]
+}
diff --git a/localstack-core/localstack/services/events/resource_providers/aws_events_eventbus_plugin.py b/localstack-core/localstack/services/events/resource_providers/aws_events_eventbus_plugin.py
new file mode 100644
index 0000000000000..25f94f1940bb2
--- /dev/null
+++ b/localstack-core/localstack/services/events/resource_providers/aws_events_eventbus_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EventsEventBusProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::Events::EventBus"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.events.resource_providers.aws_events_eventbus import (
+ EventsEventBusProvider,
+ )
+
+ self.factory = EventsEventBusProvider
diff --git a/localstack-core/localstack/services/events/resource_providers/aws_events_eventbuspolicy.py b/localstack-core/localstack/services/events/resource_providers/aws_events_eventbuspolicy.py
new file mode 100644
index 0000000000000..9da54ceeff6bf
--- /dev/null
+++ b/localstack-core/localstack/services/events/resource_providers/aws_events_eventbuspolicy.py
@@ -0,0 +1,155 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Optional, TypedDict
+
+from botocore.exceptions import ClientError
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+from localstack.utils.strings import short_uid
+
+
+class EventsEventBusPolicyProperties(TypedDict):
+ StatementId: Optional[str]
+ Action: Optional[str]
+ Condition: Optional[Condition]
+ EventBusName: Optional[str]
+ Id: Optional[str]
+ Principal: Optional[str]
+ Statement: Optional[dict]
+
+
+class Condition(TypedDict):
+ Key: Optional[str]
+ Type: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class EventsEventBusPolicyProvider(ResourceProvider[EventsEventBusPolicyProperties]):
+ TYPE = "AWS::Events::EventBusPolicy" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EventsEventBusPolicyProperties],
+ ) -> ProgressEvent[EventsEventBusPolicyProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - StatementId
+
+ Create-only properties:
+ - /properties/EventBusName
+ - /properties/StatementId
+
+ Read-only properties:
+ - /properties/Id
+
+
+
+ """
+ model = request.desired_state
+ events = request.aws_client_factory.events
+
+ model["Id"] = f"EventBusPolicy-{short_uid()}"
+
+ # either this field is set or all other fields (Action, Principal, etc.)
+ statement = model.get("Statement")
+ optional_params = {"EventBusName": model.get("EventBusName")}
+
+ if statement:
+ policy = {
+ "Version": "2012-10-17",
+ "Statement": [{"Sid": model["StatementId"], **statement}],
+ }
+ events.put_permission(Policy=json.dumps(policy), **optional_params)
+ else:
+ if model.get("Condition"):
+ optional_params.update({"Condition": model.get("Condition")})
+
+ events.put_permission(
+ StatementId=model["StatementId"],
+ Action=model["Action"],
+ Principal=model["Principal"],
+ **optional_params,
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EventsEventBusPolicyProperties],
+ ) -> ProgressEvent[EventsEventBusPolicyProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EventsEventBusPolicyProperties],
+ ) -> ProgressEvent[EventsEventBusPolicyProperties]:
+ """
+ Delete a resource
+
+ """
+ model = request.desired_state
+ events = request.aws_client_factory.events
+
+ statement_id = model["StatementId"]
+ event_bus_name = model.get("EventBusName")
+
+ params = {"StatementId": statement_id, "RemoveAllPermissions": False}
+
+ if event_bus_name:
+ params["EventBusName"] = event_bus_name
+
+ # We are using try/except since at the moment
+ # CFN doesn't properly resolve dependency between resources
+ # so this resource could be deleted if parent resource was deleted first
+
+ try:
+ events.remove_permission(**params)
+ except ClientError as err:
+ is_resource_not_found = err.response["Error"]["Code"] == "ResourceNotFoundException"
+
+ if not is_resource_not_found:
+ raise
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[EventsEventBusPolicyProperties],
+ ) -> ProgressEvent[EventsEventBusPolicyProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/events/resource_providers/aws_events_eventbuspolicy.schema.json b/localstack-core/localstack/services/events/resource_providers/aws_events_eventbuspolicy.schema.json
new file mode 100644
index 0000000000000..99bd136ddbdcd
--- /dev/null
+++ b/localstack-core/localstack/services/events/resource_providers/aws_events_eventbuspolicy.schema.json
@@ -0,0 +1,58 @@
+{
+ "typeName": "AWS::Events::EventBusPolicy",
+ "description": "Resource Type definition for AWS::Events::EventBusPolicy",
+ "additionalProperties": false,
+ "properties": {
+ "EventBusName": {
+ "type": "string"
+ },
+ "Condition": {
+ "$ref": "#/definitions/Condition"
+ },
+ "Action": {
+ "type": "string"
+ },
+ "StatementId": {
+ "type": "string"
+ },
+ "Statement": {
+ "type": "object"
+ },
+ "Id": {
+ "type": "string"
+ },
+ "Principal": {
+ "type": "string"
+ }
+ },
+ "definitions": {
+ "Condition": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Value": {
+ "type": "string"
+ },
+ "Type": {
+ "type": "string"
+ },
+ "Key": {
+ "type": "string"
+ }
+ }
+ }
+ },
+ "required": [
+ "StatementId"
+ ],
+ "createOnlyProperties": [
+ "/properties/EventBusName",
+ "/properties/StatementId"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ]
+}
diff --git a/localstack-core/localstack/services/events/resource_providers/aws_events_eventbuspolicy_plugin.py b/localstack-core/localstack/services/events/resource_providers/aws_events_eventbuspolicy_plugin.py
new file mode 100644
index 0000000000000..5368348690773
--- /dev/null
+++ b/localstack-core/localstack/services/events/resource_providers/aws_events_eventbuspolicy_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EventsEventBusPolicyProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::Events::EventBusPolicy"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.events.resource_providers.aws_events_eventbuspolicy import (
+ EventsEventBusPolicyProvider,
+ )
+
+ self.factory = EventsEventBusPolicyProvider
diff --git a/localstack-core/localstack/services/events/resource_providers/aws_events_rule.py b/localstack-core/localstack/services/events/resource_providers/aws_events_rule.py
new file mode 100644
index 0000000000000..a10d23360a41c
--- /dev/null
+++ b/localstack-core/localstack/services/events/resource_providers/aws_events_rule.py
@@ -0,0 +1,323 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+from localstack.utils import common
+
+
+class EventsRuleProperties(TypedDict):
+ Arn: Optional[str]
+ Description: Optional[str]
+ EventBusName: Optional[str]
+ EventPattern: Optional[dict]
+ Id: Optional[str]
+ Name: Optional[str]
+ RoleArn: Optional[str]
+ ScheduleExpression: Optional[str]
+ State: Optional[str]
+ Targets: Optional[list[Target]]
+
+
+class HttpParameters(TypedDict):
+ HeaderParameters: Optional[dict]
+ PathParameterValues: Optional[list[str]]
+ QueryStringParameters: Optional[dict]
+
+
+class DeadLetterConfig(TypedDict):
+ Arn: Optional[str]
+
+
+class RunCommandTarget(TypedDict):
+ Key: Optional[str]
+ Values: Optional[list[str]]
+
+
+class RunCommandParameters(TypedDict):
+ RunCommandTargets: Optional[list[RunCommandTarget]]
+
+
+class InputTransformer(TypedDict):
+ InputTemplate: Optional[str]
+ InputPathsMap: Optional[dict]
+
+
+class KinesisParameters(TypedDict):
+ PartitionKeyPath: Optional[str]
+
+
+class RedshiftDataParameters(TypedDict):
+ Database: Optional[str]
+ Sql: Optional[str]
+ DbUser: Optional[str]
+ SecretManagerArn: Optional[str]
+ StatementName: Optional[str]
+ WithEvent: Optional[bool]
+
+
+class SqsParameters(TypedDict):
+ MessageGroupId: Optional[str]
+
+
+class PlacementConstraint(TypedDict):
+ Expression: Optional[str]
+ Type: Optional[str]
+
+
+class PlacementStrategy(TypedDict):
+ Field: Optional[str]
+ Type: Optional[str]
+
+
+class CapacityProviderStrategyItem(TypedDict):
+ CapacityProvider: Optional[str]
+ Base: Optional[int]
+ Weight: Optional[int]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+class AwsVpcConfiguration(TypedDict):
+ Subnets: Optional[list[str]]
+ AssignPublicIp: Optional[str]
+ SecurityGroups: Optional[list[str]]
+
+
+class NetworkConfiguration(TypedDict):
+ AwsVpcConfiguration: Optional[AwsVpcConfiguration]
+
+
+class EcsParameters(TypedDict):
+ TaskDefinitionArn: Optional[str]
+ CapacityProviderStrategy: Optional[list[CapacityProviderStrategyItem]]
+ EnableECSManagedTags: Optional[bool]
+ EnableExecuteCommand: Optional[bool]
+ Group: Optional[str]
+ LaunchType: Optional[str]
+ NetworkConfiguration: Optional[NetworkConfiguration]
+ PlacementConstraints: Optional[list[PlacementConstraint]]
+ PlacementStrategies: Optional[list[PlacementStrategy]]
+ PlatformVersion: Optional[str]
+ PropagateTags: Optional[str]
+ ReferenceId: Optional[str]
+ TagList: Optional[list[Tag]]
+ TaskCount: Optional[int]
+
+
+class BatchRetryStrategy(TypedDict):
+ Attempts: Optional[int]
+
+
+class BatchArrayProperties(TypedDict):
+ Size: Optional[int]
+
+
+class BatchParameters(TypedDict):
+ JobDefinition: Optional[str]
+ JobName: Optional[str]
+ ArrayProperties: Optional[BatchArrayProperties]
+ RetryStrategy: Optional[BatchRetryStrategy]
+
+
+class SageMakerPipelineParameter(TypedDict):
+ Name: Optional[str]
+ Value: Optional[str]
+
+
+class SageMakerPipelineParameters(TypedDict):
+ PipelineParameterList: Optional[list[SageMakerPipelineParameter]]
+
+
+class RetryPolicy(TypedDict):
+ MaximumEventAgeInSeconds: Optional[int]
+ MaximumRetryAttempts: Optional[int]
+
+
+class Target(TypedDict):
+ Arn: Optional[str]
+ Id: Optional[str]
+ BatchParameters: Optional[BatchParameters]
+ DeadLetterConfig: Optional[DeadLetterConfig]
+ EcsParameters: Optional[EcsParameters]
+ HttpParameters: Optional[HttpParameters]
+ Input: Optional[str]
+ InputPath: Optional[str]
+ InputTransformer: Optional[InputTransformer]
+ KinesisParameters: Optional[KinesisParameters]
+ RedshiftDataParameters: Optional[RedshiftDataParameters]
+ RetryPolicy: Optional[RetryPolicy]
+ RoleArn: Optional[str]
+ RunCommandParameters: Optional[RunCommandParameters]
+ SageMakerPipelineParameters: Optional[SageMakerPipelineParameters]
+ SqsParameters: Optional[SqsParameters]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+MATCHING_OPERATIONS = [
+ "prefix",
+ "cidr",
+ "exists",
+ "suffix",
+ "anything-but",
+ "numeric",
+ "equals-ignore-case",
+ "wildcard",
+]
+
+
+def extract_rule_name(rule_id: str) -> str:
+ return rule_id.rsplit("|", maxsplit=1)[-1]
+
+
+class EventsRuleProvider(ResourceProvider[EventsRuleProperties]):
+ TYPE = "AWS::Events::Rule" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[EventsRuleProperties],
+ ) -> ProgressEvent[EventsRuleProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Create-only properties:
+ - /properties/Name
+ - /properties/EventBusName
+
+ Read-only properties:
+ - /properties/Id
+ - /properties/Arn
+
+
+ """
+ model = request.desired_state
+ events = request.aws_client_factory.events
+
+ name = model.get("Name")
+ if not name:
+ name = util.generate_default_name(
+ stack_name=request.stack_name, logical_resource_id=request.logical_resource_id
+ )
+
+ if event_bus_name := model.get("EventBusName"):
+ model["Id"] = "|".join(
+ [
+ event_bus_name,
+ name,
+ ]
+ )
+ else:
+ model["Id"] = name
+
+ attrs = [
+ "ScheduleExpression",
+ "EventPattern",
+ "State",
+ "Description",
+ "Name",
+ "EventBusName",
+ ]
+
+ params = util.select_attributes(model, attrs)
+
+ def wrap_in_lists(o, **kwargs):
+ if isinstance(o, dict):
+ for k, v in o.items():
+ if not isinstance(v, (dict, list)) and k not in MATCHING_OPERATIONS:
+ o[k] = [v]
+ return o
+
+ pattern = params.get("EventPattern")
+ if isinstance(pattern, dict):
+ wrapped = common.recurse_object(pattern, wrap_in_lists)
+ params["EventPattern"] = json.dumps(wrapped)
+
+ params["Name"] = name
+ result = events.put_rule(**params)
+ model["Arn"] = result["RuleArn"]
+
+ # put targets
+ event_bus_name = model.get("EventBusName")
+ targets = model.get("Targets") or []
+
+ if targets:
+ put_targets_kwargs = {"Rule": extract_rule_name(model["Id"]), "Targets": targets}
+ if event_bus_name:
+ put_targets_kwargs["EventBusName"] = event_bus_name
+
+ put_targets_kwargs = util.convert_request_kwargs(
+ put_targets_kwargs,
+ events.meta.service_model.operation_model("PutTargets").input_shape,
+ )
+
+ events.put_targets(**put_targets_kwargs)
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[EventsRuleProperties],
+ ) -> ProgressEvent[EventsRuleProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[EventsRuleProperties],
+ ) -> ProgressEvent[EventsRuleProperties]:
+ """
+ Delete a resource
+
+
+ """
+ model = request.desired_state
+ events = request.aws_client_factory.events
+
+ rule_name = extract_rule_name(model["Id"])
+ targets = events.list_targets_by_rule(Rule=rule_name)["Targets"]
+ target_ids = [tgt["Id"] for tgt in targets]
+ if targets:
+ events.remove_targets(Rule=rule_name, Ids=target_ids, Force=True)
+ events.delete_rule(Name=rule_name)
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[EventsRuleProperties],
+ ) -> ProgressEvent[EventsRuleProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/events/resource_providers/aws_events_rule.schema.json b/localstack-core/localstack/services/events/resource_providers/aws_events_rule.schema.json
new file mode 100644
index 0000000000000..c3a3601ff7b49
--- /dev/null
+++ b/localstack-core/localstack/services/events/resource_providers/aws_events_rule.schema.json
@@ -0,0 +1,495 @@
+{
+ "typeName": "AWS::Events::Rule",
+ "description": "Resource Type definition for AWS::Events::Rule",
+ "additionalProperties": false,
+ "properties": {
+ "EventBusName": {
+ "type": "string"
+ },
+ "EventPattern": {
+ "type": "object"
+ },
+ "ScheduleExpression": {
+ "type": "string"
+ },
+ "Description": {
+ "type": "string"
+ },
+ "State": {
+ "type": "string"
+ },
+ "Targets": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/Target"
+ }
+ },
+ "Id": {
+ "type": "string"
+ },
+ "Arn": {
+ "type": "string"
+ },
+ "RoleArn": {
+ "type": "string"
+ },
+ "Name": {
+ "type": "string"
+ }
+ },
+ "definitions": {
+ "CapacityProviderStrategyItem": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Base": {
+ "type": "integer"
+ },
+ "Weight": {
+ "type": "integer"
+ },
+ "CapacityProvider": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "CapacityProvider"
+ ]
+ },
+ "HttpParameters": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "PathParameterValues": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "HeaderParameters": {
+ "type": "object",
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ },
+ "QueryStringParameters": {
+ "type": "object",
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ }
+ }
+ },
+ "DeadLetterConfig": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Arn": {
+ "type": "string"
+ }
+ }
+ },
+ "RunCommandParameters": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "RunCommandTargets": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/RunCommandTarget"
+ }
+ }
+ },
+ "required": [
+ "RunCommandTargets"
+ ]
+ },
+ "PlacementStrategy": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Field": {
+ "type": "string"
+ },
+ "Type": {
+ "type": "string"
+ }
+ }
+ },
+ "InputTransformer": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "InputTemplate": {
+ "type": "string"
+ },
+ "InputPathsMap": {
+ "type": "object",
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ }
+ },
+ "required": [
+ "InputTemplate"
+ ]
+ },
+ "KinesisParameters": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "PartitionKeyPath": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "PartitionKeyPath"
+ ]
+ },
+ "BatchRetryStrategy": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Attempts": {
+ "type": "integer"
+ }
+ }
+ },
+ "RedshiftDataParameters": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "StatementName": {
+ "type": "string"
+ },
+ "Database": {
+ "type": "string"
+ },
+ "SecretManagerArn": {
+ "type": "string"
+ },
+ "DbUser": {
+ "type": "string"
+ },
+ "Sql": {
+ "type": "string"
+ },
+ "WithEvent": {
+ "type": "boolean"
+ }
+ },
+ "required": [
+ "Database",
+ "Sql"
+ ]
+ },
+ "Target": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "InputPath": {
+ "type": "string"
+ },
+ "HttpParameters": {
+ "$ref": "#/definitions/HttpParameters"
+ },
+ "DeadLetterConfig": {
+ "$ref": "#/definitions/DeadLetterConfig"
+ },
+ "RunCommandParameters": {
+ "$ref": "#/definitions/RunCommandParameters"
+ },
+ "InputTransformer": {
+ "$ref": "#/definitions/InputTransformer"
+ },
+ "KinesisParameters": {
+ "$ref": "#/definitions/KinesisParameters"
+ },
+ "RoleArn": {
+ "type": "string"
+ },
+ "RedshiftDataParameters": {
+ "$ref": "#/definitions/RedshiftDataParameters"
+ },
+ "Input": {
+ "type": "string"
+ },
+ "SqsParameters": {
+ "$ref": "#/definitions/SqsParameters"
+ },
+ "EcsParameters": {
+ "$ref": "#/definitions/EcsParameters"
+ },
+ "BatchParameters": {
+ "$ref": "#/definitions/BatchParameters"
+ },
+ "Id": {
+ "type": "string"
+ },
+ "Arn": {
+ "type": "string"
+ },
+ "SageMakerPipelineParameters": {
+ "$ref": "#/definitions/SageMakerPipelineParameters"
+ },
+ "RetryPolicy": {
+ "$ref": "#/definitions/RetryPolicy"
+ }
+ },
+ "required": [
+ "Id",
+ "Arn"
+ ]
+ },
+ "PlacementConstraint": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Expression": {
+ "type": "string"
+ },
+ "Type": {
+ "type": "string"
+ }
+ }
+ },
+ "AwsVpcConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "SecurityGroups": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Subnets": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "AssignPublicIp": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Subnets"
+ ]
+ },
+ "SqsParameters": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "MessageGroupId": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "MessageGroupId"
+ ]
+ },
+ "RunCommandTarget": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Values": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Key": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Values",
+ "Key"
+ ]
+ },
+ "EcsParameters": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "PlatformVersion": {
+ "type": "string"
+ },
+ "Group": {
+ "type": "string"
+ },
+ "EnableECSManagedTags": {
+ "type": "boolean"
+ },
+ "EnableExecuteCommand": {
+ "type": "boolean"
+ },
+ "PlacementConstraints": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/PlacementConstraint"
+ }
+ },
+ "PropagateTags": {
+ "type": "string"
+ },
+ "TaskCount": {
+ "type": "integer"
+ },
+ "PlacementStrategies": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/PlacementStrategy"
+ }
+ },
+ "CapacityProviderStrategy": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/CapacityProviderStrategyItem"
+ }
+ },
+ "LaunchType": {
+ "type": "string"
+ },
+ "ReferenceId": {
+ "type": "string"
+ },
+ "TagList": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "NetworkConfiguration": {
+ "$ref": "#/definitions/NetworkConfiguration"
+ },
+ "TaskDefinitionArn": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "TaskDefinitionArn"
+ ]
+ },
+ "BatchParameters": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "JobName": {
+ "type": "string"
+ },
+ "RetryStrategy": {
+ "$ref": "#/definitions/BatchRetryStrategy"
+ },
+ "ArrayProperties": {
+ "$ref": "#/definitions/BatchArrayProperties"
+ },
+ "JobDefinition": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "JobName",
+ "JobDefinition"
+ ]
+ },
+ "NetworkConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "AwsVpcConfiguration": {
+ "$ref": "#/definitions/AwsVpcConfiguration"
+ }
+ }
+ },
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Value": {
+ "type": "string"
+ },
+ "Key": {
+ "type": "string"
+ }
+ }
+ },
+ "SageMakerPipelineParameters": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "PipelineParameterList": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/SageMakerPipelineParameter"
+ }
+ }
+ }
+ },
+ "RetryPolicy": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "MaximumEventAgeInSeconds": {
+ "type": "integer"
+ },
+ "MaximumRetryAttempts": {
+ "type": "integer"
+ }
+ }
+ },
+ "BatchArrayProperties": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Size": {
+ "type": "integer"
+ }
+ }
+ },
+ "SageMakerPipelineParameter": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Value": {
+ "type": "string"
+ },
+ "Name": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Name"
+ ]
+ }
+ },
+ "createOnlyProperties": [
+ "/properties/Name",
+ "/properties/EventBusName"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id",
+ "/properties/Arn"
+ ]
+}
diff --git a/localstack-core/localstack/services/events/resource_providers/aws_events_rule_plugin.py b/localstack-core/localstack/services/events/resource_providers/aws_events_rule_plugin.py
new file mode 100644
index 0000000000000..3fa01b6717fdc
--- /dev/null
+++ b/localstack-core/localstack/services/events/resource_providers/aws_events_rule_plugin.py
@@ -0,0 +1,18 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class EventsRuleProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::Events::Rule"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.events.resource_providers.aws_events_rule import EventsRuleProvider
+
+ self.factory = EventsRuleProvider
diff --git a/localstack-core/localstack/services/events/rule.py b/localstack-core/localstack/services/events/rule.py
new file mode 100644
index 0000000000000..576cfc36e781c
--- /dev/null
+++ b/localstack-core/localstack/services/events/rule.py
@@ -0,0 +1,245 @@
+import re
+from typing import Callable, Optional
+
+from localstack.aws.api.events import (
+ Arn,
+ EventBusName,
+ EventPattern,
+ LimitExceededException,
+ ManagedBy,
+ PutTargetsResultEntryList,
+ RemoveTargetsResultEntryList,
+ RoleArn,
+ RuleDescription,
+ RuleName,
+ RuleState,
+ ScheduleExpression,
+ TagList,
+ Target,
+ TargetIdList,
+ TargetList,
+)
+from localstack.services.events.models import Rule, TargetDict, ValidationException
+from localstack.services.events.scheduler import JobScheduler, convert_schedule_to_cron
+
+TARGET_ID_REGEX = re.compile(r"^[\.\-_A-Za-z0-9]+$")
+TARGET_ARN_REGEX = re.compile(r"arn:[\d\w:\-/]*")
+CRON_REGEX = ( # borrowed from https://regex101.com/r/I80Eu0/1
+ r"^(?:cron[(](?:(?:(?:[0-5]?[0-9])|[*])(?:(?:[-](?:(?:[0-5]?[0-9])|[*]))|(?:[/][0-9]+))?"
+ r"(?:[,](?:(?:[0-5]?[0-9])|[*])(?:(?:[-](?:(?:[0-5]?[0-9])|[*]))|(?:[/][0-9]+))?)*)[ ]+"
+ r"(?:(?:(?:[0-2]?[0-9])|[*])(?:(?:[-](?:(?:[0-2]?[0-9])|[*]))|(?:[/][0-9]+))?"
+ r"(?:[,](?:(?:[0-2]?[0-9])|[*])(?:(?:[-](?:(?:[0-2]?[0-9])|[*]))|(?:[/][0-9]+))?)*)[ ]+"
+ r"(?:(?:[?][ ]+(?:(?:(?:[1]?[0-9])|(?:JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC)|[*])"
+ r"(?:(?:[-](?:(?:[1]?[0-9])|(?:JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC)|[*])(?:[/][0-9]+)?)|"
+ r"(?:[/][0-9]+))?(?:[,](?:(?:[1]?[0-9])|(?:JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC)|[*])"
+ r"(?:(?:[-](?:(?:[1]?[0-9])|(?:JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC)|[*])(?:[/][0-9]+)?)|"
+ r"(?:[/][0-9]+))?)*)[ ]+(?:(?:(?:[1-7]|(?:SUN|MON|TUE|WED|THU|FRI|SAT))[#][0-5])|"
+ r"(?:(?:(?:(?:[1-7]|(?:SUN|MON|TUE|WED|THU|FRI|SAT))L?)|[L*])(?:(?:[-](?:(?:(?:[1-7]|"
+ r"(?:SUN|MON|TUE|WED|THU|FRI|SAT))L?)|[L*]))|(?:[/][0-9]+))?(?:[,](?:(?:(?:[1-7]|"
+ r"(?:SUN|MON|TUE|WED|THU|FRI|SAT))L?)|[L*])(?:(?:[-](?:(?:(?:[1-7]|(?:SUN|MON|TUE|WED|THU|FRI|SAT))L?)|"
+ r"[L*]))|(?:[/][0-9]+))?)*)))|(?:(?:(?:(?:(?:[1-3]?[0-9])W?)|LW|[L*])(?:(?:[-](?:(?:(?:[1-3]?[0-9])W?)|"
+ r"LW|[L*]))|(?:[/][0-9]+))?(?:[,](?:(?:(?:[1-3]?[0-9])W?)|LW|[L*])(?:(?:[-](?:(?:(?:[1-3]?[0-9])W?)|"
+ r"LW|[L*]))|(?:[/][0-9]+))?)*)[ ]+(?:(?:(?:[1]?[0-9])|(?:JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC)|"
+ r"[*])(?:(?:[-](?:(?:[1]?[0-9])|(?:JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC)|[*])(?:[/][0-9]+)?)|"
+ r"(?:[/][0-9]+))?(?:[,](?:(?:[1]?[0-9])|(?:JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC)|[*])"
+ r"(?:(?:[-](?:(?:[1]?[0-9])|(?:JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC)|[*])(?:[/][0-9]+)?)|"
+ r"(?:[/][0-9]+))?)*)[ ]+[?]))[ ]+(?:(?:(?:[12][0-9]{3})|[*])(?:(?:[-](?:(?:[12][0-9]{3})|[*]))|"
+ r"(?:[/][0-9]+))?(?:[,](?:(?:[12][0-9]{3})|[*])(?:(?:[-](?:(?:[12][0-9]{3})|[*]))|(?:[/][0-9]+))?)*)[)])$"
+)
+RULE_SCHEDULE_CRON_REGEX = re.compile(CRON_REGEX)
+RULE_SCHEDULE_RATE_REGEX = re.compile(r"^rate\(\d*\s(minute|minutes|hour|hours|day|days)\)")
+
+
+class RuleService:
+ name: RuleName
+ region: str
+ account_id: str
+ schedule_expression: ScheduleExpression | None
+ event_pattern: EventPattern | None
+ description: RuleDescription | None
+ role_arn: Arn | None
+ tags: TagList | None
+ event_bus_name: EventBusName | None
+ targets: TargetDict | None
+ managed_by: ManagedBy
+ rule: Rule
+
+ def __init__(self, rule: Rule):
+ self.rule = rule
+ if rule.schedule_expression:
+ self.schedule_cron = self._get_schedule_cron(rule.schedule_expression)
+ else:
+ self.schedule_cron = None
+
+ @classmethod
+ def create_rule_service(
+ cls,
+ name: RuleName,
+ region: Optional[str] = None,
+ account_id: Optional[str] = None,
+ schedule_expression: Optional[ScheduleExpression] = None,
+ event_pattern: Optional[EventPattern] = None,
+ state: Optional[RuleState] = None,
+ description: Optional[RuleDescription] = None,
+ role_arn: Optional[RoleArn] = None,
+ tags: Optional[TagList] = None,
+ event_bus_name: Optional[EventBusName] = None,
+ targets: Optional[TargetDict] = None,
+ managed_by: Optional[ManagedBy] = None,
+ ):
+ cls._validate_input(event_pattern, schedule_expression, event_bus_name)
+ # required to keep data and functionality separate for persistence
+ return cls(
+ Rule(
+ name,
+ region,
+ account_id,
+ schedule_expression,
+ event_pattern,
+ state,
+ description,
+ role_arn,
+ tags,
+ event_bus_name,
+ targets,
+ managed_by,
+ )
+ )
+
+ @property
+ def arn(self) -> Arn:
+ return self.rule.arn
+
+ @property
+ def state(self) -> RuleState:
+ return self.rule.state
+
+ def enable(self) -> None:
+ self.rule.state = RuleState.ENABLED
+
+ def disable(self) -> None:
+ self.rule.state = RuleState.DISABLED
+
+ def add_targets(self, targets: TargetList) -> PutTargetsResultEntryList:
+ failed_entries = self.validate_targets_input(targets)
+ for target in targets:
+ target_id = target["Id"]
+ if target_id not in self.rule.targets and self._check_target_limit_reached():
+ raise LimitExceededException(
+ "The requested resource exceeds the maximum number allowed."
+ )
+ target = Target(**target)
+ self.rule.targets[target_id] = target
+ return failed_entries
+
+ def remove_targets(
+ self, target_ids: TargetIdList, force: bool = False
+ ) -> RemoveTargetsResultEntryList:
+ delete_errors = []
+ for target_id in target_ids:
+ if target_id in self.rule.targets:
+ if self.rule.managed_by and not force:
+ delete_errors.append(
+ {
+ "TargetId": target_id,
+ "ErrorCode": "ManagedRuleException",
+ "ErrorMessage": f"Rule '{self.rule.name}' is managed by an AWS service can only be modified if force is True.",
+ }
+ )
+ else:
+ del self.rule.targets[target_id]
+ else:
+ delete_errors.append(
+ {
+ "TargetId": target_id,
+ "ErrorCode": "ResourceNotFoundException",
+ "ErrorMessage": f"Rule '{self.rule.name}' does not have a target with the Id '{target_id}'.",
+ }
+ )
+ return delete_errors
+
+ def create_schedule_job(self, schedule_job_sender_func: Callable) -> None:
+ cron = self.schedule_cron
+ state = self.rule.state != "DISABLED"
+ self.job_id = JobScheduler.instance().add_job(schedule_job_sender_func, cron, state)
+
+ def validate_targets_input(self, targets: TargetList) -> PutTargetsResultEntryList:
+ validation_errors = []
+ for index, target in enumerate(targets):
+ id = target.get("Id")
+ arn = target.get("Arn", "")
+ if not TARGET_ID_REGEX.match(id):
+ validation_errors.append(
+ {
+ "TargetId": id,
+ "ErrorCode": "ValidationException",
+ "ErrorMessage": f"Value '{id}' at 'targets.{index + 1}.member.id' failed to satisfy constraint: Member must satisfy regular expression pattern: [\\.\\-_A-Za-z0-9]+",
+ }
+ )
+
+ if len(id) > 64:
+ validation_errors.append(
+ {
+ "TargetId": id,
+ "ErrorCode": "ValidationException",
+ "ErrorMessage": f"Value '{id}' at 'targets.{index + 1}.member.id' failed to satisfy constraint: Member must have length less than or equal to 64",
+ }
+ )
+
+ if not TARGET_ARN_REGEX.match(arn):
+ validation_errors.append(
+ {
+ "TargetId": id,
+ "ErrorCode": "ValidationException",
+ "ErrorMessage": f"Parameter {arn} is not valid. Reason: Provided Arn is not in correct format.",
+ }
+ )
+
+ if ":sqs:" in arn and arn.endswith(".fifo") and not target.get("SqsParameters"):
+ validation_errors.append(
+ {
+ "TargetId": id,
+ "ErrorCode": "ValidationException",
+ "ErrorMessage": f"Parameter(s) SqsParameters must be specified for target: {id}.",
+ }
+ )
+
+ return validation_errors
+
+ @classmethod
+ def _validate_input(
+ cls,
+ event_pattern: Optional[EventPattern],
+ schedule_expression: Optional[ScheduleExpression],
+ event_bus_name: Optional[EventBusName] = "default",
+ ) -> None:
+ if not event_pattern and not schedule_expression:
+ raise ValidationException(
+ "Parameter(s) EventPattern or ScheduleExpression must be specified."
+ )
+
+ if schedule_expression:
+ if event_bus_name != "default":
+ raise ValidationException(
+ "ScheduleExpression is supported only on the default event bus."
+ )
+ if not (
+ RULE_SCHEDULE_CRON_REGEX.match(schedule_expression)
+ or RULE_SCHEDULE_RATE_REGEX.match(schedule_expression)
+ ):
+ raise ValidationException("Parameter ScheduleExpression is not valid.")
+
+ def _check_target_limit_reached(self) -> bool:
+ if len(self.rule.targets) >= 5:
+ return True
+ return False
+
+ def _get_schedule_cron(self, schedule_expression: ScheduleExpression) -> str:
+ try:
+ cron = convert_schedule_to_cron(schedule_expression)
+ return cron
+ except ValueError as e:
+ raise ValidationException("Parameter ScheduleExpression is not valid.") from e
+
+
+RuleServiceDict = dict[Arn, RuleService]
diff --git a/localstack-core/localstack/services/events/scheduler.py b/localstack-core/localstack/services/events/scheduler.py
new file mode 100644
index 0000000000000..c71833f402d0b
--- /dev/null
+++ b/localstack-core/localstack/services/events/scheduler.py
@@ -0,0 +1,136 @@
+import logging
+import re
+import threading
+
+from crontab import CronTab
+
+from localstack.utils.common import short_uid
+from localstack.utils.run import FuncThread
+
+LOG = logging.getLogger(__name__)
+
+CRON_REGEX = re.compile(r"\s*cron\s*\(([^\)]*)\)\s*")
+RATE_REGEX = re.compile(r"\s*rate\s*\(([^\)]*)\)\s*")
+
+
+def convert_schedule_to_cron(schedule):
+ """Convert Events schedule like "cron(0 20 * * ? *)" or "rate(5 minutes)" """
+ cron_match = CRON_REGEX.match(schedule)
+ if cron_match:
+ return cron_match.group(1)
+
+ rate_match = RATE_REGEX.match(schedule)
+ if rate_match:
+ rate = rate_match.group(1)
+ rate_value, rate_unit = re.split(r"\s+", rate.strip())
+ rate_value = int(rate_value)
+
+ if rate_value < 1:
+ raise ValueError("Rate value must be larger than 0")
+ # see https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-rate-expressions.html
+ if rate_value == 1 and rate_unit.endswith("s"):
+ raise ValueError("If the value is equal to 1, then the unit must be singular")
+ if rate_value > 1 and not rate_unit.endswith("s"):
+ raise ValueError("If the value is greater than 1, the unit must be plural")
+
+ if "minute" in rate_unit:
+ return f"*/{rate_value} * * * *"
+ if "hour" in rate_unit:
+ return f"0 */{rate_value} * * *"
+ if "day" in rate_unit:
+ return f"0 0 */{rate_value} * *"
+
+ # TODO: cover via test
+ # raise ValueError(f"Unable to parse events schedule expression: {schedule}")
+
+ return schedule
+
+
+class Job:
+ def __init__(self, job_func, schedule, enabled):
+ self.job_func = job_func
+ self.schedule = schedule
+ self.job_id = short_uid()
+ self.is_enabled = enabled
+
+ def run(self):
+ try:
+ if self.should_run_now() and self.is_enabled:
+ self.do_run()
+ except Exception as e:
+ LOG.debug("Unable to run scheduled function %s: %s", self.job_func, e)
+
+ def should_run_now(self):
+ schedule = CronTab(self.schedule)
+ delay_secs = schedule.next(
+ default_utc=True
+ ) # utc default time format for rule schedule cron
+ # TODO fix execute on exact cron time
+ return delay_secs is not None and delay_secs < 60
+
+ def do_run(self):
+ FuncThread(self.job_func, name="events-job-run").start()
+
+
+class JobScheduler:
+ _instance = None
+
+ def __init__(self):
+ # TODO: introduce RLock for mutating jobs list
+ self.jobs = []
+ self.thread = None
+ self._stop_event = threading.Event()
+
+ def add_job(self, job_func, schedule, enabled=True):
+ job = Job(job_func, schedule, enabled=enabled)
+ self.jobs.append(job)
+ return job.job_id
+
+ def get_job(self, job_id) -> Job | None:
+ for job in self.jobs:
+ if job.job_id == job_id:
+ return job
+ return None
+
+ def disable_job(self, job_id):
+ for job in self.jobs:
+ if job.job_id == job_id:
+ job.is_enabled = False
+ break
+
+ def cancel_job(self, job_id):
+ self.jobs = [job for job in self.jobs if job.job_id != job_id]
+
+ def loop(self, *args):
+ while not self._stop_event.is_set():
+ try:
+ for job in list(self.jobs):
+ job.run()
+ except Exception:
+ pass
+ # This is a simple heuristic to cause the loop to run approximately every minute
+ # TODO: we should keep track of jobs execution times, to avoid duplicate executions
+ self._stop_event.wait(timeout=59.9)
+
+ def start_loop(self):
+ self.thread = FuncThread(self.loop, name="events-jobscheduler-loop")
+ self.thread.start()
+
+ @classmethod
+ def instance(cls):
+ if not cls._instance:
+ cls._instance = JobScheduler()
+ return cls._instance
+
+ @classmethod
+ def start(cls):
+ instance = cls.instance()
+ if not instance.thread:
+ instance.start_loop()
+ return instance
+
+ @classmethod
+ def shutdown(cls):
+ instance = cls.instance()
+ if instance.thread:
+ instance._stop_event.set()
diff --git a/localstack-core/localstack/services/events/target.py b/localstack-core/localstack/services/events/target.py
new file mode 100644
index 0000000000000..b12691f28925e
--- /dev/null
+++ b/localstack-core/localstack/services/events/target.py
@@ -0,0 +1,715 @@
+import datetime
+import json
+import logging
+import re
+import uuid
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Set, Type
+from urllib.parse import urlencode
+
+import requests
+from botocore.client import BaseClient
+
+from localstack import config
+from localstack.aws.api.events import (
+ Arn,
+ InputTransformer,
+ RuleName,
+ Target,
+ TargetInputPath,
+)
+from localstack.aws.connect import connect_to
+from localstack.services.events.api_destination import add_api_destination_authorization
+from localstack.services.events.models import (
+ FormattedEvent,
+ TransformedEvent,
+ ValidationException,
+)
+from localstack.services.events.utils import (
+ event_time_to_time_string,
+ get_trace_header_encoded_region_account,
+ is_nested_in_string,
+ to_json_str,
+)
+from localstack.utils import collections
+from localstack.utils.aws.arns import (
+ extract_account_id_from_arn,
+ extract_region_from_arn,
+ extract_service_from_arn,
+ firehose_name,
+ parse_arn,
+ sqs_queue_url_for_arn,
+)
+from localstack.utils.aws.client_types import ServicePrincipal
+from localstack.utils.aws.message_forwarding import (
+ add_target_http_parameters,
+)
+from localstack.utils.json import extract_jsonpath
+from localstack.utils.strings import to_bytes
+from localstack.utils.time import now_utc
+
+LOG = logging.getLogger(__name__)
+
+# https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-transform-target-input.html#eb-transform-input-predefined
+AWS_PREDEFINED_PLACEHOLDERS_STRING_VALUES = {
+ "aws.events.rule-arn",
+ "aws.events.rule-name",
+ "aws.events.event.ingestion-time",
+}
+AWS_PREDEFINED_PLACEHOLDERS_JSON_VALUES = {"aws.events.event", "aws.events.event.json"}
+
+PREDEFINED_PLACEHOLDERS: Set[str] = AWS_PREDEFINED_PLACEHOLDERS_STRING_VALUES.union(
+ AWS_PREDEFINED_PLACEHOLDERS_JSON_VALUES
+)
+
+TRANSFORMER_PLACEHOLDER_PATTERN = re.compile(r"<(.*?)>")
+
+
+def transform_event_with_target_input_path(
+ input_path: TargetInputPath, event: FormattedEvent
+) -> TransformedEvent:
+ formatted_event = extract_jsonpath(event, input_path)
+ return formatted_event
+
+
+def get_template_replacements(
+ input_transformer: InputTransformer, event: FormattedEvent
+) -> dict[str, Any]:
+ """Extracts values from the event using the input paths map keys and places them in the input template dict."""
+ template_replacements = {}
+ transformer_path_map = input_transformer.get("InputPathsMap", {})
+ for placeholder, transformer_path in transformer_path_map.items():
+ if placeholder in PREDEFINED_PLACEHOLDERS:
+ continue
+ value = extract_jsonpath(event, transformer_path)
+ if not value:
+ value = "" # default value is empty string
+ template_replacements[placeholder] = value
+ return template_replacements
+
+
+def replace_template_placeholders(
+ template: str, replacements: dict[str, Any], is_json_template: bool
+) -> TransformedEvent:
+ """Replace placeholders defined by in the template with the values from the replacements dict.
+ Can handle single template string or template dict."""
+
+ def replace_placeholder(match):
+ key = match.group(1)
+ value = replacements.get(key, "") # handle non defined placeholders
+ if isinstance(value, datetime.datetime):
+ return event_time_to_time_string(value)
+ if isinstance(value, dict):
+ json_str = to_json_str(value).replace('\\"', '"')
+ if is_json_template:
+ return json_str
+ return json_str.replace('"', "")
+ if isinstance(value, list):
+ if is_json_template:
+ return json.dumps(value)
+ return f"[{','.join(value)}]"
+ if is_nested_in_string(template, match):
+ return value
+ if is_json_template:
+ return json.dumps(value)
+ return value
+
+ formatted_template = TRANSFORMER_PLACEHOLDER_PATTERN.sub(replace_placeholder, template).replace(
+ "\\n", "\n"
+ )
+
+ if is_json_template:
+ try:
+ loaded_json_template = json.loads(formatted_template)
+ return loaded_json_template
+ except json.JSONDecodeError:
+ LOG.info(
+ json.dumps(
+ {
+ "InfoCode": "InternalInfoEvents at transform_event",
+ "InfoMessage": f"Replaced template is not valid json: {formatted_template}",
+ }
+ )
+ )
+ else:
+ return formatted_template[1:-1]
+
+
+class TargetSender(ABC):
+ target: Target
+ rule_arn: Arn
+ rule_name: RuleName
+ service: str
+
+ region: str # region of the event bus
+ account_id: str # region of the event bus
+ target_region: str
+ target_account_id: str
+ _client: BaseClient | None
+
+ def __init__(
+ self,
+ target: Target,
+ rule_arn: Arn,
+ rule_name: RuleName,
+ service: str,
+ region: str,
+ account_id: str,
+ ):
+ self.target = target
+ self.rule_arn = rule_arn
+ self.rule_name = rule_name
+ self.service = service
+ self.region = region
+ self.account_id = account_id
+
+ self.target_region = extract_region_from_arn(self.target["Arn"])
+ self.target_account_id = extract_account_id_from_arn(self.target["Arn"])
+
+ self._validate_input(target)
+ self._client: BaseClient | None = None
+
+ @property
+ def arn(self):
+ return self.target["Arn"]
+
+ @property
+ def target_id(self):
+ return self.target["Id"]
+
+ @property
+ def unique_id(self):
+ """Necessary to distinguish between targets with the same ARN but for different rules.
+ The unique_id is a combination of the rule ARN and the Target Id.
+ This is necessary since input path and input transformer can be different for the same target ARN,
+ attached to different rules."""
+ return f"{self.rule_arn}-{self.target_id}"
+
+ @property
+ def client(self):
+ """Lazy initialization of internal botoclient factory."""
+ if self._client is None:
+ self._client = self._initialize_client()
+ return self._client
+
+ @abstractmethod
+ def send_event(self, event: FormattedEvent | TransformedEvent):
+ pass
+
+ def process_event(self, event: FormattedEvent):
+ """Processes the event and send it to the target."""
+ if input_ := self.target.get("Input"):
+ event = json.loads(input_)
+ if isinstance(event, dict):
+ event.pop("event-bus-name", None)
+ if not input_:
+ if input_path := self.target.get("InputPath"):
+ event = transform_event_with_target_input_path(input_path, event)
+ if input_transformer := self.target.get("InputTransformer"):
+ event = self.transform_event_with_target_input_transformer(input_transformer, event)
+ if event:
+ self.send_event(event)
+ else:
+ LOG.info("No event to send to target %s", self.target.get("Id"))
+
+ def transform_event_with_target_input_transformer(
+ self, input_transformer: InputTransformer, event: FormattedEvent
+ ) -> TransformedEvent:
+ input_template = input_transformer["InputTemplate"]
+ template_replacements = get_template_replacements(input_transformer, event)
+ predefined_template_replacements = self._get_predefined_template_replacements(event)
+ template_replacements.update(predefined_template_replacements)
+
+ is_json_template = input_template.strip().startswith(("{"))
+ populated_template = replace_template_placeholders(
+ input_template, template_replacements, is_json_template
+ )
+
+ return populated_template
+
+ def _validate_input(self, target: Target):
+ """Provide a default implementation extended for each target based on specifications."""
+ # TODO add For Lambda and Amazon SNS resources, EventBridge relies on resource-based policies.
+ if "InputPath" in target and "InputTransformer" in target:
+ raise ValidationException(
+ f"Only one of Input, InputPath, or InputTransformer must be provided for target {target.get('Id')}."
+ )
+ if input_transformer := target.get("InputTransformer"):
+ self._validate_input_transformer(input_transformer)
+
+ def _initialize_client(self) -> BaseClient:
+ """Initializes internal boto client.
+ If a role from a target is provided, the client will be initialized with the assumed role.
+ If no role is provided, the client will be initialized with the account ID and region.
+ In both cases event bridge is requested as service principal"""
+ service_principal = ServicePrincipal.events
+ role_arn = self.target.get("RoleArn")
+ if role_arn: # required for cross account
+ # assumed role sessions expire after 6 hours in AWS, currently no expiration in LocalStack
+ client_factory = connect_to.with_assumed_role(
+ role_arn=role_arn,
+ service_principal=service_principal,
+ region_name=self.region,
+ )
+ else:
+ client_factory = connect_to(aws_access_key_id=self.account_id, region_name=self.region)
+ client = client_factory.get_client(self.service)
+ client = client.request_metadata(
+ service_principal=service_principal, source_arn=self.rule_arn
+ )
+ return client
+
+ def _validate_input_transformer(self, input_transformer: InputTransformer):
+ # TODO: cover via test
+ # if "InputTemplate" not in input_transformer:
+ # raise ValueError("InputTemplate is required for InputTransformer")
+ input_template = input_transformer["InputTemplate"]
+ input_paths_map = input_transformer.get("InputPathsMap", {})
+ placeholders = TRANSFORMER_PLACEHOLDER_PATTERN.findall(input_template)
+ for placeholder in placeholders:
+ if placeholder not in input_paths_map and placeholder not in PREDEFINED_PLACEHOLDERS:
+ raise ValidationException(
+ f"InputTemplate for target {self.target.get('Id')} contains invalid placeholder {placeholder}."
+ )
+
+ def _get_predefined_template_replacements(self, event: FormattedEvent) -> dict[str, Any]:
+ """Extracts predefined values from the event."""
+ predefined_template_replacements = {}
+ predefined_template_replacements["aws.events.rule-arn"] = self.rule_arn
+ predefined_template_replacements["aws.events.rule-name"] = self.rule_name
+ predefined_template_replacements["aws.events.event.ingestion-time"] = event["time"]
+ predefined_template_replacements["aws.events.event"] = {
+ "detailType" if k == "detail-type" else k: v # detail-type is is returned as detailType
+ for k, v in event.items()
+ if k != "detail" # detail is not part of .event placeholder
+ }
+ predefined_template_replacements["aws.events.event.json"] = event
+
+ return predefined_template_replacements
+
+
+TargetSenderDict = dict[str, TargetSender] # rule_arn-target_id as global unique id
+
+# Target Senders are ordered alphabetically by service name
+
+
+class ApiGatewayTargetSender(TargetSender):
+ """
+ ApiGatewayTargetSender is a TargetSender that sends events to an API Gateway target.
+ """
+
+ PROHIBITED_HEADERS = [
+ "authorization",
+ "connection",
+ "content-encoding",
+ "content-length",
+ "host",
+ "max-forwards",
+ "te",
+ "transfer-encoding",
+ "trailer",
+ "upgrade",
+ "via",
+ "www-authenticate",
+ "x-forwarded-for",
+ ] # https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-api-gateway-target.html
+
+ ALLOWED_HTTP_METHODS = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
+
+ def send_event(self, event):
+ # Parse the ARN to extract api_id, stage_name, http_method, and resource path
+ # Example ARN: arn:{partition}:execute-api:{region}:{account_id}:{api_id}/{stage_name}/{method}/{resource_path}
+ arn_parts = parse_arn(self.target["Arn"])
+ api_gateway_info = arn_parts["resource"] # e.g., 'myapi/dev/POST/pets/*/*'
+ api_gateway_info_parts = api_gateway_info.split("/")
+
+ api_id = api_gateway_info_parts[0]
+ stage_name = api_gateway_info_parts[1]
+ http_method = api_gateway_info_parts[2].upper()
+ resource_path_parts = api_gateway_info_parts[3:] # may contain wildcards
+
+ if http_method not in self.ALLOWED_HTTP_METHODS:
+ LOG.error("Unsupported HTTP method: %s", http_method)
+ return
+
+ # Replace wildcards in resource path with PathParameterValues
+ path_params_values = self.target.get("HttpParameters", {}).get("PathParameterValues", [])
+ resource_path_segments = []
+ path_param_index = 0
+ for part in resource_path_parts:
+ if part == "*":
+ if path_param_index < len(path_params_values):
+ resource_path_segments.append(path_params_values[path_param_index])
+ path_param_index += 1
+ else:
+ # Use empty string if no path parameter is provided
+ resource_path_segments.append("")
+ else:
+ resource_path_segments.append(part)
+ resource_path = "/".join(resource_path_segments)
+
+ # Ensure resource path starts and ends with '/'
+ resource_path = f"/{resource_path.strip('/')}/"
+
+ # Construct query string parameters
+ query_params = self.target.get("HttpParameters", {}).get("QueryStringParameters", {})
+ query_string = urlencode(query_params) if query_params else ""
+
+ # Construct headers
+ headers = self.target.get("HttpParameters", {}).get("HeaderParameters", {})
+ headers = {k: v for k, v in headers.items() if k.lower() not in self.PROHIBITED_HEADERS}
+ # Add Host header to ensure proper routing in LocalStack
+
+ host = f"{api_id}.execute-api.localhost.localstack.cloud"
+ headers["Host"] = host
+
+ # Ensure Content-Type is set
+ headers.setdefault("Content-Type", "application/json")
+
+ # Construct the full URL
+ resource_path = f"/{resource_path.strip('/')}/"
+
+ # Construct the full URL using urljoin
+ from urllib.parse import urljoin
+
+ base_url = config.internal_service_url()
+ base_path = f"/{stage_name}"
+ full_path = urljoin(base_path + "/", resource_path.lstrip("/"))
+ url = urljoin(base_url + "/", full_path.lstrip("/"))
+
+ if query_string:
+ url += f"?{query_string}"
+
+ # Serialize the event, converting datetime objects to strings
+ event_json = json.dumps(event, default=str)
+
+ # Send the HTTP request
+ response = requests.request(
+ method=http_method, url=url, headers=headers, data=event_json, timeout=5
+ )
+ if not response.ok:
+ LOG.warning(
+ "API Gateway target invocation failed with status code %s, response: %s",
+ response.status_code,
+ response.text,
+ )
+
+ def _validate_input(self, target: Target):
+ super()._validate_input(target)
+ # TODO: cover via test
+ # if not collections.get_safe(target, "$.RoleArn"):
+ # raise ValueError("RoleArn is required for ApiGateway target")
+
+ def _get_predefined_template_replacements(self, event: Dict[str, Any]) -> Dict[str, Any]:
+ """Extracts predefined values from the event."""
+ predefined_template_replacements = {}
+ predefined_template_replacements["aws.events.rule-arn"] = self.rule_arn
+ predefined_template_replacements["aws.events.rule-name"] = self.rule_name
+ predefined_template_replacements["aws.events.event.ingestion-time"] = event.get("time", "")
+ predefined_template_replacements["aws.events.event"] = {
+ "detailType" if k == "detail-type" else k: v for k, v in event.items() if k != "detail"
+ }
+ predefined_template_replacements["aws.events.event.json"] = event
+
+ return predefined_template_replacements
+
+
+class AppSyncTargetSender(TargetSender):
+ def send_event(self, event):
+ raise NotImplementedError("AppSync target is not yet implemented")
+
+
+class BatchTargetSender(TargetSender):
+ def send_event(self, event):
+ raise NotImplementedError("Batch target is not yet implemented")
+
+ def _validate_input(self, target: Target):
+ # TODO: cover via test and fix (only required if we have BatchParameters)
+ # if not collections.get_safe(target, "$.BatchParameters.JobDefinition"):
+ # raise ValueError("BatchParameters.JobDefinition is required for Batch target")
+ # if not collections.get_safe(target, "$.BatchParameters.JobName"):
+ # raise ValueError("BatchParameters.JobName is required for Batch target")
+ pass
+
+
+class ECSTargetSender(TargetSender):
+ def send_event(self, event):
+ raise NotImplementedError("ECS target is a pro feature, please use LocalStack Pro")
+
+ def _validate_input(self, target: Target):
+ super()._validate_input(target)
+ # TODO: cover via test
+ # if not collections.get_safe(target, "$.EcsParameters.TaskDefinitionArn"):
+ # raise ValueError("EcsParameters.TaskDefinitionArn is required for ECS target")
+
+
+class EventsTargetSender(TargetSender):
+ def send_event(self, event):
+ # TODO add validation and tests for eventbridge to eventbridge requires Detail, DetailType, and Source
+ # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/events/client/put_events.html
+ source = self._get_source(event)
+ detail_type = self._get_detail_type(event)
+ detail = event.get("detail", event)
+ resources = self._get_resources(event)
+ entries = [
+ {
+ "EventBusName": self.target["Arn"], # use arn for target account and region
+ "Source": source,
+ "DetailType": detail_type,
+ "Detail": json.dumps(detail),
+ "Resources": resources,
+ }
+ ]
+ if encoded_original_id := get_trace_header_encoded_region_account(
+ event, self.region, self.account_id, self.target_region, self.target_account_id
+ ):
+ entries[0]["TraceHeader"] = encoded_original_id
+ self.client.put_events(Entries=entries)
+
+ def _get_source(self, event: FormattedEvent | TransformedEvent) -> str:
+ if isinstance(event, dict) and (source := event.get("source")):
+ return source
+ else:
+ return self.service or ""
+
+ def _get_detail_type(self, event: FormattedEvent | TransformedEvent) -> str:
+ if isinstance(event, dict) and (detail_type := event.get("detail-type")):
+ return detail_type
+ else:
+ return ""
+
+ def _get_resources(self, event: FormattedEvent | TransformedEvent) -> list[str]:
+ if isinstance(event, dict) and (resources := event.get("resources")):
+ return resources
+ else:
+ return []
+
+
+class EventsApiDestinationTargetSender(TargetSender):
+ def send_event(self, event):
+ """Send an event to an EventBridge API destination
+ See https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-api-destinations.html"""
+ target_arn = self.target["Arn"]
+ target_region = extract_region_from_arn(target_arn)
+ target_account_id = extract_account_id_from_arn(target_arn)
+ api_destination_name = target_arn.split(":")[-1].split("/")[1]
+
+ events_client = connect_to(
+ aws_access_key_id=target_account_id, region_name=target_region
+ ).events
+ destination = events_client.describe_api_destination(Name=api_destination_name)
+
+ # get destination endpoint details
+ method = destination.get("HttpMethod", "GET")
+ endpoint = destination.get("InvocationEndpoint")
+ state = destination.get("ApiDestinationState") or "ACTIVE"
+
+ LOG.debug(
+ 'Calling EventBridge API destination (state "%s"): %s %s', state, method, endpoint
+ )
+ headers = {
+ # default headers AWS sends with every api destination call
+ "User-Agent": "Amazon/EventBridge/ApiDestinations",
+ "Content-Type": "application/json; charset=utf-8",
+ "Range": "bytes=0-1048575",
+ "Accept-Encoding": "gzip,deflate",
+ "Connection": "close",
+ }
+
+ endpoint = add_api_destination_authorization(destination, headers, event)
+ if http_parameters := self.target.get("HttpParameters"):
+ endpoint = add_target_http_parameters(http_parameters, endpoint, headers, event)
+
+ result = requests.request(
+ method=method, url=endpoint, data=json.dumps(event or {}), headers=headers
+ )
+ if result.status_code >= 400:
+ LOG.debug(
+ "Received code %s forwarding events: %s %s", result.status_code, method, endpoint
+ )
+ if result.status_code == 429 or 500 <= result.status_code <= 600:
+ pass # TODO: retry logic (only retry on 429 and 5xx response status)
+
+
+class FirehoseTargetSender(TargetSender):
+ def send_event(self, event):
+ delivery_stream_name = firehose_name(self.target["Arn"])
+ self.client.put_record(
+ DeliveryStreamName=delivery_stream_name,
+ Record={"Data": to_bytes(to_json_str(event))},
+ )
+
+
+class KinesisTargetSender(TargetSender):
+ def send_event(self, event):
+ partition_key_path = collections.get_safe(
+ self.target,
+ "$.KinesisParameters.PartitionKeyPath",
+ default_value="$.id",
+ )
+ stream_name = self.target["Arn"].split("/")[-1]
+ partition_key = collections.get_safe(event, partition_key_path, event["id"])
+ self.client.put_record(
+ StreamName=stream_name,
+ Data=to_bytes(to_json_str(event)),
+ PartitionKey=partition_key,
+ )
+
+ def _validate_input(self, target: Target):
+ super()._validate_input(target)
+ # TODO: cover via tests
+ # if not collections.get_safe(target, "$.RoleArn"):
+ # raise ValueError("RoleArn is required for Kinesis target")
+ # if not collections.get_safe(target, "$.KinesisParameters.PartitionKeyPath"):
+ # raise ValueError("KinesisParameters.PartitionKeyPath is required for Kinesis target")
+
+
+class LambdaTargetSender(TargetSender):
+ def send_event(self, event):
+ self.client.invoke(
+ FunctionName=self.target["Arn"],
+ Payload=to_bytes(to_json_str(event)),
+ InvocationType="Event",
+ )
+
+
+class LogsTargetSender(TargetSender):
+ def send_event(self, event):
+ log_group_name = self.target["Arn"].split(":")[6]
+ log_stream_name = str(uuid.uuid4()) # Unique log stream name
+ self.client.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name)
+ self.client.put_log_events(
+ logGroupName=log_group_name,
+ logStreamName=log_stream_name,
+ logEvents=[
+ {
+ "timestamp": now_utc(millis=True),
+ "message": to_json_str(event),
+ }
+ ],
+ )
+
+
+class RedshiftTargetSender(TargetSender):
+ def send_event(self, event):
+ raise NotImplementedError("Redshift target is not yet implemented")
+
+ def _validate_input(self, target: Target):
+ super()._validate_input(target)
+ # TODO: cover via test
+ # if not collections.get_safe(target, "$.RedshiftDataParameters.Database"):
+ # raise ValueError("RedshiftDataParameters.Database is required for Redshift target")
+
+
+class SagemakerTargetSender(TargetSender):
+ def send_event(self, event):
+ raise NotImplementedError("Sagemaker target is not yet implemented")
+
+
+class SnsTargetSender(TargetSender):
+ def send_event(self, event):
+ self.client.publish(TopicArn=self.target["Arn"], Message=to_json_str(event))
+
+
+class SqsTargetSender(TargetSender):
+ def send_event(self, event):
+ queue_url = sqs_queue_url_for_arn(self.target["Arn"])
+ msg_group_id = self.target.get("SqsParameters", {}).get("MessageGroupId", None)
+ kwargs = {"MessageGroupId": msg_group_id} if msg_group_id else {}
+ self.client.send_message(
+ QueueUrl=queue_url,
+ MessageBody=to_json_str(event),
+ **kwargs,
+ )
+
+
+class StatesTargetSender(TargetSender):
+ """Step Functions Target Sender"""
+
+ def send_event(self, event):
+ self.service = "stepfunctions"
+ self.client.start_execution(
+ stateMachineArn=self.target["Arn"], name=event["id"], input=to_json_str(event)
+ )
+
+ def _validate_input(self, target: Target):
+ super()._validate_input(target)
+ # TODO: cover via test
+ # if not collections.get_safe(target, "$.RoleArn"):
+ # raise ValueError("RoleArn is required for StepFunctions target")
+
+
+class SystemsManagerSender(TargetSender):
+ """EC2 Run Command Target Sender"""
+
+ def send_event(self, event):
+ raise NotImplementedError("Systems Manager target is not yet implemented")
+
+ def _validate_input(self, target: Target):
+ super()._validate_input(target)
+ # TODO: cover via test
+ # if not collections.get_safe(target, "$.RoleArn"):
+ # raise ValueError(
+ # "RoleArn is required for SystemManager target to invoke a EC2 run command"
+ # )
+ # if not collections.get_safe(target, "$.RunCommandParameters.RunCommandTargets"):
+ # raise ValueError(
+ # "RunCommandParameters.RunCommandTargets is required for Systems Manager target"
+ # )
+
+
+class TargetSenderFactory:
+ # supported targets: https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-targets.html
+ target: Target
+ rule_arn: Arn
+ rule_name: RuleName
+ region: str
+ account_id: str
+
+ target_map = {
+ "apigateway": ApiGatewayTargetSender,
+ "appsync": AppSyncTargetSender,
+ "batch": BatchTargetSender,
+ "ecs": ECSTargetSender,
+ "events": EventsTargetSender,
+ "events_api_destination": EventsApiDestinationTargetSender,
+ "firehose": FirehoseTargetSender,
+ "kinesis": KinesisTargetSender,
+ "lambda": LambdaTargetSender,
+ "logs": LogsTargetSender,
+ "redshift": RedshiftTargetSender,
+ "sns": SnsTargetSender,
+ "sqs": SqsTargetSender,
+ "sagemaker": SagemakerTargetSender,
+ "ssm": SystemsManagerSender,
+ "states": StatesTargetSender,
+ "execute-api": ApiGatewayTargetSender,
+ # TODO custom endpoints via http target
+ }
+
+ def __init__(
+ self, target: Target, rule_arn: Arn, rule_name: RuleName, region: str, account_id: str
+ ):
+ self.target = target
+ self.rule_arn = rule_arn
+ self.rule_name = rule_name
+ self.region = region
+ self.account_id = account_id
+
+ @classmethod
+ def register_target_sender(cls, service_name: str, sender_class: Type[TargetSender]):
+ cls.target_map[service_name] = sender_class
+
+ def get_target_sender(self) -> TargetSender:
+ target_arn = self.target["Arn"]
+ service = extract_service_from_arn(target_arn)
+ if ":api-destination/" in target_arn or ":destination/" in target_arn:
+ service = "events_api_destination"
+ if service in self.target_map:
+ target_sender_class = self.target_map[service]
+ else:
+ raise Exception(f"Unsupported target for Service: {service}")
+ target_sender = target_sender_class(
+ self.target, self.rule_arn, self.rule_name, service, self.region, self.account_id
+ )
+ return target_sender
diff --git a/localstack-core/localstack/services/events/usage.py b/localstack-core/localstack/services/events/usage.py
new file mode 100644
index 0000000000000..fa51d185ce76e
--- /dev/null
+++ b/localstack-core/localstack/services/events/usage.py
@@ -0,0 +1,7 @@
+from localstack.utils.analytics.usage import UsageSetCounter
+
+# number of successful EventBridge rule invocations per target (e.g., aws:lambda)
+rule_invocation = UsageSetCounter("events:rule:invocation")
+
+# number of EventBridge rule errors per target (e.g., aws:lambda)
+rule_error = UsageSetCounter("events:rule:error")
diff --git a/localstack-core/localstack/services/events/utils.py b/localstack-core/localstack/services/events/utils.py
new file mode 100644
index 0000000000000..36258ac668acb
--- /dev/null
+++ b/localstack-core/localstack/services/events/utils.py
@@ -0,0 +1,295 @@
+import json
+import logging
+import re
+from datetime import datetime, timezone
+from typing import Any, Dict, Optional
+
+from botocore.utils import ArnParser
+
+from localstack.aws.api import RequestContext
+from localstack.aws.api.events import (
+ ArchiveName,
+ Arn,
+ ConnectionArn,
+ ConnectionName,
+ EventBusName,
+ EventBusNameOrArn,
+ EventTime,
+ PutEventsRequestEntry,
+ RuleArn,
+ Timestamp,
+)
+from localstack.services.events.models import (
+ FormattedEvent,
+ ResourceType,
+ TransformedEvent,
+ ValidationException,
+)
+from localstack.utils.aws.arns import ARN_PARTITION_REGEX, parse_arn
+from localstack.utils.strings import long_uid
+
+LOG = logging.getLogger(__name__)
+
+RULE_ARN_CUSTOM_EVENT_BUS_PATTERN = re.compile(
+ rf"{ARN_PARTITION_REGEX}:events:[a-z0-9-]+:\d{{12}}:rule/[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+$"
+)
+
+RULE_ARN_ARCHIVE_PATTERN = re.compile(
+ rf"{ARN_PARTITION_REGEX}:events:[a-z0-9-]+:\d{{12}}:archive/[a-zA-Z0-9_-]+$"
+)
+ARCHIVE_NAME_ARN_PATTERN = re.compile(
+ rf"{ARN_PARTITION_REGEX}:events:[a-z0-9-]+:\d{{12}}:archive/(?P.+)$"
+)
+CONNECTION_NAME_ARN_PATTERN = re.compile(
+ rf"{ARN_PARTITION_REGEX}:events:[a-z0-9-]+:\d{{12}}:connection/(?P[^/]+)/(?P[^/]+)$"
+)
+
+TARGET_ID_PATTERN = re.compile(r"[\.\-_A-Za-z0-9]+")
+
+
+class EventJSONEncoder(json.JSONEncoder):
+ """This json encoder is used to serialize datetime object
+ of a eventbridge event to time strings."""
+
+ def default(self, obj):
+ if isinstance(obj, datetime):
+ return event_time_to_time_string(obj)
+ return super().default(obj)
+
+
+def to_json_str(obj: Any, separators: Optional[tuple[str, str]] = (",", ":")) -> str:
+ return json.dumps(obj, cls=EventJSONEncoder, separators=separators)
+
+
+def extract_region_and_account_id(
+ name_or_arn: EventBusNameOrArn, context: RequestContext
+) -> tuple[str, str]:
+ """Returns the region and account id from the arn,
+ or falls back on the region and account id of the context"""
+ account_id = None
+ region = None
+ if ArnParser.is_arn(name_or_arn):
+ parsed_arn = parse_arn(name_or_arn)
+ region = parsed_arn.get("region")
+ account_id = parsed_arn.get("account")
+ if not account_id or not region:
+ region = context.get("region")
+ account_id = context.get("account_id")
+ return region, account_id
+
+
+def extract_event_bus_name(
+ resource_arn_or_name: EventBusNameOrArn | RuleArn | None,
+) -> EventBusName:
+ """Return the event bus name. Input can be either an event bus name or ARN."""
+ if not resource_arn_or_name:
+ return "default"
+ if not re.match(f"{ARN_PARTITION_REGEX}:events", resource_arn_or_name):
+ return resource_arn_or_name
+ resource_type = get_resource_type(resource_arn_or_name)
+ if resource_type == ResourceType.EVENT_BUS:
+ return resource_arn_or_name.split("/")[-1]
+ if resource_type == ResourceType.RULE:
+ if bool(RULE_ARN_CUSTOM_EVENT_BUS_PATTERN.match(resource_arn_or_name)):
+ return resource_arn_or_name.split("rule/", 1)[1].split("/", 1)[0]
+ return "default"
+
+
+def extract_connection_name(
+ connection_arn: ConnectionArn,
+) -> ConnectionName:
+ match = CONNECTION_NAME_ARN_PATTERN.match(connection_arn)
+ if not match:
+ raise ValidationException(
+ f"Parameter {connection_arn} is not valid. Reason: Provided Arn is not in correct format."
+ )
+ return match.group("name")
+
+
+def extract_archive_name(arn: Arn) -> ArchiveName:
+ match = ARCHIVE_NAME_ARN_PATTERN.match(arn)
+ if not match:
+ raise ValidationException(
+ f"Parameter {arn} is not valid. Reason: Provided Arn is not in correct format."
+ )
+ return match.group("name")
+
+
+def is_archive_arn(arn: Arn) -> bool:
+ return bool(RULE_ARN_ARCHIVE_PATTERN.match(arn))
+
+
+def get_resource_type(arn: Arn) -> ResourceType:
+ parsed_arn = parse_arn(arn)
+ resource_type = parsed_arn["resource"].split("/", 1)[0]
+ if resource_type == "event-bus":
+ return ResourceType.EVENT_BUS
+ if resource_type == "rule":
+ return ResourceType.RULE
+ raise ValidationException(
+ f"Parameter {arn} is not valid. Reason: Provided Arn is not in correct format."
+ )
+
+
+def get_event_time(event: PutEventsRequestEntry) -> EventTime:
+ event_time = datetime.now(timezone.utc)
+ if event_timestamp := event.get("Time"):
+ try:
+ # use time from event if provided
+ event_time = event_timestamp.replace(tzinfo=timezone.utc)
+ except ValueError:
+ # use current time if event time is invalid
+ LOG.debug(
+ "Could not parse the `Time` parameter, falling back to current time for the following Event: '%s'",
+ event,
+ )
+ return event_time
+
+
+def event_time_to_time_string(event_time: EventTime) -> str:
+ return event_time.strftime("%Y-%m-%dT%H:%M:%SZ")
+
+
+def convert_to_timezone_aware_datetime(
+ timestamp: Timestamp,
+) -> Timestamp:
+ if timestamp.tzinfo is None:
+ timestamp = timestamp.replace(tzinfo=timezone.utc)
+ return timestamp
+
+
+def recursive_remove_none_values_from_dict(d: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Recursively removes keys with non values from a dictionary.
+ """
+ if not isinstance(d, dict):
+ return d
+
+ clean_dict = {}
+ for key, value in d.items():
+ if value is None:
+ continue
+ if isinstance(value, list):
+ nested_list = [recursive_remove_none_values_from_dict(item) for item in value]
+ nested_list = [item for item in nested_list if item]
+ if nested_list:
+ clean_dict[key] = nested_list
+ elif isinstance(value, dict):
+ nested_dict = recursive_remove_none_values_from_dict(value)
+ if nested_dict:
+ clean_dict[key] = nested_dict
+ else:
+ clean_dict[key] = value
+ return clean_dict
+
+
+def format_event(
+ event: PutEventsRequestEntry, region: str, account_id: str, event_bus_name: EventBusName
+) -> FormattedEvent:
+ # See https://docs.aws.amazon.com/AmazonS3/latest/userguide/ev-events.html
+ trace_header = event.get("TraceHeader")
+ message = {}
+ if trace_header:
+ try:
+ message = json.loads(trace_header)
+ except json.JSONDecodeError:
+ pass
+ message_id = message.get("original_id", str(long_uid()))
+ region = message.get("original_region", region)
+ account_id = message.get("original_account", account_id)
+ # Format the datetime to ISO-8601 string
+ event_time = get_event_time(event)
+ formatted_time = event_time_to_time_string(event_time)
+
+ formatted_event = {
+ "version": "0",
+ "id": message_id,
+ "detail-type": event.get("DetailType"),
+ "source": event.get("Source"),
+ "account": account_id,
+ "time": formatted_time,
+ "region": region,
+ "resources": event.get("Resources", []),
+ "detail": json.loads(event.get("Detail", "{}")),
+ "event-bus-name": event_bus_name, # current workaround for EventStudio extension
+ }
+ if replay_name := event.get("ReplayName"):
+ formatted_event["replay-name"] = replay_name # required for replay from archive
+
+ return formatted_event
+
+
+def re_format_event(event: FormattedEvent, event_bus_name: EventBusName) -> PutEventsRequestEntry:
+ """Transforms the event to the original event structure."""
+ re_formatted_event = {
+ "Source": event["source"],
+ "DetailType": event[
+ "detail-type"
+ ], # detail_type automatically interpreted as detail-type in typedict
+ "Detail": json.dumps(event["detail"]),
+ "Time": event["time"],
+ }
+ if event.get("resources"):
+ re_formatted_event["Resources"] = event["resources"]
+ if event_bus_name:
+ re_formatted_event["EventBusName"] = event_bus_name
+ if event.get("replay-name"):
+ re_formatted_event["ReplayName"] = event["replay_name"]
+ return re_formatted_event
+
+
+def get_trace_header_encoded_region_account(
+ event: PutEventsRequestEntry | FormattedEvent | TransformedEvent,
+ source_region: str,
+ source_account_id: str,
+ target_region: str,
+ target_account_id: str,
+) -> str | None:
+ """Encode the original region and account_id for cross-region and cross-account
+ event bus communication in the trace header. For event bus to event bus communication
+ in a different account the event id is preserved. This is not the case if the region differs."""
+ if event.get("TraceHeader"):
+ return None
+ if source_region != target_region and source_account_id != target_account_id:
+ return json.dumps(
+ {
+ "original_region": source_region,
+ "original_account": source_account_id,
+ }
+ )
+ if source_region != target_region:
+ return json.dumps({"original_region": source_region})
+ if source_account_id != target_account_id:
+ if original_id := event.get("id"):
+ return json.dumps({"original_id": original_id, "original_account": source_account_id})
+ else:
+ return json.dumps({"original_account": source_account_id})
+
+
+def is_nested_in_string(template: str, match: re.Match[str]) -> bool:
+ """
+ Determines if a match (string) is within quotes in the given template.
+
+ Examples:
+ True for "users-service/users/" # nested within larger string
+ True for "" # simple quoted placeholder
+ True for "Hello " # nested within larger string
+ False for {"id": } # not in quotes at all
+ """
+ start = match.start()
+ end = match.end()
+
+ left_quote = template.rfind('"', 0, start)
+ right_quote = template.find('"', end)
+ next_comma = template.find(",", end)
+ next_brace = template.find("}", end)
+
+ # If no right quote, or if comma/brace comes before right quote, not nested
+ if (
+ right_quote == -1
+ or (next_comma != -1 and next_comma < right_quote)
+ or (next_brace != -1 and next_brace < right_quote)
+ ):
+ return False
+
+ return left_quote != -1
diff --git a/localstack-core/localstack/services/events/v1/__init__.py b/localstack-core/localstack/services/events/v1/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/events/v1/models.py b/localstack-core/localstack/services/events/v1/models.py
new file mode 100644
index 0000000000000..4096215c82499
--- /dev/null
+++ b/localstack-core/localstack/services/events/v1/models.py
@@ -0,0 +1,11 @@
+from typing import Dict
+
+from localstack.services.stores import AccountRegionBundle, BaseStore, LocalAttribute
+
+
+class EventsStore(BaseStore):
+ # maps rule name to job_id
+ rule_scheduled_jobs: Dict[str, str] = LocalAttribute(default=dict)
+
+
+events_stores = AccountRegionBundle("events", EventsStore)
diff --git a/localstack-core/localstack/services/events/v1/provider.py b/localstack-core/localstack/services/events/v1/provider.py
new file mode 100644
index 0000000000000..bbcd4e0ac33eb
--- /dev/null
+++ b/localstack-core/localstack/services/events/v1/provider.py
@@ -0,0 +1,533 @@
+import datetime
+import json
+import logging
+import os
+import re
+import time
+from typing import Any, Dict, Optional
+
+from moto.events import events_backends
+from moto.events.responses import EventsHandler as MotoEventsHandler
+from werkzeug import Request
+from werkzeug.exceptions import NotFound
+
+from localstack import config
+from localstack.aws.api import RequestContext
+from localstack.aws.api.core import CommonServiceException, ServiceException
+from localstack.aws.api.events import (
+ Boolean,
+ ConnectionAuthorizationType,
+ ConnectionDescription,
+ ConnectionName,
+ ConnectivityResourceParameters,
+ CreateConnectionAuthRequestParameters,
+ CreateConnectionResponse,
+ EventBusNameOrArn,
+ EventPattern,
+ EventsApi,
+ PutRuleResponse,
+ PutTargetsResponse,
+ RoleArn,
+ RuleDescription,
+ RuleName,
+ RuleState,
+ ScheduleExpression,
+ String,
+ TagList,
+ TargetList,
+ TestEventPatternResponse,
+)
+from localstack.constants import APPLICATION_AMZ_JSON_1_1
+from localstack.http import route
+from localstack.services.edge import ROUTER
+from localstack.services.events.scheduler import JobScheduler
+from localstack.services.events.v1.models import EventsStore, events_stores
+from localstack.services.moto import call_moto
+from localstack.services.plugins import ServiceLifecycleHook
+from localstack.utils.aws.arns import event_bus_arn, parse_arn
+from localstack.utils.aws.client_types import ServicePrincipal
+from localstack.utils.aws.message_forwarding import send_event_to_target
+from localstack.utils.collections import pick_attributes
+from localstack.utils.common import TMP_FILES, mkdir, save_file, truncate
+from localstack.utils.event_matcher import matches_event
+from localstack.utils.json import extract_jsonpath
+from localstack.utils.strings import long_uid, short_uid
+from localstack.utils.time import TIMESTAMP_FORMAT_TZ, timestamp
+
+LOG = logging.getLogger(__name__)
+
+# list of events used to run assertions during integration testing (not exposed to the user)
+TEST_EVENTS_CACHE = []
+EVENTS_TMP_DIR = "cw_events"
+DEFAULT_EVENT_BUS_NAME = "default"
+CONNECTION_NAME_PATTERN = re.compile("^[\\.\\-_A-Za-z0-9]+$")
+
+
+class ValidationException(ServiceException):
+ code: str = "ValidationException"
+ sender_fault: bool = True
+ status_code: int = 400
+
+
+class EventsProvider(EventsApi, ServiceLifecycleHook):
+ def __init__(self):
+ apply_patches()
+
+ def on_after_init(self):
+ ROUTER.add(self.trigger_scheduled_rule)
+
+ def on_before_start(self):
+ JobScheduler.start()
+
+ def on_before_stop(self):
+ JobScheduler.shutdown()
+
+ @route("/_aws/events/rules//trigger")
+ def trigger_scheduled_rule(self, request: Request, rule_arn: str):
+ """Developer endpoint to trigger a scheduled rule."""
+ arn_data = parse_arn(rule_arn)
+ account_id = arn_data["account"]
+ region = arn_data["region"]
+ rule_name = arn_data["resource"].split("/", maxsplit=1)[-1]
+
+ job_id = events_stores[account_id][region].rule_scheduled_jobs.get(rule_name)
+ if not job_id:
+ raise NotFound()
+ job = JobScheduler().instance().get_job(job_id)
+ if not job:
+ raise NotFound()
+
+ # TODO: once job scheduler is refactored, we can update the deadline of the task instead of running
+ # it here
+ job.run()
+
+ @staticmethod
+ def get_store(context: RequestContext) -> EventsStore:
+ return events_stores[context.account_id][context.region]
+
+ def test_event_pattern(
+ self, context: RequestContext, event_pattern: EventPattern, event: String, **kwargs
+ ) -> TestEventPatternResponse:
+ """Test event pattern uses EventBridge event pattern matching:
+ https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-event-patterns.html
+ """
+ result = matches_event(event_pattern, event)
+ return TestEventPatternResponse(Result=result)
+
+ @staticmethod
+ def get_scheduled_rule_func(
+ store: EventsStore,
+ rule_name: RuleName,
+ event_bus_name_or_arn: Optional[EventBusNameOrArn] = None,
+ ):
+ def func(*args, **kwargs):
+ account_id = store._account_id
+ region = store._region_name
+ moto_backend = events_backends[account_id][region]
+ event_bus_name = get_event_bus_name(event_bus_name_or_arn)
+ event_bus = moto_backend.event_buses[event_bus_name]
+ rule = event_bus.rules.get(rule_name)
+ if not rule:
+ LOG.info("Unable to find rule `%s` for event bus `%s`", rule_name, event_bus_name)
+ return
+ if rule.targets:
+ LOG.debug(
+ "Notifying %s targets in response to triggered Events rule %s",
+ len(rule.targets),
+ rule_name,
+ )
+
+ default_event = {
+ "version": "0",
+ "id": long_uid(),
+ "detail-type": "Scheduled Event",
+ "source": "aws.events",
+ "account": account_id,
+ "time": timestamp(format=TIMESTAMP_FORMAT_TZ),
+ "region": region,
+ "resources": [rule.arn],
+ "detail": {},
+ }
+
+ for target in rule.targets:
+ arn = target.get("Arn")
+
+ if input_ := target.get("Input"):
+ event = json.loads(input_)
+ else:
+ event = default_event
+ if target.get("InputPath"):
+ event = filter_event_with_target_input_path(target, event)
+ if input_transformer := target.get("InputTransformer"):
+ event = process_event_with_input_transformer(input_transformer, event)
+
+ attr = pick_attributes(target, ["$.SqsParameters", "$.KinesisParameters"])
+
+ try:
+ send_event_to_target(
+ arn,
+ event,
+ target_attributes=attr,
+ role=target.get("RoleArn"),
+ target=target,
+ source_arn=rule.arn,
+ source_service=ServicePrincipal.events,
+ )
+ except Exception as e:
+ LOG.info(
+ "Unable to send event notification %s to target %s: %s",
+ truncate(event),
+ target,
+ e,
+ )
+
+ return func
+
+ @staticmethod
+ def convert_schedule_to_cron(schedule):
+ """Convert Events schedule like "cron(0 20 * * ? *)" or "rate(5 minutes)" """
+ cron_regex = r"\s*cron\s*\(([^\)]*)\)\s*"
+ if re.match(cron_regex, schedule):
+ cron = re.sub(cron_regex, r"\1", schedule)
+ return cron
+ rate_regex = r"\s*rate\s*\(([^\)]*)\)\s*"
+ if re.match(rate_regex, schedule):
+ rate = re.sub(rate_regex, r"\1", schedule)
+ value, unit = re.split(r"\s+", rate.strip())
+
+ value = int(value)
+ if value < 1:
+ raise ValueError("Rate value must be larger than 0")
+ # see https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-rate-expressions.html
+ if value == 1 and unit.endswith("s"):
+ raise ValueError("If the value is equal to 1, then the unit must be singular")
+ if value > 1 and not unit.endswith("s"):
+ raise ValueError("If the value is greater than 1, the unit must be plural")
+
+ if "minute" in unit:
+ return "*/%s * * * *" % value
+ if "hour" in unit:
+ return "0 */%s * * *" % value
+ if "day" in unit:
+ return "0 0 */%s * *" % value
+ raise ValueError("Unable to parse events schedule expression: %s" % schedule)
+ return schedule
+
+ @staticmethod
+ def put_rule_job_scheduler(
+ store: EventsStore,
+ name: Optional[RuleName],
+ state: Optional[RuleState],
+ schedule_expression: Optional[ScheduleExpression],
+ event_bus_name_or_arn: Optional[EventBusNameOrArn] = None,
+ ):
+ if not schedule_expression:
+ return
+
+ try:
+ cron = EventsProvider.convert_schedule_to_cron(schedule_expression)
+ except ValueError as e:
+ LOG.error("Error parsing schedule expression: %s", e)
+ raise ValidationException("Parameter ScheduleExpression is not valid.") from e
+
+ job_func = EventsProvider.get_scheduled_rule_func(
+ store, name, event_bus_name_or_arn=event_bus_name_or_arn
+ )
+ LOG.debug("Adding new scheduled Events rule with cron schedule %s", cron)
+
+ enabled = state != "DISABLED"
+ job_id = JobScheduler.instance().add_job(job_func, cron, enabled)
+ rule_scheduled_jobs = store.rule_scheduled_jobs
+ rule_scheduled_jobs[name] = job_id
+
+ def put_rule(
+ self,
+ context: RequestContext,
+ name: RuleName,
+ schedule_expression: ScheduleExpression = None,
+ event_pattern: EventPattern = None,
+ state: RuleState = None,
+ description: RuleDescription = None,
+ role_arn: RoleArn = None,
+ tags: TagList = None,
+ event_bus_name: EventBusNameOrArn = None,
+ **kwargs,
+ ) -> PutRuleResponse:
+ store = self.get_store(context)
+ self.put_rule_job_scheduler(
+ store, name, state, schedule_expression, event_bus_name_or_arn=event_bus_name
+ )
+ return call_moto(context)
+
+ def delete_rule(
+ self,
+ context: RequestContext,
+ name: RuleName,
+ event_bus_name: EventBusNameOrArn = None,
+ force: Boolean = None,
+ **kwargs,
+ ) -> None:
+ rule_scheduled_jobs = self.get_store(context).rule_scheduled_jobs
+ job_id = rule_scheduled_jobs.get(name)
+ if job_id:
+ LOG.debug("Removing scheduled Events: %s | job_id: %s", name, job_id)
+ JobScheduler.instance().cancel_job(job_id=job_id)
+ call_moto(context)
+
+ def disable_rule(
+ self,
+ context: RequestContext,
+ name: RuleName,
+ event_bus_name: EventBusNameOrArn = None,
+ **kwargs,
+ ) -> None:
+ rule_scheduled_jobs = self.get_store(context).rule_scheduled_jobs
+ job_id = rule_scheduled_jobs.get(name)
+ if job_id:
+ LOG.debug("Disabling Rule: %s | job_id: %s", name, job_id)
+ JobScheduler.instance().disable_job(job_id=job_id)
+ call_moto(context)
+
+ def create_connection(
+ self,
+ context: RequestContext,
+ name: ConnectionName,
+ authorization_type: ConnectionAuthorizationType,
+ auth_parameters: CreateConnectionAuthRequestParameters,
+ description: ConnectionDescription = None,
+ invocation_connectivity_parameters: ConnectivityResourceParameters = None,
+ **kwargs,
+ ) -> CreateConnectionResponse:
+ errors = []
+
+ if not CONNECTION_NAME_PATTERN.match(name):
+ error = f"{name} at 'name' failed to satisfy: Member must satisfy regular expression pattern: [\\.\\-_A-Za-z0-9]+"
+ errors.append(error)
+
+ if len(name) > 64:
+ error = f"{name} at 'name' failed to satisfy: Member must have length less than or equal to 64"
+ errors.append(error)
+
+ if authorization_type not in ["BASIC", "API_KEY", "OAUTH_CLIENT_CREDENTIALS"]:
+ error = f"{authorization_type} at 'authorizationType' failed to satisfy: Member must satisfy enum value set: [BASIC, OAUTH_CLIENT_CREDENTIALS, API_KEY]"
+ errors.append(error)
+
+ if len(errors) > 0:
+ error_description = "; ".join(errors)
+ error_plural = "errors" if len(errors) > 1 else "error"
+ errors_amount = len(errors)
+ message = f"{errors_amount} validation {error_plural} detected: {error_description}"
+ raise CommonServiceException(message=message, code="ValidationException")
+
+ return call_moto(context)
+
+ def put_targets(
+ self,
+ context: RequestContext,
+ rule: RuleName,
+ targets: TargetList,
+ event_bus_name: EventBusNameOrArn = None,
+ **kwargs,
+ ) -> PutTargetsResponse:
+ validation_errors = []
+
+ id_regex = re.compile(r"^[\.\-_A-Za-z0-9]+$")
+ for index, target in enumerate(targets):
+ id = target.get("Id")
+ if not id_regex.match(id):
+ validation_errors.append(
+ f"Value '{id}' at 'targets.{index + 1}.member.id' failed to satisfy constraint: Member must satisfy regular expression pattern: [\\.\\-_A-Za-z0-9]+"
+ )
+
+ if len(id) > 64:
+ validation_errors.append(
+ f"Value '{id}' at 'targets.{index + 1}.member.id' failed to satisfy constraint: Member must have length less than or equal to 64"
+ )
+
+ if validation_errors:
+ errors_message = "; ".join(validation_errors)
+ message = f"{len(validation_errors)} validation {'errors' if len(validation_errors) > 1 else 'error'} detected: {errors_message}"
+ raise CommonServiceException(message=message, code="ValidationException")
+
+ return call_moto(context)
+
+
+def _get_events_tmp_dir():
+ return os.path.join(config.dirs.tmp, EVENTS_TMP_DIR)
+
+
+def _create_and_register_temp_dir():
+ tmp_dir = _get_events_tmp_dir()
+ if not os.path.exists(tmp_dir):
+ mkdir(tmp_dir)
+ TMP_FILES.append(tmp_dir)
+ return tmp_dir
+
+
+def _dump_events_to_files(events_with_added_uuid):
+ try:
+ _create_and_register_temp_dir()
+ current_time_millis = int(round(time.time() * 1000))
+ for event in events_with_added_uuid:
+ target = os.path.join(
+ _get_events_tmp_dir(),
+ "%s_%s" % (current_time_millis, event["uuid"]),
+ )
+ save_file(target, json.dumps(event["event"]))
+ except Exception as e:
+ LOG.info("Unable to dump events to tmp dir %s: %s", _get_events_tmp_dir(), e)
+
+
+def filter_event_based_on_event_format(
+ self, rule_name: str, event_bus_name: str, event: dict[str, Any]
+):
+ rule_information = self.events_backend.describe_rule(
+ rule_name, event_bus_arn(event_bus_name, self.current_account, self.region)
+ )
+
+ if not rule_information:
+ LOG.info('Unable to find rule "%s" in backend: %s', rule_name, rule_information)
+ return False
+ if rule_information.event_pattern._pattern:
+ event_pattern = rule_information.event_pattern._pattern
+ if not matches_event(event_pattern, event):
+ return False
+ return True
+
+
+def filter_event_with_target_input_path(target: Dict, event: Dict) -> Dict:
+ input_path = target.get("InputPath")
+ if input_path:
+ event = extract_jsonpath(event, input_path)
+ return event
+
+
+def process_event_with_input_transformer(input_transformer: Dict, event: Dict) -> Dict:
+ """
+ Process the event with the input transformer of the target event,
+ by replacing the message with the populated InputTemplate.
+ docs.aws.amazon.com/eventbridge/latest/userguide/eb-transform-target-input.html
+ """
+ try:
+ input_paths = input_transformer["InputPathsMap"]
+ input_template = input_transformer["InputTemplate"]
+ except KeyError as e:
+ LOG.error("%s key does not exist in input_transformer.", e)
+ raise e
+ for key, path in input_paths.items():
+ value = extract_jsonpath(event, path)
+ if not value:
+ value = ""
+ input_template = input_template.replace(f"<{key}>", value)
+ templated_event = re.sub('"', "", input_template)
+ return templated_event
+
+
+def process_events(event: Dict, targets: list[Dict]):
+ for target in targets:
+ arn = target["Arn"]
+ changed_event = filter_event_with_target_input_path(target, event)
+ if input_transformer := target.get("InputTransformer"):
+ changed_event = process_event_with_input_transformer(input_transformer, changed_event)
+ if target.get("Input"):
+ changed_event = json.loads(target.get("Input"))
+ try:
+ send_event_to_target(
+ arn,
+ changed_event,
+ pick_attributes(target, ["$.SqsParameters", "$.KinesisParameters"]),
+ role=target.get("RoleArn"),
+ target=target,
+ source_service=ServicePrincipal.events,
+ source_arn=target.get("RuleArn"),
+ )
+ except Exception as e:
+ LOG.info(
+ "Unable to send event notification %s to target %s: %s",
+ truncate(event),
+ target,
+ e,
+ )
+
+
+def get_event_bus_name(event_bus_name_or_arn: Optional[EventBusNameOrArn] = None) -> str:
+ event_bus_name_or_arn = event_bus_name_or_arn or DEFAULT_EVENT_BUS_NAME
+ return event_bus_name_or_arn.split("/")[-1]
+
+
+# specific logic for put_events which forwards matching events to target listeners
+def events_handler_put_events(self):
+ entries = self._get_param("Entries")
+
+ # keep track of events for local integration testing
+ if config.is_local_test_mode():
+ TEST_EVENTS_CACHE.extend(entries)
+
+ events = [{"event": event, "uuid": str(long_uid())} for event in entries]
+
+ _dump_events_to_files(events)
+
+ for event_envelope in events:
+ event = event_envelope["event"]
+ event_bus_name = get_event_bus_name(event.get("EventBusName"))
+ event_bus = self.events_backend.event_buses.get(event_bus_name)
+ if not event_bus:
+ continue
+
+ matching_rules = [
+ r
+ for r in event_bus.rules.values()
+ if r.event_bus_name == event_bus_name and not r.scheduled_expression
+ ]
+ if not matching_rules:
+ continue
+
+ event_time = datetime.datetime.utcnow()
+ if event_timestamp := event.get("Time"):
+ try:
+ # if provided, use the time from event
+ event_time = datetime.datetime.utcfromtimestamp(event_timestamp)
+ except ValueError:
+ # if we can't parse it, pass and keep using `utcnow`
+ LOG.debug(
+ "Could not parse the `Time` parameter, falling back to `utcnow` for the following Event: '%s'",
+ event,
+ )
+
+ # See https://docs.aws.amazon.com/AmazonS3/latest/userguide/ev-events.html
+ formatted_event = {
+ "version": "0",
+ "id": event_envelope["uuid"],
+ "detail-type": event.get("DetailType"),
+ "source": event.get("Source"),
+ "account": self.current_account,
+ "time": event_time.strftime("%Y-%m-%dT%H:%M:%SZ"),
+ "region": self.region,
+ "resources": event.get("Resources", []),
+ "detail": json.loads(event.get("Detail", "{}")),
+ }
+
+ targets = []
+ for rule in matching_rules:
+ if filter_event_based_on_event_format(self, rule.name, event_bus_name, formatted_event):
+ rule_targets, _ = self.events_backend.list_targets_by_rule(
+ rule.name, event_bus_arn(event_bus_name, self.current_account, self.region)
+ )
+ targets.extend([{"RuleArn": rule.arn} | target for target in rule_targets])
+ # process event
+ process_events(formatted_event, targets)
+
+ content = {
+ "FailedEntryCount": 0, # TODO: dynamically set proper value when refactoring
+ "Entries": [{"EventId": event["uuid"]} for event in events],
+ }
+
+ self.response_headers.update(
+ {"Content-Type": APPLICATION_AMZ_JSON_1_1, "x-amzn-RequestId": short_uid()}
+ )
+
+ return json.dumps(content), self.response_headers
+
+
+def apply_patches():
+ MotoEventsHandler.put_events = events_handler_put_events
diff --git a/localstack-core/localstack/services/firehose/__init__.py b/localstack-core/localstack/services/firehose/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/firehose/mappers.py b/localstack-core/localstack/services/firehose/mappers.py
new file mode 100644
index 0000000000000..f262db136020a
--- /dev/null
+++ b/localstack-core/localstack/services/firehose/mappers.py
@@ -0,0 +1,172 @@
+from datetime import datetime
+from typing import cast
+
+from localstack.aws.api.firehose import (
+ AmazonopensearchserviceDestinationConfiguration,
+ AmazonopensearchserviceDestinationDescription,
+ AmazonopensearchserviceDestinationUpdate,
+ ElasticsearchDestinationConfiguration,
+ ElasticsearchDestinationDescription,
+ ElasticsearchDestinationUpdate,
+ ExtendedS3DestinationConfiguration,
+ ExtendedS3DestinationDescription,
+ ExtendedS3DestinationUpdate,
+ HttpEndpointDestinationConfiguration,
+ HttpEndpointDestinationDescription,
+ HttpEndpointDestinationUpdate,
+ KinesisStreamSourceConfiguration,
+ KinesisStreamSourceDescription,
+ RedshiftDestinationConfiguration,
+ RedshiftDestinationDescription,
+ S3DestinationConfiguration,
+ S3DestinationDescription,
+ S3DestinationUpdate,
+ SourceDescription,
+ VpcConfigurationDescription,
+)
+
+
+def convert_es_config_to_desc(
+ configuration: ElasticsearchDestinationConfiguration,
+) -> ElasticsearchDestinationDescription:
+ if configuration is not None:
+ # Just take the whole typed dict and typecast it to our target type
+ result = cast(ElasticsearchDestinationDescription, configuration)
+ # Only specifically handle keys which are named differently or their values differ (version and clusterconfig)
+ result["S3DestinationDescription"] = convert_s3_config_to_desc(
+ configuration["S3Configuration"]
+ )
+ if "VpcConfiguration" in configuration:
+ result["VpcConfigurationDescription"] = cast(
+ VpcConfigurationDescription, configuration["VpcConfiguration"]
+ )
+ result.pop("S3Configuration", None)
+ result.pop("VpcConfiguration", None)
+ return result
+
+
+def convert_es_update_to_desc(
+ update: ElasticsearchDestinationUpdate,
+) -> ElasticsearchDestinationDescription:
+ if update is not None:
+ # Just take the whole typed dict and typecast it to our target type
+ result = cast(ElasticsearchDestinationDescription, update)
+ # Only specifically handle keys which are named differently or their values differ (version and clusterconfig)
+ if "S3Update" in update:
+ result["S3DestinationDescription"] = cast(S3DestinationDescription, update["S3Update"])
+ result.pop("S3Update", None)
+ return result
+
+
+def convert_opensearch_config_to_desc(
+ configuration: AmazonopensearchserviceDestinationConfiguration,
+) -> AmazonopensearchserviceDestinationDescription:
+ if configuration is not None:
+ # Just take the whole typed dict and typecast it to our target type
+ result = cast(AmazonopensearchserviceDestinationDescription, configuration)
+ # Only specifically handle keys which are named differently or their values differ (version and clusterconfig)
+ if "S3Configuration" in configuration:
+ result["S3DestinationDescription"] = convert_s3_config_to_desc(
+ configuration["S3Configuration"]
+ )
+ if "VpcConfiguration" in configuration:
+ result["VpcConfigurationDescription"] = cast(
+ VpcConfigurationDescription, configuration["VpcConfiguration"]
+ )
+ result.pop("S3Configuration", None)
+ result.pop("VpcConfiguration", None)
+ return result
+
+
+def convert_opensearch_update_to_desc(
+ update: AmazonopensearchserviceDestinationUpdate,
+) -> AmazonopensearchserviceDestinationDescription:
+ if update is not None:
+ # Just take the whole typed dict and typecast it to our target type
+ result = cast(AmazonopensearchserviceDestinationDescription, update)
+ # Only specifically handle keys which are named differently or their values differ (version and clusterconfig)
+ if "S3Update" in update:
+ result["S3DestinationDescription"] = cast(S3DestinationDescription, update["S3Update"])
+ result.pop("S3Update", None)
+ return result
+
+
+def convert_s3_config_to_desc(
+ configuration: S3DestinationConfiguration,
+) -> S3DestinationDescription:
+ if configuration:
+ return cast(S3DestinationDescription, configuration)
+
+
+def convert_s3_update_to_desc(update: S3DestinationUpdate) -> S3DestinationDescription:
+ if update:
+ return cast(S3DestinationDescription, update)
+
+
+def convert_extended_s3_config_to_desc(
+ configuration: ExtendedS3DestinationConfiguration,
+) -> ExtendedS3DestinationDescription:
+ if configuration:
+ result = cast(ExtendedS3DestinationDescription, configuration)
+ if "S3BackupConfiguration" in configuration:
+ result["S3BackupDescription"] = convert_s3_config_to_desc(
+ configuration["S3BackupConfiguration"]
+ )
+ result.pop("S3BackupConfiguration", None)
+ return result
+
+
+def convert_extended_s3_update_to_desc(
+ update: ExtendedS3DestinationUpdate,
+) -> ExtendedS3DestinationDescription:
+ if update:
+ result = cast(ExtendedS3DestinationDescription, update)
+ if "S3BackupUpdate" in update:
+ result["S3BackupDescription"] = convert_s3_update_to_desc(update["S3BackupUpdate"])
+ result.pop("S3BackupUpdate", None)
+ return result
+
+
+def convert_http_config_to_desc(
+ configuration: HttpEndpointDestinationConfiguration,
+) -> HttpEndpointDestinationDescription:
+ if configuration:
+ result = cast(HttpEndpointDestinationDescription, configuration)
+ if "S3Configuration" in configuration:
+ result["S3DestinationDescription"] = convert_s3_config_to_desc(
+ configuration["S3Configuration"]
+ )
+ result.pop("S3Configuration", None)
+ return result
+
+
+def convert_http_update_to_desc(
+ update: HttpEndpointDestinationUpdate,
+) -> HttpEndpointDestinationDescription:
+ if update:
+ result = cast(HttpEndpointDestinationDescription, update)
+ if "S3Update" in update:
+ result["S3DestinationDescription"] = convert_s3_update_to_desc(update["S3Update"])
+ result.pop("S3Update", None)
+ return result
+
+
+def convert_source_config_to_desc(
+ configuration: KinesisStreamSourceConfiguration,
+) -> SourceDescription:
+ if configuration:
+ result = cast(KinesisStreamSourceDescription, configuration)
+ result["DeliveryStartTimestamp"] = datetime.now()
+ return SourceDescription(KinesisStreamSourceDescription=result)
+
+
+def convert_redshift_config_to_desc(
+ configuration: RedshiftDestinationConfiguration,
+) -> RedshiftDestinationDescription:
+ if configuration is not None:
+ result = cast(RedshiftDestinationDescription, configuration)
+ result["S3DestinationDescription"] = convert_s3_config_to_desc(
+ configuration["S3Configuration"]
+ )
+ result.pop("S3Configuration", None)
+ return result
diff --git a/localstack-core/localstack/services/firehose/models.py b/localstack-core/localstack/services/firehose/models.py
new file mode 100644
index 0000000000000..ef2e395ef9229
--- /dev/null
+++ b/localstack-core/localstack/services/firehose/models.py
@@ -0,0 +1,21 @@
+from typing import Dict
+
+from localstack.aws.api.firehose import DeliveryStreamDescription
+from localstack.services.stores import (
+ AccountRegionBundle,
+ BaseStore,
+ CrossRegionAttribute,
+ LocalAttribute,
+)
+from localstack.utils.tagging import TaggingService
+
+
+class FirehoseStore(BaseStore):
+ # maps delivery stream names to DeliveryStreamDescription
+ delivery_streams: Dict[str, DeliveryStreamDescription] = LocalAttribute(default=dict)
+
+ # static tagging service instance
+ TAGS = CrossRegionAttribute(default=TaggingService)
+
+
+firehose_stores = AccountRegionBundle("firehose", FirehoseStore)
diff --git a/localstack-core/localstack/services/firehose/provider.py b/localstack-core/localstack/services/firehose/provider.py
new file mode 100644
index 0000000000000..6f56dca1ddf03
--- /dev/null
+++ b/localstack-core/localstack/services/firehose/provider.py
@@ -0,0 +1,957 @@
+import base64
+import functools
+import json
+import logging
+import os
+import re
+import threading
+import time
+import uuid
+from datetime import datetime
+from typing import Dict, List
+from urllib.parse import urlparse
+
+import requests
+
+from localstack.aws.api import RequestContext
+from localstack.aws.api.firehose import (
+ AmazonOpenSearchServerlessDestinationConfiguration,
+ AmazonOpenSearchServerlessDestinationUpdate,
+ AmazonopensearchserviceDestinationConfiguration,
+ AmazonopensearchserviceDestinationDescription,
+ AmazonopensearchserviceDestinationUpdate,
+ BooleanObject,
+ CreateDeliveryStreamOutput,
+ DatabaseSourceConfiguration,
+ DeleteDeliveryStreamOutput,
+ DeliveryStreamDescription,
+ DeliveryStreamEncryptionConfigurationInput,
+ DeliveryStreamName,
+ DeliveryStreamStatus,
+ DeliveryStreamType,
+ DeliveryStreamVersionId,
+ DescribeDeliveryStreamInputLimit,
+ DescribeDeliveryStreamOutput,
+ DestinationDescription,
+ DestinationDescriptionList,
+ DestinationId,
+ ElasticsearchDestinationConfiguration,
+ ElasticsearchDestinationDescription,
+ ElasticsearchDestinationUpdate,
+ ElasticsearchS3BackupMode,
+ ExtendedS3DestinationConfiguration,
+ ExtendedS3DestinationUpdate,
+ FirehoseApi,
+ HttpEndpointDestinationConfiguration,
+ HttpEndpointDestinationUpdate,
+ IcebergDestinationConfiguration,
+ IcebergDestinationUpdate,
+ InvalidArgumentException,
+ KinesisStreamSourceConfiguration,
+ ListDeliveryStreamsInputLimit,
+ ListDeliveryStreamsOutput,
+ ListTagsForDeliveryStreamInputLimit,
+ ListTagsForDeliveryStreamOutput,
+ ListTagsForDeliveryStreamOutputTagList,
+ MSKSourceConfiguration,
+ PutRecordBatchOutput,
+ PutRecordBatchRequestEntryList,
+ PutRecordBatchResponseEntry,
+ PutRecordOutput,
+ Record,
+ RedshiftDestinationConfiguration,
+ RedshiftDestinationDescription,
+ RedshiftDestinationUpdate,
+ ResourceNotFoundException,
+ S3DestinationConfiguration,
+ S3DestinationDescription,
+ S3DestinationUpdate,
+ SnowflakeDestinationConfiguration,
+ SnowflakeDestinationUpdate,
+ SplunkDestinationConfiguration,
+ SplunkDestinationUpdate,
+ TagDeliveryStreamInputTagList,
+ TagDeliveryStreamOutput,
+ TagKey,
+ TagKeyList,
+ UntagDeliveryStreamOutput,
+ UpdateDestinationOutput,
+)
+from localstack.aws.connect import connect_to
+from localstack.services.firehose.mappers import (
+ convert_es_config_to_desc,
+ convert_es_update_to_desc,
+ convert_extended_s3_config_to_desc,
+ convert_extended_s3_update_to_desc,
+ convert_http_config_to_desc,
+ convert_http_update_to_desc,
+ convert_opensearch_config_to_desc,
+ convert_opensearch_update_to_desc,
+ convert_redshift_config_to_desc,
+ convert_s3_config_to_desc,
+ convert_s3_update_to_desc,
+ convert_source_config_to_desc,
+)
+from localstack.services.firehose.models import FirehoseStore, firehose_stores
+from localstack.utils.aws.arns import (
+ extract_account_id_from_arn,
+ extract_region_from_arn,
+ firehose_stream_arn,
+ opensearch_domain_name,
+ s3_bucket_name,
+)
+from localstack.utils.aws.client_types import ServicePrincipal
+from localstack.utils.collections import select_from_typed_dict
+from localstack.utils.common import (
+ TIMESTAMP_FORMAT_MICROS,
+ first_char_to_lower,
+ keys_to_lower,
+ now_utc,
+ short_uid,
+ timestamp,
+ to_bytes,
+ to_str,
+ truncate,
+)
+from localstack.utils.kinesis import kinesis_connector
+from localstack.utils.kinesis.kinesis_connector import KinesisProcessorThread
+from localstack.utils.run import run_for_max_seconds
+
+LOG = logging.getLogger(__name__)
+
+# global sequence number counter for Firehose records (these are very large long values in AWS)
+SEQUENCE_NUMBER = 49546986683135544286507457936321625675700192471156785154
+SEQUENCE_NUMBER_MUTEX = threading.RLock()
+
+
+def next_sequence_number() -> int:
+ """Increase and return the next global sequence number."""
+ global SEQUENCE_NUMBER
+ with SEQUENCE_NUMBER_MUTEX:
+ SEQUENCE_NUMBER += 1
+ return SEQUENCE_NUMBER
+
+
+def _get_description_or_raise_not_found(
+ context, delivery_stream_name: str
+) -> DeliveryStreamDescription:
+ store = FirehoseProvider.get_store(context.account_id, context.region)
+ delivery_stream_description = store.delivery_streams.get(delivery_stream_name)
+ if not delivery_stream_description:
+ raise ResourceNotFoundException(
+ f"Firehose {delivery_stream_name} under account {context.account_id} not found."
+ )
+ return delivery_stream_description
+
+
+def get_opensearch_endpoint(domain_arn: str) -> str:
+ """
+ Get an OpenSearch cluster endpoint by describing the cluster associated with the domain_arn
+ :param domain_arn: ARN of the cluster.
+ :returns: cluster endpoint
+ :raises: ValueError if the domain_arn is malformed
+ """
+ account_id = extract_account_id_from_arn(domain_arn)
+ region_name = extract_region_from_arn(domain_arn)
+ if region_name is None:
+ raise ValueError("unable to parse region from opensearch domain ARN")
+ opensearch_client = connect_to(aws_access_key_id=account_id, region_name=region_name).opensearch
+ domain_name = opensearch_domain_name(domain_arn)
+ info = opensearch_client.describe_domain(DomainName=domain_name)
+ base_domain = info["DomainStatus"]["Endpoint"]
+ # Add the URL scheme "http" if it's not set yet. https might not be enabled for all instances
+ # f.e. when the endpoint strategy is PORT or there is a custom opensearch/elasticsearch instance
+ endpoint = base_domain if base_domain.startswith("http") else f"http://{base_domain}"
+ return endpoint
+
+
+def get_search_db_connection(endpoint: str, region_name: str):
+ """
+ Get a connection to an ElasticSearch or OpenSearch DB
+ :param endpoint: cluster endpoint
+ :param region_name: cluster region e.g. us-east-1
+ """
+ from opensearchpy import OpenSearch, RequestsHttpConnection
+ from requests_aws4auth import AWS4Auth
+
+ verify_certs = False
+ use_ssl = False
+ # use ssl?
+ if "https://" in endpoint:
+ use_ssl = True
+ # TODO remove this condition once ssl certs are available for .es.localhost.localstack.cloud domains
+ endpoint_netloc = urlparse(endpoint).netloc
+ if not re.match(r"^.*(localhost(\.localstack\.cloud)?)(:\d+)?$", endpoint_netloc):
+ verify_certs = True
+
+ LOG.debug("Creating ES client with endpoint %s", endpoint)
+ if "AWS_ACCESS_KEY_ID" in os.environ and "AWS_SECRET_ACCESS_KEY" in os.environ:
+ access_key = os.environ.get("AWS_ACCESS_KEY_ID")
+ secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
+ session_token = os.environ.get("AWS_SESSION_TOKEN")
+ awsauth = AWS4Auth(access_key, secret_key, region_name, "es", session_token=session_token)
+ connection_class = RequestsHttpConnection
+ return OpenSearch(
+ hosts=[endpoint],
+ verify_certs=verify_certs,
+ use_ssl=use_ssl,
+ connection_class=connection_class,
+ http_auth=awsauth,
+ )
+ return OpenSearch(hosts=[endpoint], verify_certs=verify_certs, use_ssl=use_ssl)
+
+
+def _drop_keys_in_destination_descriptions_not_in_output_types(
+ destinations: list,
+) -> list[dict]:
+ """For supported destinations, drops the keys in the description not defined in the respective destination description return type"""
+ for destination in destinations:
+ if amazon_open_search_service_destination_description := destination.get(
+ "AmazonopensearchserviceDestinationDescription"
+ ):
+ destination["AmazonopensearchserviceDestinationDescription"] = select_from_typed_dict(
+ AmazonopensearchserviceDestinationDescription,
+ amazon_open_search_service_destination_description,
+ filter=True,
+ )
+ if elasticsearch_destination_description := destination.get(
+ "ElasticsearchDestinationDescription"
+ ):
+ destination["ElasticsearchDestinationDescription"] = select_from_typed_dict(
+ ElasticsearchDestinationDescription,
+ elasticsearch_destination_description,
+ filter=True,
+ )
+ if http_endpoint_destination_description := destination.get(
+ "HttpEndpointDestinationDescription"
+ ):
+ destination["HttpEndpointDestinationDescription"] = select_from_typed_dict(
+ HttpEndpointDestinationConfiguration,
+ http_endpoint_destination_description,
+ filter=True,
+ )
+ if redshift_destination_description := destination.get("RedshiftDestinationDescription"):
+ destination["RedshiftDestinationDescription"] = select_from_typed_dict(
+ RedshiftDestinationDescription,
+ redshift_destination_description,
+ filter=True,
+ )
+ if s3_destination_description := destination.get("S3DestinationDescription"):
+ destination["S3DestinationDescription"] = select_from_typed_dict(
+ S3DestinationDescription, s3_destination_description, filter=True
+ )
+
+ return destinations
+
+
+class FirehoseProvider(FirehoseApi):
+ # maps a delivery_stream_arn to its kinesis thread; the arn encodes account id and region
+ kinesis_listeners: dict[str, KinesisProcessorThread]
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.kinesis_listeners = {}
+
+ @staticmethod
+ def get_store(account_id: str, region_name: str) -> FirehoseStore:
+ return firehose_stores[account_id][region_name]
+
+ def create_delivery_stream(
+ self,
+ context: RequestContext,
+ delivery_stream_name: DeliveryStreamName,
+ delivery_stream_type: DeliveryStreamType = None,
+ kinesis_stream_source_configuration: KinesisStreamSourceConfiguration = None,
+ delivery_stream_encryption_configuration_input: DeliveryStreamEncryptionConfigurationInput = None,
+ s3_destination_configuration: S3DestinationConfiguration = None,
+ extended_s3_destination_configuration: ExtendedS3DestinationConfiguration = None,
+ redshift_destination_configuration: RedshiftDestinationConfiguration = None,
+ elasticsearch_destination_configuration: ElasticsearchDestinationConfiguration = None,
+ amazonopensearchservice_destination_configuration: AmazonopensearchserviceDestinationConfiguration = None,
+ splunk_destination_configuration: SplunkDestinationConfiguration = None,
+ http_endpoint_destination_configuration: HttpEndpointDestinationConfiguration = None,
+ tags: TagDeliveryStreamInputTagList = None,
+ amazon_open_search_serverless_destination_configuration: AmazonOpenSearchServerlessDestinationConfiguration = None,
+ msk_source_configuration: MSKSourceConfiguration = None,
+ snowflake_destination_configuration: SnowflakeDestinationConfiguration = None,
+ iceberg_destination_configuration: IcebergDestinationConfiguration = None,
+ database_source_configuration: DatabaseSourceConfiguration = None,
+ **kwargs,
+ ) -> CreateDeliveryStreamOutput:
+ # TODO add support for database_source_configuration
+ store = self.get_store(context.account_id, context.region)
+
+ destinations: DestinationDescriptionList = []
+ if elasticsearch_destination_configuration:
+ destinations.append(
+ DestinationDescription(
+ DestinationId=short_uid(),
+ ElasticsearchDestinationDescription=convert_es_config_to_desc(
+ elasticsearch_destination_configuration
+ ),
+ )
+ )
+ if amazonopensearchservice_destination_configuration:
+ db_description = convert_opensearch_config_to_desc(
+ amazonopensearchservice_destination_configuration
+ )
+ destinations.append(
+ DestinationDescription(
+ DestinationId=short_uid(),
+ AmazonopensearchserviceDestinationDescription=db_description,
+ )
+ )
+ if s3_destination_configuration or extended_s3_destination_configuration:
+ destinations.append(
+ DestinationDescription(
+ DestinationId=short_uid(),
+ S3DestinationDescription=convert_s3_config_to_desc(
+ s3_destination_configuration
+ ),
+ ExtendedS3DestinationDescription=convert_extended_s3_config_to_desc(
+ extended_s3_destination_configuration
+ ),
+ )
+ )
+ if http_endpoint_destination_configuration:
+ destinations.append(
+ DestinationDescription(
+ DestinationId=short_uid(),
+ HttpEndpointDestinationDescription=convert_http_config_to_desc(
+ http_endpoint_destination_configuration
+ ),
+ )
+ )
+ if splunk_destination_configuration:
+ LOG.warning(
+ "Delivery stream contains a splunk destination (which is currently not supported)."
+ )
+ if redshift_destination_configuration:
+ destinations.append(
+ DestinationDescription(
+ DestinationId=short_uid(),
+ RedshiftDestinationDescription=convert_redshift_config_to_desc(
+ redshift_destination_configuration
+ ),
+ )
+ )
+ if amazon_open_search_serverless_destination_configuration:
+ LOG.warning(
+ "Delivery stream contains a opensearch serverless destination (which is currently not supported)."
+ )
+
+ stream = DeliveryStreamDescription(
+ DeliveryStreamName=delivery_stream_name,
+ DeliveryStreamARN=firehose_stream_arn(
+ stream_name=delivery_stream_name,
+ account_id=context.account_id,
+ region_name=context.region,
+ ),
+ DeliveryStreamStatus=DeliveryStreamStatus.ACTIVE,
+ DeliveryStreamType=delivery_stream_type,
+ HasMoreDestinations=False,
+ VersionId="1",
+ CreateTimestamp=datetime.now(),
+ Destinations=destinations,
+ Source=convert_source_config_to_desc(kinesis_stream_source_configuration),
+ )
+ delivery_stream_arn = stream["DeliveryStreamARN"]
+ store.TAGS.tag_resource(delivery_stream_arn, tags)
+ store.delivery_streams[delivery_stream_name] = stream
+
+ if delivery_stream_type == DeliveryStreamType.KinesisStreamAsSource:
+ if not kinesis_stream_source_configuration:
+ raise InvalidArgumentException("Missing delivery stream configuration")
+ kinesis_stream_arn = kinesis_stream_source_configuration["KinesisStreamARN"]
+ kinesis_stream_name = kinesis_stream_arn.split(":stream/")[1]
+
+ def _startup():
+ stream["DeliveryStreamStatus"] = DeliveryStreamStatus.CREATING
+ try:
+ listener_function = functools.partial(
+ self._process_records,
+ context.account_id,
+ context.region,
+ delivery_stream_name,
+ )
+ process = kinesis_connector.listen_to_kinesis(
+ stream_name=kinesis_stream_name,
+ account_id=context.account_id,
+ region_name=context.region,
+ listener_func=listener_function,
+ wait_until_started=True,
+ ddb_lease_table_suffix=f"-firehose-{delivery_stream_name}",
+ )
+
+ self.kinesis_listeners[delivery_stream_arn] = process
+ stream["DeliveryStreamStatus"] = DeliveryStreamStatus.ACTIVE
+ except Exception as e:
+ LOG.warning(
+ "Unable to create Firehose delivery stream %s: %s",
+ delivery_stream_name,
+ e,
+ )
+ stream["DeliveryStreamStatus"] = DeliveryStreamStatus.CREATING_FAILED
+
+ run_for_max_seconds(25, _startup)
+ return CreateDeliveryStreamOutput(DeliveryStreamARN=stream["DeliveryStreamARN"])
+
+ def delete_delivery_stream(
+ self,
+ context: RequestContext,
+ delivery_stream_name: DeliveryStreamName,
+ allow_force_delete: BooleanObject = None,
+ **kwargs,
+ ) -> DeleteDeliveryStreamOutput:
+ store = self.get_store(context.account_id, context.region)
+ delivery_stream_description = store.delivery_streams.pop(delivery_stream_name, {})
+ if not delivery_stream_description:
+ raise ResourceNotFoundException(
+ f"Firehose {delivery_stream_name} under account {context.account_id} not found."
+ )
+
+ delivery_stream_arn = firehose_stream_arn(
+ stream_name=delivery_stream_name,
+ account_id=context.account_id,
+ region_name=context.region,
+ )
+ if kinesis_process := self.kinesis_listeners.pop(delivery_stream_arn, None):
+ LOG.debug("Stopping kinesis listener for %s", delivery_stream_name)
+ kinesis_process.stop()
+
+ return DeleteDeliveryStreamOutput()
+
+ def describe_delivery_stream(
+ self,
+ context: RequestContext,
+ delivery_stream_name: DeliveryStreamName,
+ limit: DescribeDeliveryStreamInputLimit = None,
+ exclusive_start_destination_id: DestinationId = None,
+ **kwargs,
+ ) -> DescribeDeliveryStreamOutput:
+ delivery_stream_description = _get_description_or_raise_not_found(
+ context, delivery_stream_name
+ )
+ if destinations := delivery_stream_description.get("Destinations"):
+ delivery_stream_description["Destinations"] = (
+ _drop_keys_in_destination_descriptions_not_in_output_types(destinations)
+ )
+
+ return DescribeDeliveryStreamOutput(DeliveryStreamDescription=delivery_stream_description)
+
+ def list_delivery_streams(
+ self,
+ context: RequestContext,
+ limit: ListDeliveryStreamsInputLimit = None,
+ delivery_stream_type: DeliveryStreamType = None,
+ exclusive_start_delivery_stream_name: DeliveryStreamName = None,
+ **kwargs,
+ ) -> ListDeliveryStreamsOutput:
+ store = self.get_store(context.account_id, context.region)
+ delivery_stream_names = []
+ for name, stream in store.delivery_streams.items():
+ delivery_stream_names.append(stream["DeliveryStreamName"])
+ return ListDeliveryStreamsOutput(
+ DeliveryStreamNames=delivery_stream_names, HasMoreDeliveryStreams=False
+ )
+
+ def put_record(
+ self,
+ context: RequestContext,
+ delivery_stream_name: DeliveryStreamName,
+ record: Record,
+ **kwargs,
+ ) -> PutRecordOutput:
+ record = self._reencode_record(record)
+ return self._put_record(context.account_id, context.region, delivery_stream_name, record)
+
+ def put_record_batch(
+ self,
+ context: RequestContext,
+ delivery_stream_name: DeliveryStreamName,
+ records: PutRecordBatchRequestEntryList,
+ **kwargs,
+ ) -> PutRecordBatchOutput:
+ records = self._reencode_records(records)
+ return PutRecordBatchOutput(
+ FailedPutCount=0,
+ RequestResponses=self._put_records(
+ context.account_id, context.region, delivery_stream_name, records
+ ),
+ )
+
+ def tag_delivery_stream(
+ self,
+ context: RequestContext,
+ delivery_stream_name: DeliveryStreamName,
+ tags: TagDeliveryStreamInputTagList,
+ **kwargs,
+ ) -> TagDeliveryStreamOutput:
+ store = self.get_store(context.account_id, context.region)
+ delivery_stream_description = _get_description_or_raise_not_found(
+ context, delivery_stream_name
+ )
+ store.TAGS.tag_resource(delivery_stream_description["DeliveryStreamARN"], tags)
+ return ListTagsForDeliveryStreamOutput()
+
+ def list_tags_for_delivery_stream(
+ self,
+ context: RequestContext,
+ delivery_stream_name: DeliveryStreamName,
+ exclusive_start_tag_key: TagKey = None,
+ limit: ListTagsForDeliveryStreamInputLimit = None,
+ **kwargs,
+ ) -> ListTagsForDeliveryStreamOutput:
+ store = self.get_store(context.account_id, context.region)
+ delivery_stream_description = _get_description_or_raise_not_found(
+ context, delivery_stream_name
+ )
+ # The tagging service returns a dictionary with the given root name
+ tags = store.TAGS.list_tags_for_resource(
+ arn=delivery_stream_description["DeliveryStreamARN"], root_name="root"
+ )
+ # Extract the actual list of tags for the typed response
+ tag_list: ListTagsForDeliveryStreamOutputTagList = tags["root"]
+ return ListTagsForDeliveryStreamOutput(Tags=tag_list, HasMoreTags=False)
+
+ def untag_delivery_stream(
+ self,
+ context: RequestContext,
+ delivery_stream_name: DeliveryStreamName,
+ tag_keys: TagKeyList,
+ **kwargs,
+ ) -> UntagDeliveryStreamOutput:
+ store = self.get_store(context.account_id, context.region)
+ delivery_stream_description = _get_description_or_raise_not_found(
+ context, delivery_stream_name
+ )
+ # The tagging service returns a dictionary with the given root name
+ store.TAGS.untag_resource(
+ arn=delivery_stream_description["DeliveryStreamARN"], tag_names=tag_keys
+ )
+ return UntagDeliveryStreamOutput()
+
+ def update_destination(
+ self,
+ context: RequestContext,
+ delivery_stream_name: DeliveryStreamName,
+ current_delivery_stream_version_id: DeliveryStreamVersionId,
+ destination_id: DestinationId,
+ s3_destination_update: S3DestinationUpdate = None,
+ extended_s3_destination_update: ExtendedS3DestinationUpdate = None,
+ redshift_destination_update: RedshiftDestinationUpdate = None,
+ elasticsearch_destination_update: ElasticsearchDestinationUpdate = None,
+ amazonopensearchservice_destination_update: AmazonopensearchserviceDestinationUpdate = None,
+ splunk_destination_update: SplunkDestinationUpdate = None,
+ http_endpoint_destination_update: HttpEndpointDestinationUpdate = None,
+ amazon_open_search_serverless_destination_update: AmazonOpenSearchServerlessDestinationUpdate = None,
+ snowflake_destination_update: SnowflakeDestinationUpdate = None,
+ iceberg_destination_update: IcebergDestinationUpdate = None,
+ **kwargs,
+ ) -> UpdateDestinationOutput:
+ delivery_stream_description = _get_description_or_raise_not_found(
+ context, delivery_stream_name
+ )
+ destinations = delivery_stream_description["Destinations"]
+ try:
+ destination = next(filter(lambda d: d["DestinationId"] == destination_id, destinations))
+ except StopIteration:
+ destination = DestinationDescription(DestinationId=destination_id)
+ delivery_stream_description["Destinations"].append(destination)
+
+ if elasticsearch_destination_update:
+ destination["ElasticsearchDestinationDescription"] = convert_es_update_to_desc(
+ elasticsearch_destination_update
+ )
+
+ if amazonopensearchservice_destination_update:
+ destination["AmazonopensearchserviceDestinationDescription"] = (
+ convert_opensearch_update_to_desc(amazonopensearchservice_destination_update)
+ )
+
+ if s3_destination_update:
+ destination["S3DestinationDescription"] = convert_s3_update_to_desc(
+ s3_destination_update
+ )
+
+ if extended_s3_destination_update:
+ destination["ExtendedS3DestinationDescription"] = convert_extended_s3_update_to_desc(
+ extended_s3_destination_update
+ )
+
+ if http_endpoint_destination_update:
+ destination["HttpEndpointDestinationDescription"] = convert_http_update_to_desc(
+ http_endpoint_destination_update
+ )
+ # TODO: add feature update redshift destination
+
+ return UpdateDestinationOutput()
+
+ def _reencode_record(self, record: Record) -> Record:
+ """
+ The ASF decodes the record's data automatically. But most of the service integrations (kinesis, lambda, http)
+ are working with the base64 encoded data.
+ """
+ if "Data" in record:
+ record["Data"] = base64.b64encode(record["Data"])
+ return record
+
+ def _reencode_records(self, records: List[Record]) -> List[Record]:
+ return [self._reencode_record(r) for r in records]
+
+ def _process_records(
+ self,
+ account_id: str,
+ region_name: str,
+ fh_d_stream: str,
+ records: List[Record],
+ ):
+ """Process the given records from the underlying Kinesis stream"""
+ return self._put_records(account_id, region_name, fh_d_stream, records)
+
+ def _put_record(
+ self,
+ account_id: str,
+ region_name: str,
+ delivery_stream_name: str,
+ record: Record,
+ ) -> PutRecordOutput:
+ """Put a record to the firehose stream from a PutRecord API call"""
+ result = self._put_records(account_id, region_name, delivery_stream_name, [record])
+ return PutRecordOutput(RecordId=result[0]["RecordId"])
+
+ def _put_records(
+ self,
+ account_id: str,
+ region_name: str,
+ delivery_stream_name: str,
+ unprocessed_records: List[Record],
+ ) -> List[PutRecordBatchResponseEntry]:
+ """Put a list of records to the firehose stream - either directly from a PutRecord API call, or
+ received from an underlying Kinesis stream (if 'KinesisStreamAsSource' is configured)"""
+ store = self.get_store(account_id, region_name)
+ delivery_stream_description = store.delivery_streams.get(delivery_stream_name)
+ if not delivery_stream_description:
+ raise ResourceNotFoundException(
+ f"Firehose {delivery_stream_name} under account {account_id} not found."
+ )
+
+ # preprocess records, add any missing attributes
+ self._add_missing_record_attributes(unprocessed_records)
+
+ for destination in delivery_stream_description.get("Destinations", []):
+ # apply processing steps to incoming items
+ proc_config = {}
+ for child in destination.values():
+ proc_config = (
+ isinstance(child, dict) and child.get("ProcessingConfiguration") or proc_config
+ )
+ records = list(unprocessed_records)
+ if proc_config.get("Enabled") is not False:
+ for processor in proc_config.get("Processors", []):
+ # TODO: run processors asynchronously, to avoid request timeouts on PutRecord API calls
+ records = self._preprocess_records(processor, records)
+
+ if "ElasticsearchDestinationDescription" in destination:
+ self._put_to_search_db(
+ "ElasticSearch",
+ destination["ElasticsearchDestinationDescription"],
+ delivery_stream_name,
+ records,
+ unprocessed_records,
+ region_name,
+ )
+ if "AmazonopensearchserviceDestinationDescription" in destination:
+ self._put_to_search_db(
+ "OpenSearch",
+ destination["AmazonopensearchserviceDestinationDescription"],
+ delivery_stream_name,
+ records,
+ unprocessed_records,
+ region_name,
+ )
+ if "S3DestinationDescription" in destination:
+ s3_dest_desc = (
+ destination["S3DestinationDescription"]
+ or destination["ExtendedS3DestinationDescription"]
+ )
+ self._put_records_to_s3_bucket(delivery_stream_name, records, s3_dest_desc)
+ if "HttpEndpointDestinationDescription" in destination:
+ http_dest = destination["HttpEndpointDestinationDescription"]
+ end_point = http_dest["EndpointConfiguration"]
+ url = end_point["Url"]
+ record_to_send = {
+ "requestId": str(uuid.uuid4()),
+ "timestamp": (int(time.time())),
+ "records": [],
+ }
+ for record in records:
+ data = record.get("Data") or record.get("data")
+ record_to_send["records"].append({"data": to_str(data)})
+ headers = {
+ "Content-Type": "application/json",
+ }
+ try:
+ requests.post(url, json=record_to_send, headers=headers)
+ except Exception as e:
+ LOG.exception("Unable to put Firehose records to HTTP endpoint %s.", url)
+ raise e
+ if "RedshiftDestinationDescription" in destination:
+ s3_dest_desc = destination["RedshiftDestinationDescription"][
+ "S3DestinationDescription"
+ ]
+ self._put_records_to_s3_bucket(delivery_stream_name, records, s3_dest_desc)
+
+ redshift_dest_desc = destination["RedshiftDestinationDescription"]
+ self._put_to_redshift(records, redshift_dest_desc)
+ return [
+ PutRecordBatchResponseEntry(RecordId=str(uuid.uuid4())) for _ in unprocessed_records
+ ]
+
+ def _put_to_search_db(
+ self,
+ db_flavor,
+ db_description,
+ delivery_stream_name,
+ records,
+ unprocessed_records,
+ region_name,
+ ):
+ """
+ sends Firehose records to an ElasticSearch or Opensearch database
+ """
+ search_db_index = db_description["IndexName"]
+ domain_arn = db_description.get("DomainARN")
+ cluster_endpoint = db_description.get("ClusterEndpoint")
+ if cluster_endpoint is None:
+ cluster_endpoint = get_opensearch_endpoint(domain_arn)
+
+ db_connection = get_search_db_connection(cluster_endpoint, region_name)
+
+ if db_description.get("S3BackupMode") == ElasticsearchS3BackupMode.AllDocuments:
+ s3_dest_desc = db_description.get("S3DestinationDescription")
+ if s3_dest_desc:
+ try:
+ self._put_records_to_s3_bucket(
+ stream_name=delivery_stream_name,
+ records=unprocessed_records,
+ s3_destination_description=s3_dest_desc,
+ )
+ except Exception as e:
+ LOG.warning("Unable to backup unprocessed records to S3. Error: %s", e)
+ else:
+ LOG.warning("Passed S3BackupMode without S3Configuration. Cannot backup...")
+ elif db_description.get("S3BackupMode") == ElasticsearchS3BackupMode.FailedDocumentsOnly:
+ # TODO support FailedDocumentsOnly as well
+ LOG.warning("S3BackupMode FailedDocumentsOnly is set but currently not supported.")
+ for record in records:
+ obj_id = uuid.uuid4()
+
+ data = "{}"
+ # DirectPut
+ if "Data" in record:
+ data = base64.b64decode(record["Data"])
+ # KinesisAsSource
+ elif "data" in record:
+ data = base64.b64decode(record["data"])
+
+ try:
+ body = json.loads(data)
+ except Exception as e:
+ LOG.warning("%s only allows json input data!", db_flavor)
+ raise e
+
+ if LOG.isEnabledFor(logging.DEBUG):
+ LOG.debug(
+ "Publishing to %s destination. Data: %s",
+ db_flavor,
+ truncate(data, max_length=300),
+ )
+ try:
+ db_connection.create(index=search_db_index, id=obj_id, body=body)
+ except Exception as e:
+ LOG.exception("Unable to put record to stream %s.", delivery_stream_name)
+ raise e
+
+ def _add_missing_record_attributes(self, records: List[Dict]) -> None:
+ def _get_entry(obj, key):
+ return obj.get(key) or obj.get(first_char_to_lower(key))
+
+ for record in records:
+ if not _get_entry(record, "ApproximateArrivalTimestamp"):
+ record["ApproximateArrivalTimestamp"] = int(now_utc(millis=True))
+ if not _get_entry(record, "KinesisRecordMetadata"):
+ record["kinesisRecordMetadata"] = {
+ "shardId": "shardId-000000000000",
+ # not really documented what AWS is using internally - simply using a random UUID here
+ "partitionKey": str(uuid.uuid4()),
+ "approximateArrivalTimestamp": timestamp(
+ float(_get_entry(record, "ApproximateArrivalTimestamp")) / 1000,
+ format=TIMESTAMP_FORMAT_MICROS,
+ ),
+ "sequenceNumber": next_sequence_number(),
+ "subsequenceNumber": "",
+ }
+
+ def _preprocess_records(self, processor: Dict, records: List[Record]) -> List[Dict]:
+ """Preprocess the list of records by calling the given processor (e.g., Lamnda function)."""
+ proc_type = processor.get("Type")
+ parameters = processor.get("Parameters", [])
+ parameters = {p["ParameterName"]: p["ParameterValue"] for p in parameters}
+ if proc_type == "Lambda":
+ lambda_arn = parameters.get("LambdaArn")
+ # TODO: add support for other parameters, e.g., NumberOfRetries, BufferSizeInMBs, BufferIntervalInSeconds, ...
+ records = keys_to_lower(records)
+ # Convert the record data to string (for json serialization)
+ for record in records:
+ if "data" in record:
+ record["data"] = to_str(record["data"])
+ if "Data" in record:
+ record["Data"] = to_str(record["Data"])
+ event = {"records": records}
+ event = to_bytes(json.dumps(event))
+
+ account_id = extract_account_id_from_arn(lambda_arn)
+ region_name = extract_region_from_arn(lambda_arn)
+ client = connect_to(aws_access_key_id=account_id, region_name=region_name).lambda_
+
+ response = client.invoke(FunctionName=lambda_arn, Payload=event)
+ result = json.load(response["Payload"])
+ records = result.get("records", []) if result else []
+ else:
+ LOG.warning("Unsupported Firehose processor type '%s'", proc_type)
+ return records
+
+ def _put_records_to_s3_bucket(
+ self,
+ stream_name: str,
+ records: List[Dict],
+ s3_destination_description: S3DestinationDescription,
+ ):
+ bucket = s3_bucket_name(s3_destination_description["BucketARN"])
+ prefix = s3_destination_description.get("Prefix", "")
+ file_extension = s3_destination_description.get("FileExtension", "")
+
+ if role_arn := s3_destination_description.get("RoleARN"):
+ factory = connect_to.with_assumed_role(
+ role_arn=role_arn, service_principal=ServicePrincipal.firehose
+ )
+ else:
+ factory = connect_to()
+ s3 = factory.s3.request_metadata(
+ source_arn=stream_name, service_principal=ServicePrincipal.firehose
+ )
+ batched_data = b"".join([base64.b64decode(r.get("Data") or r.get("data")) for r in records])
+
+ obj_path = self._get_s3_object_path(stream_name, prefix, file_extension)
+ try:
+ LOG.debug("Publishing to S3 destination: %s. Data: %s", bucket, batched_data)
+ s3.put_object(Bucket=bucket, Key=obj_path, Body=batched_data)
+ except Exception as e:
+ LOG.exception(
+ "Unable to put records %s to s3 bucket.",
+ records,
+ )
+ raise e
+
+ def _get_s3_object_path(self, stream_name, prefix, file_extension):
+ # See https://aws.amazon.com/kinesis/data-firehose/faqs/#Data_delivery
+ # Path prefix pattern: myApp/YYYY/MM/DD/HH/
+ # Object name pattern: DeliveryStreamName-DeliveryStreamVersion-YYYY-MM-DD-HH-MM-SS-RandomString
+ if not prefix.endswith("/") and prefix != "":
+ prefix = prefix + "/"
+ pattern = "{pre}%Y/%m/%d/%H/{name}-%Y-%m-%d-%H-%M-%S-{rand}"
+ path = pattern.format(pre=prefix, name=stream_name, rand=str(uuid.uuid4()))
+ path = timestamp(format=path)
+
+ if file_extension:
+ path += file_extension
+
+ return path
+
+ def _put_to_redshift(
+ self,
+ records: List[Dict],
+ redshift_destination_description: RedshiftDestinationDescription,
+ ):
+ jdbcurl = redshift_destination_description.get("ClusterJDBCURL")
+ cluster_id = self._get_cluster_id_from_jdbc_url(jdbcurl)
+ db_name = jdbcurl.split("/")[-1]
+ table_name = redshift_destination_description.get("CopyCommand").get("DataTableName")
+
+ rows_to_insert = [self._prepare_records_for_redshift(record) for record in records]
+ columns_placeholder_str = self._extract_columns(records[0])
+ sql_insert_statement = f"INSERT INTO {table_name} VALUES ({columns_placeholder_str})"
+
+ execute_statement = {
+ "Sql": sql_insert_statement,
+ "Database": db_name,
+ "ClusterIdentifier": cluster_id, # cluster_identifier in cluster create
+ }
+
+ role_arn = redshift_destination_description.get("RoleARN")
+ account_id = extract_account_id_from_arn(role_arn)
+ region_name = self._get_region_from_jdbc_url(jdbcurl)
+ redshift_data = connect_to(
+ aws_access_key_id=account_id, region_name=region_name
+ ).redshift_data
+
+ for row_to_insert in rows_to_insert: # redsift_data only allows single row inserts
+ try:
+ LOG.debug(
+ "Publishing to Redshift destination: %s. Data: %s",
+ jdbcurl,
+ row_to_insert,
+ )
+ redshift_data.execute_statement(Parameters=row_to_insert, **execute_statement)
+ except Exception as e:
+ LOG.exception(
+ "Unable to put records %s to redshift cluster.",
+ row_to_insert,
+ )
+ raise e
+
+ def _get_cluster_id_from_jdbc_url(self, jdbc_url: str) -> str:
+ pattern = r"://(.*?)\."
+ match = re.search(pattern, jdbc_url)
+ if match:
+ return match.group(1)
+ else:
+ raise ValueError(f"Unable to extract cluster id from jdbc url: {jdbc_url}")
+
+ def _get_region_from_jdbc_url(self, jdbc_url: str) -> str | None:
+ match = re.search(r"://(?:[^.]+\.){2}([^.]+)\.", jdbc_url)
+ if match:
+ return match.group(1)
+ else:
+ LOG.debug("Cannot extract region from JDBC url '%s'", jdbc_url)
+ return None
+
+ def _decode_record(self, record: Dict) -> Dict:
+ data = base64.b64decode(record.get("Data") or record.get("data"))
+ data = to_str(data)
+ data = json.loads(data)
+ return data
+
+ def _prepare_records_for_redshift(self, record: Dict) -> List[Dict]:
+ data = self._decode_record(record)
+
+ parameters = []
+ for key, value in data.items():
+ if isinstance(value, str):
+ value = value.replace("\t", " ")
+ value = value.replace("\n", " ")
+ elif value is None:
+ value = "NULL"
+ else:
+ value = str(value)
+ parameters.append({"name": key, "value": value})
+ # required to work with execute_statement in community (moto) and ext (localstack native)
+
+ return parameters
+
+ def _extract_columns(self, record: Dict) -> str:
+ data = self._decode_record(record)
+ placeholders = [f":{key}" for key in data]
+ placeholder_str = ", ".join(placeholders)
+ return placeholder_str
diff --git a/localstack-core/localstack/services/iam/__init__.py b/localstack-core/localstack/services/iam/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/iam/iam_patches.py b/localstack-core/localstack/services/iam/iam_patches.py
new file mode 100644
index 0000000000000..5b672ac86059a
--- /dev/null
+++ b/localstack-core/localstack/services/iam/iam_patches.py
@@ -0,0 +1,153 @@
+import threading
+from typing import Dict, List, Optional
+
+from moto.iam.models import (
+ AccessKey,
+ AWSManagedPolicy,
+ IAMBackend,
+ InlinePolicy,
+ Policy,
+)
+from moto.iam.models import Role as MotoRole
+from moto.iam.policy_validation import VALID_STATEMENT_ELEMENTS
+
+from localstack import config
+from localstack.constants import TAG_KEY_CUSTOM_ID
+from localstack.utils.patch import patch
+
+ADDITIONAL_MANAGED_POLICIES = {
+ "AWSLambdaExecute": {
+ "Arn": "arn:aws:iam::aws:policy/AWSLambdaExecute",
+ "Path": "/",
+ "CreateDate": "2017-10-20T17:23:10+00:00",
+ "DefaultVersionId": "v4",
+ "Document": {
+ "Version": "2012-10-17",
+ "Statement": [
+ {
+ "Effect": "Allow",
+ "Action": ["logs:*"],
+ "Resource": "arn:aws:logs:*:*:*",
+ },
+ {
+ "Effect": "Allow",
+ "Action": ["s3:GetObject", "s3:PutObject"],
+ "Resource": "arn:aws:s3:::*",
+ },
+ ],
+ },
+ "UpdateDate": "2019-05-20T18:22:18+00:00",
+ }
+}
+
+IAM_PATCHED = False
+IAM_PATCH_LOCK = threading.RLock()
+
+
+def apply_iam_patches():
+ global IAM_PATCHED
+
+ # prevent patching multiple times, as this is called from both STS and IAM (for now)
+ with IAM_PATCH_LOCK:
+ if IAM_PATCHED:
+ return
+
+ IAM_PATCHED = True
+
+ # support service linked roles
+ moto_role_og_arn_prop = MotoRole.arn
+
+ @property
+ def moto_role_arn(self):
+ return getattr(self, "service_linked_role_arn", None) or moto_role_og_arn_prop.__get__(self)
+
+ MotoRole.arn = moto_role_arn
+
+ # Add missing managed polices
+ # TODO this might not be necessary
+ @patch(IAMBackend._init_aws_policies)
+ def _init_aws_policies_extended(_init_aws_policies, self):
+ loaded_policies = _init_aws_policies(self)
+ loaded_policies.extend(
+ [
+ AWSManagedPolicy.from_data(name, self.account_id, self.region_name, d)
+ for name, d in ADDITIONAL_MANAGED_POLICIES.items()
+ ]
+ )
+ return loaded_policies
+
+ if "Principal" not in VALID_STATEMENT_ELEMENTS:
+ VALID_STATEMENT_ELEMENTS.append("Principal")
+
+ # patch policy __init__ to set document as attribute
+
+ @patch(Policy.__init__)
+ def policy__init__(
+ fn,
+ self,
+ name,
+ account_id,
+ region,
+ default_version_id=None,
+ description=None,
+ document=None,
+ **kwargs,
+ ):
+ fn(self, name, account_id, region, default_version_id, description, document, **kwargs)
+ self.document = document
+ if "tags" in kwargs and TAG_KEY_CUSTOM_ID in kwargs["tags"]:
+ self.id = kwargs["tags"][TAG_KEY_CUSTOM_ID]["Value"]
+
+ @patch(IAMBackend.create_role)
+ def iam_backend_create_role(
+ fn,
+ self,
+ role_name: str,
+ assume_role_policy_document: str,
+ path: str,
+ permissions_boundary: Optional[str],
+ description: str,
+ tags: List[Dict[str, str]],
+ max_session_duration: Optional[str],
+ linked_service: Optional[str] = None,
+ ):
+ role = fn(
+ self,
+ role_name,
+ assume_role_policy_document,
+ path,
+ permissions_boundary,
+ description,
+ tags,
+ max_session_duration,
+ linked_service,
+ )
+ new_id_tag = [tag for tag in (tags or []) if tag["Key"] == TAG_KEY_CUSTOM_ID]
+ if new_id_tag:
+ new_id = new_id_tag[0]["Value"]
+ old_id = role.id
+ role.id = new_id
+ self.roles[new_id] = self.roles.pop(old_id)
+ return role
+
+ @patch(InlinePolicy.unapply_policy)
+ def inline_policy_unapply_policy(fn, self, backend):
+ try:
+ fn(self, backend)
+ except Exception:
+ # Actually role can be deleted before policy being deleted in cloudformation
+ pass
+
+ @patch(AccessKey.__init__)
+ def access_key__init__(
+ fn,
+ self,
+ user_name: Optional[str],
+ prefix: str,
+ account_id: str,
+ status: str = "Active",
+ **kwargs,
+ ):
+ if not config.PARITY_AWS_ACCESS_KEY_ID:
+ prefix = "L" + prefix[1:]
+ fn(self, user_name, prefix, account_id, status, **kwargs)
diff --git a/localstack-core/localstack/services/iam/provider.py b/localstack-core/localstack/services/iam/provider.py
new file mode 100644
index 0000000000000..7adca335e82da
--- /dev/null
+++ b/localstack-core/localstack/services/iam/provider.py
@@ -0,0 +1,422 @@
+import json
+import re
+from datetime import datetime
+from typing import Dict, List
+from urllib.parse import quote
+
+from moto.iam.models import (
+ IAMBackend,
+ filter_items_with_path_prefix,
+ iam_backends,
+)
+from moto.iam.models import Role as MotoRole
+
+from localstack.aws.api import CommonServiceException, RequestContext, handler
+from localstack.aws.api.iam import (
+ ActionNameListType,
+ ActionNameType,
+ AttachedPermissionsBoundary,
+ ContextEntryListType,
+ CreateRoleRequest,
+ CreateRoleResponse,
+ CreateServiceLinkedRoleResponse,
+ CreateUserResponse,
+ DeleteServiceLinkedRoleResponse,
+ DeletionTaskIdType,
+ DeletionTaskStatusType,
+ EvaluationResult,
+ GetServiceLinkedRoleDeletionStatusResponse,
+ GetUserResponse,
+ IamApi,
+ InvalidInputException,
+ ListInstanceProfileTagsResponse,
+ ListRolesResponse,
+ MalformedPolicyDocumentException,
+ NoSuchEntityException,
+ PolicyEvaluationDecisionType,
+ ResourceHandlingOptionType,
+ ResourceNameListType,
+ ResourceNameType,
+ Role,
+ SimulatePolicyResponse,
+ SimulationPolicyListType,
+ Tag,
+ User,
+ arnType,
+ customSuffixType,
+ existingUserNameType,
+ groupNameType,
+ instanceProfileNameType,
+ markerType,
+ maxItemsType,
+ pathPrefixType,
+ pathType,
+ policyDocumentType,
+ roleDescriptionType,
+ roleNameType,
+ tagKeyListType,
+ tagListType,
+ userNameType,
+)
+from localstack.aws.connect import connect_to
+from localstack.constants import INTERNAL_AWS_SECRET_ACCESS_KEY
+from localstack.services.iam.iam_patches import apply_iam_patches
+from localstack.services.moto import call_moto
+from localstack.utils.aws.request_context import extract_access_key_id_from_auth_header
+from localstack.utils.common import short_uid
+
+SERVICE_LINKED_ROLE_PATH_PREFIX = "/aws-service-role"
+
+
+POLICY_ARN_REGEX = re.compile(r"arn:[^:]+:iam::(?:\d{12}|aws):policy/.*")
+
+
+def get_iam_backend(context: RequestContext) -> IAMBackend:
+ return iam_backends[context.account_id][context.partition]
+
+
+class IamProvider(IamApi):
+ def __init__(self):
+ apply_iam_patches()
+
+ @handler("CreateRole", expand=False)
+ def create_role(
+ self, context: RequestContext, request: CreateRoleRequest
+ ) -> CreateRoleResponse:
+ try:
+ json.loads(request["AssumeRolePolicyDocument"])
+ except json.JSONDecodeError:
+ raise MalformedPolicyDocumentException("This policy contains invalid Json")
+ result = call_moto(context)
+
+ if not request.get("MaxSessionDuration") and result["Role"].get("MaxSessionDuration"):
+ result["Role"].pop("MaxSessionDuration")
+
+ if "RoleLastUsed" in result["Role"] and not result["Role"]["RoleLastUsed"]:
+ # not part of the AWS response if it's empty
+ # FIXME: RoleLastUsed did not seem well supported when this check was added
+ result["Role"].pop("RoleLastUsed")
+
+ return result
+
+ @staticmethod
+ def build_evaluation_result(
+ action_name: ActionNameType, resource_name: ResourceNameType, policy_statements: List[Dict]
+ ) -> EvaluationResult:
+ eval_res = EvaluationResult()
+ eval_res["EvalActionName"] = action_name
+ eval_res["EvalResourceName"] = resource_name
+ eval_res["EvalDecision"] = PolicyEvaluationDecisionType.explicitDeny
+ for statement in policy_statements:
+ # TODO Implement evaluation logic here
+ if (
+ action_name in statement["Action"]
+ and resource_name in statement["Resource"]
+ and statement["Effect"] == "Allow"
+ ):
+ eval_res["EvalDecision"] = PolicyEvaluationDecisionType.allowed
+ eval_res["MatchedStatements"] = [] # TODO: add support for statement compilation.
+ return eval_res
+
+ def simulate_principal_policy(
+ self,
+ context: RequestContext,
+ policy_source_arn: arnType,
+ action_names: ActionNameListType,
+ policy_input_list: SimulationPolicyListType = None,
+ permissions_boundary_policy_input_list: SimulationPolicyListType = None,
+ resource_arns: ResourceNameListType = None,
+ resource_policy: policyDocumentType = None,
+ resource_owner: ResourceNameType = None,
+ caller_arn: ResourceNameType = None,
+ context_entries: ContextEntryListType = None,
+ resource_handling_option: ResourceHandlingOptionType = None,
+ max_items: maxItemsType = None,
+ marker: markerType = None,
+ **kwargs,
+ ) -> SimulatePolicyResponse:
+ backend = get_iam_backend(context)
+ policy = backend.get_policy(policy_source_arn)
+ policy_version = backend.get_policy_version(policy_source_arn, policy.default_version_id)
+ try:
+ policy_statements = json.loads(policy_version.document).get("Statement", [])
+ except Exception:
+ raise NoSuchEntityException("Policy not found")
+
+ evaluations = [
+ self.build_evaluation_result(action_name, resource_arn, policy_statements)
+ for action_name in action_names
+ for resource_arn in resource_arns
+ ]
+
+ response = SimulatePolicyResponse()
+ response["IsTruncated"] = False
+ response["EvaluationResults"] = evaluations
+ return response
+
+ def delete_policy(self, context: RequestContext, policy_arn: arnType, **kwargs) -> None:
+ backend = get_iam_backend(context)
+ if backend.managed_policies.get(policy_arn):
+ backend.managed_policies.pop(policy_arn, None)
+ else:
+ raise NoSuchEntityException("Policy {0} was not found.".format(policy_arn))
+
+ def detach_role_policy(
+ self, context: RequestContext, role_name: roleNameType, policy_arn: arnType, **kwargs
+ ) -> None:
+ backend = get_iam_backend(context)
+ try:
+ role = backend.get_role(role_name)
+ policy = role.managed_policies[policy_arn]
+ policy.detach_from(role)
+ except KeyError:
+ raise NoSuchEntityException("Policy {0} was not found.".format(policy_arn))
+
+ @staticmethod
+ def moto_role_to_role_type(moto_role: MotoRole) -> Role:
+ role = Role()
+ role["Path"] = moto_role.path
+ role["RoleName"] = moto_role.name
+ role["RoleId"] = moto_role.id
+ role["Arn"] = moto_role.arn
+ role["CreateDate"] = moto_role.create_date
+ if moto_role.assume_role_policy_document:
+ role["AssumeRolePolicyDocument"] = moto_role.assume_role_policy_document
+ if moto_role.description:
+ role["Description"] = moto_role.description
+ if moto_role.max_session_duration:
+ role["MaxSessionDuration"] = moto_role.max_session_duration
+ if moto_role.permissions_boundary:
+ role["PermissionsBoundary"] = moto_role.permissions_boundary
+ if moto_role.tags:
+ role["Tags"] = [Tag(Key=k, Value=v) for k, v in moto_role.tags.items()]
+ # role["RoleLastUsed"]: # TODO: add support
+ return role
+
+ def list_roles(
+ self,
+ context: RequestContext,
+ path_prefix: pathPrefixType = None,
+ marker: markerType = None,
+ max_items: maxItemsType = None,
+ **kwargs,
+ ) -> ListRolesResponse:
+ backend = get_iam_backend(context)
+ moto_roles = backend.roles.values()
+ if path_prefix:
+ moto_roles = filter_items_with_path_prefix(path_prefix, moto_roles)
+ moto_roles = sorted(moto_roles, key=lambda role: role.id)
+
+ response_roles = []
+ for moto_role in moto_roles:
+ response_role = self.moto_role_to_role_type(moto_role)
+ # Permission boundary should not be a part of the response
+ response_role.pop("PermissionsBoundary", None)
+ response_roles.append(response_role)
+ if path_prefix: # TODO: this is consistent with the patch it migrates, but should add tests for this.
+ response_role["AssumeRolePolicyDocument"] = quote(
+ json.dumps(moto_role.assume_role_policy_document or {})
+ )
+
+ return ListRolesResponse(Roles=response_roles, IsTruncated=False)
+
+ def update_group(
+ self,
+ context: RequestContext,
+ group_name: groupNameType,
+ new_path: pathType = None,
+ new_group_name: groupNameType = None,
+ **kwargs,
+ ) -> None:
+ new_group_name = new_group_name or group_name
+ backend = get_iam_backend(context)
+ group = backend.get_group(group_name)
+ group.path = new_path
+ group.name = new_group_name
+ backend.groups[new_group_name] = backend.groups.pop(group_name)
+
+ def list_instance_profile_tags(
+ self,
+ context: RequestContext,
+ instance_profile_name: instanceProfileNameType,
+ marker: markerType = None,
+ max_items: maxItemsType = None,
+ **kwargs,
+ ) -> ListInstanceProfileTagsResponse:
+ backend = get_iam_backend(context)
+ profile = backend.get_instance_profile(instance_profile_name)
+ response = ListInstanceProfileTagsResponse()
+ response["Tags"] = [Tag(Key=k, Value=v) for k, v in profile.tags.items()]
+ return response
+
+ def tag_instance_profile(
+ self,
+ context: RequestContext,
+ instance_profile_name: instanceProfileNameType,
+ tags: tagListType,
+ **kwargs,
+ ) -> None:
+ backend = get_iam_backend(context)
+ profile = backend.get_instance_profile(instance_profile_name)
+ value_by_key = {tag["Key"]: tag["Value"] for tag in tags}
+ profile.tags.update(value_by_key)
+
+ def untag_instance_profile(
+ self,
+ context: RequestContext,
+ instance_profile_name: instanceProfileNameType,
+ tag_keys: tagKeyListType,
+ **kwargs,
+ ) -> None:
+ backend = get_iam_backend(context)
+ profile = backend.get_instance_profile(instance_profile_name)
+ for tag in tag_keys:
+ profile.tags.pop(tag, None)
+
+ def create_service_linked_role(
+ self,
+ context: RequestContext,
+ aws_service_name: groupNameType,
+ description: roleDescriptionType = None,
+ custom_suffix: customSuffixType = None,
+ **kwargs,
+ ) -> CreateServiceLinkedRoleResponse:
+ # TODO: test
+ # TODO: how to support "CustomSuffix" API request parameter?
+ policy_doc = json.dumps(
+ {
+ "Version": "2012-10-17",
+ "Statement": [
+ {
+ "Effect": "Allow",
+ "Principal": {"Service": aws_service_name},
+ "Action": "sts:AssumeRole",
+ }
+ ],
+ }
+ )
+ path = f"{SERVICE_LINKED_ROLE_PATH_PREFIX}/{aws_service_name}"
+ role_name = f"r-{short_uid()}"
+ backend = get_iam_backend(context)
+ role = backend.create_role(
+ role_name=role_name,
+ assume_role_policy_document=policy_doc,
+ path=path,
+ permissions_boundary="",
+ description=description,
+ tags={},
+ max_session_duration=3600,
+ )
+ role.service_linked_role_arn = "arn:{0}:iam::{1}:role/aws-service-role/{2}/{3}".format(
+ context.partition, context.account_id, aws_service_name, role.name
+ )
+
+ res_role = self.moto_role_to_role_type(role)
+ return CreateServiceLinkedRoleResponse(Role=res_role)
+
+ def delete_service_linked_role(
+ self, context: RequestContext, role_name: roleNameType, **kwargs
+ ) -> DeleteServiceLinkedRoleResponse:
+ # TODO: test
+ backend = get_iam_backend(context)
+ backend.delete_role(role_name)
+ return DeleteServiceLinkedRoleResponse(DeletionTaskId=short_uid())
+
+ def get_service_linked_role_deletion_status(
+ self, context: RequestContext, deletion_task_id: DeletionTaskIdType, **kwargs
+ ) -> GetServiceLinkedRoleDeletionStatusResponse:
+ # TODO: test
+ return GetServiceLinkedRoleDeletionStatusResponse(Status=DeletionTaskStatusType.SUCCEEDED)
+
+ def put_user_permissions_boundary(
+ self,
+ context: RequestContext,
+ user_name: userNameType,
+ permissions_boundary: arnType,
+ **kwargs,
+ ) -> None:
+ if user := get_iam_backend(context).users.get(user_name):
+ user.permissions_boundary = permissions_boundary
+ else:
+ raise NoSuchEntityException()
+
+ def delete_user_permissions_boundary(
+ self, context: RequestContext, user_name: userNameType, **kwargs
+ ) -> None:
+ if user := get_iam_backend(context).users.get(user_name):
+ if hasattr(user, "permissions_boundary"):
+ delattr(user, "permissions_boundary")
+ else:
+ raise NoSuchEntityException()
+
+ def create_user(
+ self,
+ context: RequestContext,
+ user_name: userNameType,
+ path: pathType = None,
+ permissions_boundary: arnType = None,
+ tags: tagListType = None,
+ **kwargs,
+ ) -> CreateUserResponse:
+ response = call_moto(context=context)
+ user = get_iam_backend(context).get_user(user_name)
+ if permissions_boundary:
+ user.permissions_boundary = permissions_boundary
+ response["User"]["PermissionsBoundary"] = AttachedPermissionsBoundary(
+ PermissionsBoundaryArn=permissions_boundary,
+ PermissionsBoundaryType="Policy",
+ )
+ return response
+
+ def get_user(
+ self, context: RequestContext, user_name: existingUserNameType = None, **kwargs
+ ) -> GetUserResponse:
+ response = call_moto(context=context)
+ moto_user_name = response["User"]["UserName"]
+ moto_user = get_iam_backend(context).users.get(moto_user_name)
+ # if the user does not exist or is no user
+ if not moto_user and not user_name:
+ access_key_id = extract_access_key_id_from_auth_header(context.request.headers)
+ sts_client = connect_to(
+ region_name=context.region,
+ aws_access_key_id=access_key_id,
+ aws_secret_access_key=INTERNAL_AWS_SECRET_ACCESS_KEY,
+ ).sts
+ caller_identity = sts_client.get_caller_identity()
+ caller_arn = caller_identity["Arn"]
+ if caller_arn.endswith(":root"):
+ return GetUserResponse(
+ User=User(
+ UserId=context.account_id,
+ Arn=caller_arn,
+ CreateDate=datetime.now(),
+ PasswordLastUsed=datetime.now(),
+ )
+ )
+ else:
+ raise CommonServiceException(
+ "ValidationError",
+ "Must specify userName when calling with non-User credentials",
+ )
+
+ if hasattr(moto_user, "permissions_boundary") and moto_user.permissions_boundary:
+ response["User"]["PermissionsBoundary"] = AttachedPermissionsBoundary(
+ PermissionsBoundaryArn=moto_user.permissions_boundary,
+ PermissionsBoundaryType="Policy",
+ )
+
+ return response
+
+ def attach_role_policy(
+ self, context: RequestContext, role_name: roleNameType, policy_arn: arnType, **kwargs
+ ) -> None:
+ if not POLICY_ARN_REGEX.match(policy_arn):
+ raise InvalidInputException(f"ARN {policy_arn} is not valid.")
+ return call_moto(context=context)
+
+ def attach_user_policy(
+ self, context: RequestContext, user_name: userNameType, policy_arn: arnType, **kwargs
+ ) -> None:
+ if not POLICY_ARN_REGEX.match(policy_arn):
+ raise InvalidInputException(f"ARN {policy_arn} is not valid.")
+ return call_moto(context=context)
diff --git a/localstack-core/localstack/services/iam/resource_providers/__init__.py b/localstack-core/localstack/services/iam/resource_providers/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_accesskey.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_accesskey.py
new file mode 100644
index 0000000000000..a945e5af67a47
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_accesskey.py
@@ -0,0 +1,116 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class IAMAccessKeyProperties(TypedDict):
+ UserName: Optional[str]
+ Id: Optional[str]
+ SecretAccessKey: Optional[str]
+ Serial: Optional[int]
+ Status: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class IAMAccessKeyProvider(ResourceProvider[IAMAccessKeyProperties]):
+ TYPE = "AWS::IAM::AccessKey" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[IAMAccessKeyProperties],
+ ) -> ProgressEvent[IAMAccessKeyProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - UserName
+
+ Create-only properties:
+ - /properties/UserName
+ - /properties/Serial
+
+ Read-only properties:
+ - /properties/SecretAccessKey
+ - /properties/Id
+
+ """
+ # TODO: what alues can model['Serial'] take on initial create?
+ model = request.desired_state
+ iam_client = request.aws_client_factory.iam
+
+ access_key = iam_client.create_access_key(UserName=model["UserName"])
+ model["SecretAccessKey"] = access_key["AccessKey"]["SecretAccessKey"]
+ model["Id"] = access_key["AccessKey"]["AccessKeyId"]
+
+ if model.get("Status") == "Inactive":
+ # can be "Active" or "Inactive"
+ # by default the created access key has Status "Active", but if user set Inactive this needs to be adjusted
+ iam_client.update_access_key(
+ AccessKeyId=model["Id"], UserName=model["UserName"], Status=model["Status"]
+ )
+
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ def read(
+ self,
+ request: ResourceRequest[IAMAccessKeyProperties],
+ ) -> ProgressEvent[IAMAccessKeyProperties]:
+ """
+ Fetch resource information
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[IAMAccessKeyProperties],
+ ) -> ProgressEvent[IAMAccessKeyProperties]:
+ """
+ Delete a resource
+ """
+ iam_client = request.aws_client_factory.iam
+ model = request.previous_state
+ iam_client.delete_access_key(AccessKeyId=model["Id"], UserName=model["UserName"])
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model={})
+
+ def update(
+ self,
+ request: ResourceRequest[IAMAccessKeyProperties],
+ ) -> ProgressEvent[IAMAccessKeyProperties]:
+ """
+ Update a resource
+ """
+ iam_client = request.aws_client_factory.iam
+
+ # FIXME: replacement should be handled in engine before here
+ user_name_changed = request.desired_state["UserName"] != request.previous_state["UserName"]
+ serial_changed = request.desired_state["Serial"] != request.previous_state["Serial"]
+ if user_name_changed or serial_changed:
+ # recreate the key
+ self.delete(request)
+ create_event = self.create(request)
+ return create_event
+
+ iam_client.update_access_key(
+ AccessKeyId=request.previous_state["Id"],
+ UserName=request.previous_state["UserName"],
+ Status=request.desired_state["Status"],
+ )
+ old_model = request.previous_state
+ old_model["Status"] = request.desired_state["Status"]
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=old_model)
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_accesskey.schema.json b/localstack-core/localstack/services/iam/resource_providers/aws_iam_accesskey.schema.json
new file mode 100644
index 0000000000000..4925db7a9d608
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_accesskey.schema.json
@@ -0,0 +1,36 @@
+{
+ "typeName": "AWS::IAM::AccessKey",
+ "description": "Resource Type definition for AWS::IAM::AccessKey",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "type": "string"
+ },
+ "SecretAccessKey": {
+ "type": "string"
+ },
+ "Serial": {
+ "type": "integer"
+ },
+ "Status": {
+ "type": "string"
+ },
+ "UserName": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "UserName"
+ ],
+ "readOnlyProperties": [
+ "/properties/SecretAccessKey",
+ "/properties/Id"
+ ],
+ "createOnlyProperties": [
+ "/properties/UserName",
+ "/properties/Serial"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ]
+}
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_accesskey_plugin.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_accesskey_plugin.py
new file mode 100644
index 0000000000000..a54ee6f94b3db
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_accesskey_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class IAMAccessKeyProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::IAM::AccessKey"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.iam.resource_providers.aws_iam_accesskey import (
+ IAMAccessKeyProvider,
+ )
+
+ self.factory = IAMAccessKeyProvider
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_group.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_group.py
new file mode 100644
index 0000000000000..69c2b15ab1bfe
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_group.py
@@ -0,0 +1,152 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class IAMGroupProperties(TypedDict):
+ Arn: Optional[str]
+ GroupName: Optional[str]
+ Id: Optional[str]
+ ManagedPolicyArns: Optional[list[str]]
+ Path: Optional[str]
+ Policies: Optional[list[Policy]]
+
+
+class Policy(TypedDict):
+ PolicyDocument: Optional[dict]
+ PolicyName: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class IAMGroupProvider(ResourceProvider[IAMGroupProperties]):
+ TYPE = "AWS::IAM::Group" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[IAMGroupProperties],
+ ) -> ProgressEvent[IAMGroupProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Create-only properties:
+ - /properties/GroupName
+
+ Read-only properties:
+ - /properties/Arn
+ - /properties/Id
+ """
+ model = request.desired_state
+ iam_client = request.aws_client_factory.iam
+
+ group_name = model.get("GroupName")
+ if not group_name:
+ group_name = util.generate_default_name(request.stack_name, request.logical_resource_id)
+ model["GroupName"] = group_name
+
+ create_group_result = iam_client.create_group(
+ **util.select_attributes(model, ["GroupName", "Path"])
+ )
+ model["Id"] = create_group_result["Group"][
+ "GroupName"
+ ] # a bit weird that this is not the GroupId
+ model["Arn"] = create_group_result["Group"]["Arn"]
+
+ for managed_policy in model.get("ManagedPolicyArns", []):
+ iam_client.attach_group_policy(GroupName=group_name, PolicyArn=managed_policy)
+
+ for inline_policy in model.get("Policies", []):
+ doc = json.dumps(inline_policy.get("PolicyDocument"))
+ iam_client.put_group_policy(
+ GroupName=group_name,
+ PolicyName=inline_policy.get("PolicyName"),
+ PolicyDocument=doc,
+ )
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[IAMGroupProperties],
+ ) -> ProgressEvent[IAMGroupProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[IAMGroupProperties],
+ ) -> ProgressEvent[IAMGroupProperties]:
+ """
+ Delete a resource
+ """
+ model = request.desired_state
+ iam_client = request.aws_client_factory.iam
+
+ # first we need to detach and delete any attached policies
+ for managed_policy in model.get("ManagedPolicyArns", []):
+ iam_client.detach_group_policy(GroupName=model["GroupName"], PolicyArn=managed_policy)
+
+ for inline_policy in model.get("Policies", []):
+ iam_client.delete_group_policy(
+ GroupName=model["GroupName"],
+ PolicyName=inline_policy.get("PolicyName"),
+ )
+
+ # now we can delete the actual group
+ iam_client.delete_group(GroupName=model["GroupName"])
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model={},
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[IAMGroupProperties],
+ ) -> ProgressEvent[IAMGroupProperties]:
+ """
+ Update a resource
+ """
+ # TODO: note: while the resource implemented "update_resource" previously, it didn't actually work
+ # so leaving it out here for now
+ # iam.update_group(
+ # GroupName=props.get("GroupName"),
+ # NewPath=props.get("NewPath") or "",
+ # NewGroupName=props.get("NewGroupName") or "",
+ # )
+ raise NotImplementedError
+
+ def list(
+ self,
+ request: ResourceRequest[IAMGroupProperties],
+ ) -> ProgressEvent[IAMGroupProperties]:
+ resources = request.aws_client_factory.iam.list_groups()
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_models=[
+ IAMGroupProperties(Id=resource["GroupName"]) for resource in resources["Groups"]
+ ],
+ )
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_group.schema.json b/localstack-core/localstack/services/iam/resource_providers/aws_iam_group.schema.json
new file mode 100644
index 0000000000000..e31b0e5594b3f
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_group.schema.json
@@ -0,0 +1,61 @@
+{
+ "typeName": "AWS::IAM::Group",
+ "description": "Resource Type definition for AWS::IAM::Group",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "type": "string"
+ },
+ "Arn": {
+ "type": "string"
+ },
+ "GroupName": {
+ "type": "string"
+ },
+ "ManagedPolicyArns": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Path": {
+ "type": "string"
+ },
+ "Policies": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/Policy"
+ }
+ }
+ },
+ "definitions": {
+ "Policy": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "PolicyDocument": {
+ "type": "object"
+ },
+ "PolicyName": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "PolicyDocument",
+ "PolicyName"
+ ]
+ }
+ },
+ "readOnlyProperties": [
+ "/properties/Arn",
+ "/properties/Id"
+ ],
+ "createOnlyProperties": [
+ "/properties/GroupName"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ]
+}
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_group_plugin.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_group_plugin.py
new file mode 100644
index 0000000000000..24af55af719b1
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_group_plugin.py
@@ -0,0 +1,18 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class IAMGroupProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::IAM::Group"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.iam.resource_providers.aws_iam_group import IAMGroupProvider
+
+ self.factory = IAMGroupProvider
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_instanceprofile.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_instanceprofile.py
new file mode 100644
index 0000000000000..b65f5f079d0ff
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_instanceprofile.py
@@ -0,0 +1,136 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class IAMInstanceProfileProperties(TypedDict):
+ Roles: Optional[list[str]]
+ Arn: Optional[str]
+ InstanceProfileName: Optional[str]
+ Path: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class IAMInstanceProfileProvider(ResourceProvider[IAMInstanceProfileProperties]):
+ TYPE = "AWS::IAM::InstanceProfile" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[IAMInstanceProfileProperties],
+ ) -> ProgressEvent[IAMInstanceProfileProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/InstanceProfileName
+
+ Required properties:
+ - Roles
+
+ Create-only properties:
+ - /properties/InstanceProfileName
+ - /properties/Path
+
+ Read-only properties:
+ - /properties/Arn
+
+ IAM permissions required:
+ - iam:CreateInstanceProfile
+ - iam:PassRole
+ - iam:AddRoleToInstanceProfile
+ - iam:GetInstanceProfile
+
+ """
+ model = request.desired_state
+ iam = request.aws_client_factory.iam
+
+ # defaults
+ role_name = model.get("InstanceProfileName")
+ if not role_name:
+ role_name = util.generate_default_name(request.stack_name, request.logical_resource_id)
+ model["InstanceProfileName"] = role_name
+
+ response = iam.create_instance_profile(
+ **util.select_attributes(
+ model,
+ [
+ "InstanceProfileName",
+ "Path",
+ ],
+ ),
+ )
+ for role_name in model.get("Roles", []):
+ iam.add_role_to_instance_profile(
+ InstanceProfileName=model["InstanceProfileName"], RoleName=role_name
+ )
+ model["Arn"] = response["InstanceProfile"]["Arn"]
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[IAMInstanceProfileProperties],
+ ) -> ProgressEvent[IAMInstanceProfileProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - iam:GetInstanceProfile
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[IAMInstanceProfileProperties],
+ ) -> ProgressEvent[IAMInstanceProfileProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - iam:GetInstanceProfile
+ - iam:RemoveRoleFromInstanceProfile
+ - iam:DeleteInstanceProfile
+ """
+ iam = request.aws_client_factory.iam
+ instance_profile = iam.get_instance_profile(
+ InstanceProfileName=request.previous_state["InstanceProfileName"]
+ )
+ for role in instance_profile["InstanceProfile"]["Roles"]:
+ iam.remove_role_from_instance_profile(
+ InstanceProfileName=request.previous_state["InstanceProfileName"],
+ RoleName=role["RoleName"],
+ )
+ iam.delete_instance_profile(
+ InstanceProfileName=request.previous_state["InstanceProfileName"]
+ )
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model={})
+
+ def update(
+ self,
+ request: ResourceRequest[IAMInstanceProfileProperties],
+ ) -> ProgressEvent[IAMInstanceProfileProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - iam:PassRole
+ - iam:RemoveRoleFromInstanceProfile
+ - iam:AddRoleToInstanceProfile
+ - iam:GetInstanceProfile
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_instanceprofile.schema.json b/localstack-core/localstack/services/iam/resource_providers/aws_iam_instanceprofile.schema.json
new file mode 100644
index 0000000000000..f04a6751c1691
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_instanceprofile.schema.json
@@ -0,0 +1,77 @@
+{
+ "typeName": "AWS::IAM::InstanceProfile",
+ "description": "Resource Type definition for AWS::IAM::InstanceProfile",
+ "additionalProperties": false,
+ "properties": {
+ "Path": {
+ "type": "string",
+ "description": "The path to the instance profile."
+ },
+ "Roles": {
+ "type": "array",
+ "description": "The name of the role to associate with the instance profile. Only one role can be assigned to an EC2 instance at a time, and all applications on the instance share the same role and permissions.",
+ "uniqueItems": true,
+ "insertionOrder": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "InstanceProfileName": {
+ "type": "string",
+ "description": "The name of the instance profile to create."
+ },
+ "Arn": {
+ "type": "string",
+ "description": "The Amazon Resource Name (ARN) of the instance profile."
+ }
+ },
+ "taggable": false,
+ "required": [
+ "Roles"
+ ],
+ "createOnlyProperties": [
+ "/properties/InstanceProfileName",
+ "/properties/Path"
+ ],
+ "primaryIdentifier": [
+ "/properties/InstanceProfileName"
+ ],
+ "readOnlyProperties": [
+ "/properties/Arn"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "iam:CreateInstanceProfile",
+ "iam:PassRole",
+ "iam:AddRoleToInstanceProfile",
+ "iam:GetInstanceProfile"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "iam:GetInstanceProfile"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "iam:PassRole",
+ "iam:RemoveRoleFromInstanceProfile",
+ "iam:AddRoleToInstanceProfile",
+ "iam:GetInstanceProfile"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "iam:GetInstanceProfile",
+ "iam:RemoveRoleFromInstanceProfile",
+ "iam:DeleteInstanceProfile"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "iam:ListInstanceProfiles"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_instanceprofile_plugin.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_instanceprofile_plugin.py
new file mode 100644
index 0000000000000..875b729a55323
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_instanceprofile_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class IAMInstanceProfileProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::IAM::InstanceProfile"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.iam.resource_providers.aws_iam_instanceprofile import (
+ IAMInstanceProfileProvider,
+ )
+
+ self.factory = IAMInstanceProfileProvider
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_managedpolicy.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_managedpolicy.py
new file mode 100644
index 0000000000000..0bca0e5a02169
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_managedpolicy.py
@@ -0,0 +1,117 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class IAMManagedPolicyProperties(TypedDict):
+ PolicyDocument: Optional[dict]
+ Description: Optional[str]
+ Groups: Optional[list[str]]
+ Id: Optional[str]
+ ManagedPolicyName: Optional[str]
+ Path: Optional[str]
+ Roles: Optional[list[str]]
+ Users: Optional[list[str]]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class IAMManagedPolicyProvider(ResourceProvider[IAMManagedPolicyProperties]):
+ TYPE = "AWS::IAM::ManagedPolicy" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[IAMManagedPolicyProperties],
+ ) -> ProgressEvent[IAMManagedPolicyProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - PolicyDocument
+
+ Create-only properties:
+ - /properties/ManagedPolicyName
+ - /properties/Description
+ - /properties/Path
+
+ Read-only properties:
+ - /properties/Id
+
+ """
+ model = request.desired_state
+ iam_client = request.aws_client_factory.iam
+ group_name = model.get("ManagedPolicyName")
+ if not group_name:
+ group_name = util.generate_default_name(request.stack_name, request.logical_resource_id)
+ model["ManagedPolicyName"] = group_name
+
+ policy_doc = json.dumps(util.remove_none_values(model["PolicyDocument"]))
+ policy = iam_client.create_policy(
+ PolicyName=model["ManagedPolicyName"], PolicyDocument=policy_doc
+ )
+ model["Id"] = policy["Policy"]["Arn"]
+ policy_arn = policy["Policy"]["Arn"]
+ for role in model.get("Roles", []):
+ iam_client.attach_role_policy(RoleName=role, PolicyArn=policy_arn)
+ for user in model.get("Users", []):
+ iam_client.attach_user_policy(UserName=user, PolicyArn=policy_arn)
+ for group in model.get("Groups", []):
+ iam_client.attach_group_policy(GroupName=group, PolicyArn=policy_arn)
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ def read(
+ self,
+ request: ResourceRequest[IAMManagedPolicyProperties],
+ ) -> ProgressEvent[IAMManagedPolicyProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[IAMManagedPolicyProperties],
+ ) -> ProgressEvent[IAMManagedPolicyProperties]:
+ """
+ Delete a resource
+ """
+ iam_client = request.aws_client_factory.iam
+ model = request.previous_state
+
+ for role in model.get("Roles", []):
+ iam_client.detach_role_policy(RoleName=role, PolicyArn=model["Id"])
+ for user in model.get("Users", []):
+ iam_client.detach_user_policy(UserName=user, PolicyArn=model["Id"])
+ for group in model.get("Groups", []):
+ iam_client.detach_group_policy(GroupName=group, PolicyArn=model["Id"])
+
+ iam_client.delete_policy(PolicyArn=model["Id"])
+
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ def update(
+ self,
+ request: ResourceRequest[IAMManagedPolicyProperties],
+ ) -> ProgressEvent[IAMManagedPolicyProperties]:
+ """
+ Update a resource
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_managedpolicy.schema.json b/localstack-core/localstack/services/iam/resource_providers/aws_iam_managedpolicy.schema.json
new file mode 100644
index 0000000000000..da6d25ca321bf
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_managedpolicy.schema.json
@@ -0,0 +1,57 @@
+{
+ "typeName": "AWS::IAM::ManagedPolicy",
+ "description": "Resource Type definition for AWS::IAM::ManagedPolicy",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "type": "string"
+ },
+ "Description": {
+ "type": "string"
+ },
+ "Groups": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "ManagedPolicyName": {
+ "type": "string"
+ },
+ "Path": {
+ "type": "string"
+ },
+ "PolicyDocument": {
+ "type": "object"
+ },
+ "Roles": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Users": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ }
+ },
+ "required": [
+ "PolicyDocument"
+ ],
+ "createOnlyProperties": [
+ "/properties/ManagedPolicyName",
+ "/properties/Description",
+ "/properties/Path"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ]
+}
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_managedpolicy_plugin.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_managedpolicy_plugin.py
new file mode 100644
index 0000000000000..d33ce61ef26b5
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_managedpolicy_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class IAMManagedPolicyProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::IAM::ManagedPolicy"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.iam.resource_providers.aws_iam_managedpolicy import (
+ IAMManagedPolicyProvider,
+ )
+
+ self.factory = IAMManagedPolicyProvider
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_policy.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_policy.py
new file mode 100644
index 0000000000000..97fdb19341b57
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_policy.py
@@ -0,0 +1,143 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import json
+import random
+import string
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class IAMPolicyProperties(TypedDict):
+ PolicyDocument: Optional[dict]
+ PolicyName: Optional[str]
+ Groups: Optional[list[str]]
+ Id: Optional[str]
+ Roles: Optional[list[str]]
+ Users: Optional[list[str]]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class IAMPolicyProvider(ResourceProvider[IAMPolicyProperties]):
+ TYPE = "AWS::IAM::Policy" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[IAMPolicyProperties],
+ ) -> ProgressEvent[IAMPolicyProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - PolicyDocument
+ - PolicyName
+
+ Read-only properties:
+ - /properties/Id
+
+ """
+ model = request.desired_state
+ iam_client = request.aws_client_factory.iam
+
+ policy_doc = json.dumps(util.remove_none_values(model["PolicyDocument"]))
+ policy_name = model["PolicyName"]
+
+ if not any([model.get("Roles"), model.get("Users"), model.get("Groups")]):
+ return ProgressEvent(
+ status=OperationStatus.FAILED,
+ resource_model={},
+ error_code="InvalidRequest",
+ message="At least one of [Groups,Roles,Users] must be non-empty.",
+ )
+
+ for role in model.get("Roles", []):
+ iam_client.put_role_policy(
+ RoleName=role, PolicyName=policy_name, PolicyDocument=policy_doc
+ )
+ for user in model.get("Users", []):
+ iam_client.put_user_policy(
+ UserName=user, PolicyName=policy_name, PolicyDocument=policy_doc
+ )
+ for group in model.get("Groups", []):
+ iam_client.put_group_policy(
+ GroupName=group, PolicyName=policy_name, PolicyDocument=policy_doc
+ )
+
+ # the physical resource ID here has a bit of a weird format
+ # e.g. 'stack-fnSe-1OKWZIBB89193' where fnSe are the first 4 characters of the LogicalResourceId (or name?)
+ suffix = "".join(random.choices(string.ascii_uppercase + string.digits, k=13))
+ model["Id"] = f"stack-{model.get('PolicyName', '')[:4]}-{suffix}"
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ def read(
+ self,
+ request: ResourceRequest[IAMPolicyProperties],
+ ) -> ProgressEvent[IAMPolicyProperties]:
+ """
+ Fetch resource information
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[IAMPolicyProperties],
+ ) -> ProgressEvent[IAMPolicyProperties]:
+ """
+ Delete a resource
+ """
+ iam = request.aws_client_factory.iam
+
+ model = request.previous_state
+ policy_name = request.previous_state["PolicyName"]
+ for role in model.get("Roles", []):
+ iam.delete_role_policy(RoleName=role, PolicyName=policy_name)
+ for user in model.get("Users", []):
+ iam.delete_user_policy(UserName=user, PolicyName=policy_name)
+ for group in model.get("Groups", []):
+ iam.delete_group_policy(GroupName=group, PolicyName=policy_name)
+
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model={})
+
+ def update(
+ self,
+ request: ResourceRequest[IAMPolicyProperties],
+ ) -> ProgressEvent[IAMPolicyProperties]:
+ """
+ Update a resource
+ """
+ iam_client = request.aws_client_factory.iam
+ model = request.desired_state
+ # FIXME: this wasn't properly implemented before as well, still needs to be rewritten
+ policy_doc = json.dumps(util.remove_none_values(model["PolicyDocument"]))
+ policy_name = model["PolicyName"]
+
+ for role in model.get("Roles", []):
+ iam_client.put_role_policy(
+ RoleName=role, PolicyName=policy_name, PolicyDocument=policy_doc
+ )
+ for user in model.get("Users", []):
+ iam_client.put_user_policy(
+ UserName=user, PolicyName=policy_name, PolicyDocument=policy_doc
+ )
+ for group in model.get("Groups", []):
+ iam_client.put_group_policy(
+ GroupName=group, PolicyName=policy_name, PolicyDocument=policy_doc
+ )
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model={**request.previous_state, **request.desired_state},
+ )
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_policy.schema.json b/localstack-core/localstack/services/iam/resource_providers/aws_iam_policy.schema.json
new file mode 100644
index 0000000000000..1b6a5fb438e4b
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_policy.schema.json
@@ -0,0 +1,47 @@
+{
+ "typeName": "AWS::IAM::Policy",
+ "description": "Resource Type definition for AWS::IAM::Policy",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "type": "string"
+ },
+ "Groups": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "PolicyDocument": {
+ "type": "object"
+ },
+ "PolicyName": {
+ "type": "string"
+ },
+ "Roles": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Users": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ }
+ },
+ "required": [
+ "PolicyDocument",
+ "PolicyName"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ]
+}
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_policy_plugin.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_policy_plugin.py
new file mode 100644
index 0000000000000..a3fdd7e9c9dc3
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_policy_plugin.py
@@ -0,0 +1,18 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class IAMPolicyProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::IAM::Policy"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.iam.resource_providers.aws_iam_policy import IAMPolicyProvider
+
+ self.factory = IAMPolicyProvider
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_role.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_role.py
new file mode 100644
index 0000000000000..3a3cb8aa63466
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_role.py
@@ -0,0 +1,245 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+from localstack.utils.functions import call_safe
+
+
+class IAMRoleProperties(TypedDict):
+ AssumeRolePolicyDocument: Optional[dict | str]
+ Arn: Optional[str]
+ Description: Optional[str]
+ ManagedPolicyArns: Optional[list[str]]
+ MaxSessionDuration: Optional[int]
+ Path: Optional[str]
+ PermissionsBoundary: Optional[str]
+ Policies: Optional[list[Policy]]
+ RoleId: Optional[str]
+ RoleName: Optional[str]
+ Tags: Optional[list[Tag]]
+
+
+class Policy(TypedDict):
+ PolicyDocument: Optional[str | dict]
+ PolicyName: Optional[str]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+IAM_POLICY_VERSION = "2012-10-17"
+
+
+class IAMRoleProvider(ResourceProvider[IAMRoleProperties]):
+ TYPE = "AWS::IAM::Role" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[IAMRoleProperties],
+ ) -> ProgressEvent[IAMRoleProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/RoleName
+
+ Required properties:
+ - AssumeRolePolicyDocument
+
+ Create-only properties:
+ - /properties/Path
+ - /properties/RoleName
+
+ Read-only properties:
+ - /properties/Arn
+ - /properties/RoleId
+
+ IAM permissions required:
+ - iam:CreateRole
+ - iam:PutRolePolicy
+ - iam:AttachRolePolicy
+ - iam:GetRolePolicy <- not in use right now
+
+ """
+ model = request.desired_state
+ iam = request.aws_client_factory.iam
+
+ # defaults
+ role_name = model.get("RoleName")
+ if not role_name:
+ role_name = util.generate_default_name(request.stack_name, request.logical_resource_id)
+ model["RoleName"] = role_name
+
+ create_role_response = iam.create_role(
+ **{
+ k: v
+ for k, v in model.items()
+ if k not in ["ManagedPolicyArns", "Policies", "AssumeRolePolicyDocument"]
+ },
+ AssumeRolePolicyDocument=json.dumps(model["AssumeRolePolicyDocument"]),
+ )
+
+ # attach managed policies
+ policy_arns = model.get("ManagedPolicyArns", [])
+ for arn in policy_arns:
+ iam.attach_role_policy(RoleName=role_name, PolicyArn=arn)
+
+ # add inline policies
+ inline_policies = model.get("Policies", [])
+ for policy in inline_policies:
+ if not isinstance(policy, dict):
+ request.logger.info(
+ 'Invalid format of policy for IAM role "%s": %s',
+ model.get("RoleName"),
+ policy,
+ )
+ continue
+ pol_name = policy.get("PolicyName")
+
+ # get policy document - make sure we're resolving references in the policy doc
+ doc = dict(policy["PolicyDocument"])
+ doc = util.remove_none_values(doc)
+
+ doc["Version"] = doc.get("Version") or IAM_POLICY_VERSION
+ statements = doc["Statement"]
+ statements = statements if isinstance(statements, list) else [statements]
+ for statement in statements:
+ if isinstance(statement.get("Resource"), list):
+ # filter out empty resource strings
+ statement["Resource"] = [r for r in statement["Resource"] if r]
+ doc = json.dumps(doc)
+ iam.put_role_policy(
+ RoleName=model["RoleName"],
+ PolicyName=pol_name,
+ PolicyDocument=doc,
+ )
+ model["Arn"] = create_role_response["Role"]["Arn"]
+ model["RoleId"] = create_role_response["Role"]["RoleId"]
+
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ def read(
+ self,
+ request: ResourceRequest[IAMRoleProperties],
+ ) -> ProgressEvent[IAMRoleProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - iam:GetRole
+ - iam:ListAttachedRolePolicies
+ - iam:ListRolePolicies
+ - iam:GetRolePolicy
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[IAMRoleProperties],
+ ) -> ProgressEvent[IAMRoleProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - iam:DeleteRole
+ - iam:DetachRolePolicy
+ - iam:DeleteRolePolicy
+ - iam:GetRole
+ - iam:ListAttachedRolePolicies
+ - iam:ListRolePolicies
+ """
+ iam_client = request.aws_client_factory.iam
+ role_name = request.previous_state["RoleName"]
+
+ # detach managed policies
+ for policy in iam_client.list_attached_role_policies(RoleName=role_name).get(
+ "AttachedPolicies", []
+ ):
+ call_safe(
+ iam_client.detach_role_policy,
+ kwargs={"RoleName": role_name, "PolicyArn": policy["PolicyArn"]},
+ )
+
+ # delete inline policies
+ for inline_policy_name in iam_client.list_role_policies(RoleName=role_name).get(
+ "PolicyNames", []
+ ):
+ call_safe(
+ iam_client.delete_role_policy,
+ kwargs={"RoleName": role_name, "PolicyName": inline_policy_name},
+ )
+
+ iam_client.delete_role(RoleName=role_name)
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model={})
+
+ def update(
+ self,
+ request: ResourceRequest[IAMRoleProperties],
+ ) -> ProgressEvent[IAMRoleProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - iam:UpdateRole
+ - iam:UpdateRoleDescription
+ - iam:UpdateAssumeRolePolicy
+ - iam:DetachRolePolicy
+ - iam:AttachRolePolicy
+ - iam:DeleteRolePermissionsBoundary
+ - iam:PutRolePermissionsBoundary
+ - iam:DeleteRolePolicy
+ - iam:PutRolePolicy
+ - iam:TagRole
+ - iam:UntagRole
+ """
+ props = request.desired_state
+ _states = request.previous_state
+
+ # note that we're using permissions that are not technically allowed here due to the currently broken change detection
+ props_policy = props.get("AssumeRolePolicyDocument")
+ # technically a change to the role name shouldn't even get here since it implies a replacement, not an in-place update
+ # for now we just go with it though
+ # determine if the previous name was autogenerated or not
+ new_role_name = props.get("RoleName")
+ name_changed = new_role_name and new_role_name != _states["RoleName"]
+
+ # new_role_name = props.get("RoleName", _states.get("RoleName"))
+ policy_changed = props_policy and props_policy != _states.get(
+ "AssumeRolePolicyDocument", ""
+ )
+ managed_policy_arns_changed = props.get("ManagedPolicyArns", []) != _states.get(
+ "ManagedPolicyArns", []
+ )
+ if name_changed or policy_changed or managed_policy_arns_changed:
+ # TODO: do a proper update instead of replacement
+ self.delete(request)
+ return self.create(request)
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=request.previous_state)
+ # raise Exception("why was a change even detected?")
+
+ def list(
+ self,
+ request: ResourceRequest[IAMRoleProperties],
+ ) -> ProgressEvent[IAMRoleProperties]:
+ resources = request.aws_client_factory.iam.list_roles()
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_models=[
+ IAMRoleProperties(RoleName=resource["RoleName"]) for resource in resources["Roles"]
+ ],
+ )
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_role.schema.json b/localstack-core/localstack/services/iam/resource_providers/aws_iam_role.schema.json
new file mode 100644
index 0000000000000..a7b8a4489cc59
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_role.schema.json
@@ -0,0 +1,183 @@
+{
+ "typeName": "AWS::IAM::Role",
+ "$schema": "https://raw.githubusercontent.com/aws-cloudformation/cloudformation-resource-schema/master/src/main/resources/schema/provider.definition.schema.v1.json",
+ "description": "Resource Type definition for AWS::IAM::Role",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-iam.git",
+ "definitions": {
+ "Policy": {
+ "description": "The inline policy document that is embedded in the specified IAM role.",
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "PolicyDocument": {
+ "description": "The policy document.",
+ "type": [
+ "string",
+ "object"
+ ]
+ },
+ "PolicyName": {
+ "description": "The friendly name (not ARN) identifying the policy.",
+ "type": "string"
+ }
+ },
+ "required": [
+ "PolicyName",
+ "PolicyDocument"
+ ]
+ },
+ "Tag": {
+ "description": "A key-value pair to associate with a resource.",
+ "type": "object",
+ "properties": {
+ "Key": {
+ "type": "string",
+ "description": "The key name of the tag. You can specify a value that is 1 to 128 Unicode characters in length and cannot be prefixed with aws:. You can use any of the following characters: the set of Unicode letters, digits, whitespace, _, ., /, =, +, and -."
+ },
+ "Value": {
+ "type": "string",
+ "description": "The value for the tag. You can specify a value that is 0 to 256 Unicode characters in length and cannot be prefixed with aws:. You can use any of the following characters: the set of Unicode letters, digits, whitespace, _, ., /, =, +, and -."
+ }
+ },
+ "required": [
+ "Key",
+ "Value"
+ ],
+ "additionalProperties": false
+ }
+ },
+ "properties": {
+ "Arn": {
+ "description": "The Amazon Resource Name (ARN) for the role.",
+ "type": "string"
+ },
+ "AssumeRolePolicyDocument": {
+ "description": "The trust policy that is associated with this role.",
+ "type": [
+ "object",
+ "string"
+ ]
+ },
+ "Description": {
+ "description": "A description of the role that you provide.",
+ "type": "string"
+ },
+ "ManagedPolicyArns": {
+ "description": "A list of Amazon Resource Names (ARNs) of the IAM managed policies that you want to attach to the role. ",
+ "type": "array",
+ "uniqueItems": true,
+ "insertionOrder": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "MaxSessionDuration": {
+ "description": "The maximum session duration (in seconds) that you want to set for the specified role. If you do not specify a value for this setting, the default maximum of one hour is applied. This setting can have a value from 1 hour to 12 hours. ",
+ "type": "integer"
+ },
+ "Path": {
+ "description": "The path to the role.",
+ "type": "string"
+ },
+ "PermissionsBoundary": {
+ "description": "The ARN of the policy used to set the permissions boundary for the role.",
+ "type": "string"
+ },
+ "Policies": {
+ "description": "Adds or updates an inline policy document that is embedded in the specified IAM role. ",
+ "type": "array",
+ "insertionOrder": false,
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Policy"
+ }
+ },
+ "RoleId": {
+ "description": "The stable and unique string identifying the role.",
+ "type": "string"
+ },
+ "RoleName": {
+ "description": "A name for the IAM role, up to 64 characters in length.",
+ "type": "string"
+ },
+ "Tags": {
+ "description": "A list of tags that are attached to the role.",
+ "type": "array",
+ "uniqueItems": false,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "AssumeRolePolicyDocument"
+ ],
+ "readOnlyProperties": [
+ "/properties/Arn",
+ "/properties/RoleId"
+ ],
+ "createOnlyProperties": [
+ "/properties/Path",
+ "/properties/RoleName"
+ ],
+ "primaryIdentifier": [
+ "/properties/RoleName"
+ ],
+ "tagging": {
+ "taggable": true,
+ "tagOnCreate": true,
+ "tagUpdatable": true,
+ "cloudFormationSystemTags": false,
+ "tagProperty": "/properties/Tags"
+ },
+ "handlers": {
+ "create": {
+ "permissions": [
+ "iam:CreateRole",
+ "iam:PutRolePolicy",
+ "iam:AttachRolePolicy",
+ "iam:GetRolePolicy"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "iam:GetRole",
+ "iam:ListAttachedRolePolicies",
+ "iam:ListRolePolicies",
+ "iam:GetRolePolicy"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "iam:UpdateRole",
+ "iam:UpdateRoleDescription",
+ "iam:UpdateAssumeRolePolicy",
+ "iam:DetachRolePolicy",
+ "iam:AttachRolePolicy",
+ "iam:DeleteRolePermissionsBoundary",
+ "iam:PutRolePermissionsBoundary",
+ "iam:DeleteRolePolicy",
+ "iam:PutRolePolicy",
+ "iam:TagRole",
+ "iam:UntagRole"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "iam:DeleteRole",
+ "iam:DetachRolePolicy",
+ "iam:DeleteRolePolicy",
+ "iam:GetRole",
+ "iam:ListAttachedRolePolicies",
+ "iam:ListRolePolicies"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "iam:ListRoles"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_role_plugin.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_role_plugin.py
new file mode 100644
index 0000000000000..d6c7059f611eb
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_role_plugin.py
@@ -0,0 +1,18 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class IAMRoleProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::IAM::Role"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.iam.resource_providers.aws_iam_role import IAMRoleProvider
+
+ self.factory = IAMRoleProvider
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_servercertificate.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_servercertificate.py
new file mode 100644
index 0000000000000..233f9554efcc0
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_servercertificate.py
@@ -0,0 +1,133 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class IAMServerCertificateProperties(TypedDict):
+ Arn: Optional[str]
+ CertificateBody: Optional[str]
+ CertificateChain: Optional[str]
+ Path: Optional[str]
+ PrivateKey: Optional[str]
+ ServerCertificateName: Optional[str]
+ Tags: Optional[list[Tag]]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class IAMServerCertificateProvider(ResourceProvider[IAMServerCertificateProperties]):
+ TYPE = "AWS::IAM::ServerCertificate" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[IAMServerCertificateProperties],
+ ) -> ProgressEvent[IAMServerCertificateProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/ServerCertificateName
+
+
+
+ Create-only properties:
+ - /properties/ServerCertificateName
+ - /properties/PrivateKey
+ - /properties/CertificateBody
+ - /properties/CertificateChain
+
+ Read-only properties:
+ - /properties/Arn
+
+ IAM permissions required:
+ - iam:UploadServerCertificate
+ - iam:GetServerCertificate
+
+ """
+ model = request.desired_state
+ if not model.get("ServerCertificateName"):
+ model["ServerCertificateName"] = util.generate_default_name_without_stack(
+ request.logical_resource_id
+ )
+
+ create_params = util.select_attributes(
+ model,
+ [
+ "ServerCertificateName",
+ "PrivateKey",
+ "CertificateBody",
+ "CertificateChain",
+ "Path",
+ "Tags",
+ ],
+ )
+
+ # Create the resource
+ certificate = request.aws_client_factory.iam.upload_server_certificate(**create_params)
+ model["Arn"] = certificate["ServerCertificateMetadata"]["Arn"]
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[IAMServerCertificateProperties],
+ ) -> ProgressEvent[IAMServerCertificateProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - iam:GetServerCertificate
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[IAMServerCertificateProperties],
+ ) -> ProgressEvent[IAMServerCertificateProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - iam:DeleteServerCertificate
+ """
+ model = request.desired_state
+ request.aws_client_factory.iam.delete_server_certificate(
+ ServerCertificateName=model["ServerCertificateName"]
+ )
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[IAMServerCertificateProperties],
+ ) -> ProgressEvent[IAMServerCertificateProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - iam:TagServerCertificate
+ - iam:UntagServerCertificate
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_servercertificate.schema.json b/localstack-core/localstack/services/iam/resource_providers/aws_iam_servercertificate.schema.json
new file mode 100644
index 0000000000000..b0af6c74c2da9
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_servercertificate.schema.json
@@ -0,0 +1,129 @@
+{
+ "typeName": "AWS::IAM::ServerCertificate",
+ "description": "Resource Type definition for AWS::IAM::ServerCertificate",
+ "additionalProperties": false,
+ "properties": {
+ "CertificateBody": {
+ "minLength": 1,
+ "maxLength": 16384,
+ "pattern": "[\\u0009\\u000A\\u000D\\u0020-\\u00FF]+",
+ "type": "string"
+ },
+ "CertificateChain": {
+ "minLength": 1,
+ "maxLength": 2097152,
+ "pattern": "[\\u0009\\u000A\\u000D\\u0020-\\u00FF]+",
+ "type": "string"
+ },
+ "ServerCertificateName": {
+ "minLength": 1,
+ "maxLength": 128,
+ "pattern": "[\\w+=,.@-]+",
+ "type": "string"
+ },
+ "Path": {
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "(\\u002F)|(\\u002F[\\u0021-\\u007F]+\\u002F)",
+ "type": "string"
+ },
+ "PrivateKey": {
+ "minLength": 1,
+ "maxLength": 16384,
+ "pattern": "[\\u0009\\u000A\\u000D\\u0020-\\u00FF]+",
+ "type": "string"
+ },
+ "Arn": {
+ "description": "Amazon Resource Name (ARN) of the server certificate",
+ "minLength": 1,
+ "maxLength": 1600,
+ "type": "string"
+ },
+ "Tags": {
+ "type": "array",
+ "uniqueItems": false,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ }
+ },
+ "definitions": {
+ "Tag": {
+ "description": "A key-value pair to associate with a resource.",
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Value": {
+ "description": "The value for the tag. You can specify a value that is 0 to 256 Unicode characters in length and cannot be prefixed with aws:. You can use any of the following characters: the set of Unicode letters, digits, whitespace, _, ., /, =, +, and -.",
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 256
+ },
+ "Key": {
+ "description": "The key name of the tag. You can specify a value that is 1 to 128 Unicode characters in length and cannot be prefixed with aws:. You can use any of the following characters: the set of Unicode letters, digits, whitespace, _, ., /, =, +, and -.",
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 128
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ }
+ },
+ "createOnlyProperties": [
+ "/properties/ServerCertificateName",
+ "/properties/PrivateKey",
+ "/properties/CertificateBody",
+ "/properties/CertificateChain"
+ ],
+ "readOnlyProperties": [
+ "/properties/Arn"
+ ],
+ "writeOnlyProperties": [
+ "/properties/PrivateKey",
+ "/properties/CertificateBody",
+ "/properties/CertificateChain"
+ ],
+ "primaryIdentifier": [
+ "/properties/ServerCertificateName"
+ ],
+ "handlers": {
+ "create": {
+ "permissions": [
+ "iam:UploadServerCertificate",
+ "iam:GetServerCertificate"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "iam:GetServerCertificate"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "iam:TagServerCertificate",
+ "iam:UntagServerCertificate"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "iam:DeleteServerCertificate"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "iam:ListServerCertificates",
+ "iam:GetServerCertificate"
+ ]
+ }
+ },
+ "tagging": {
+ "taggable": true,
+ "tagOnCreate": true,
+ "tagUpdatable": true,
+ "cloudFormationSystemTags": false
+ }
+}
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_servercertificate_plugin.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_servercertificate_plugin.py
new file mode 100644
index 0000000000000..13723bd73ce2b
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_servercertificate_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class IAMServerCertificateProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::IAM::ServerCertificate"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.iam.resource_providers.aws_iam_servercertificate import (
+ IAMServerCertificateProvider,
+ )
+
+ self.factory = IAMServerCertificateProvider
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_servicelinkedrole.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_servicelinkedrole.py
new file mode 100644
index 0000000000000..2437966df10e7
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_servicelinkedrole.py
@@ -0,0 +1,95 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class IAMServiceLinkedRoleProperties(TypedDict):
+ AWSServiceName: Optional[str]
+ CustomSuffix: Optional[str]
+ Description: Optional[str]
+ Id: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class IAMServiceLinkedRoleProvider(ResourceProvider[IAMServiceLinkedRoleProperties]):
+ TYPE = "AWS::IAM::ServiceLinkedRole" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[IAMServiceLinkedRoleProperties],
+ ) -> ProgressEvent[IAMServiceLinkedRoleProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - AWSServiceName
+
+ Create-only properties:
+ - /properties/CustomSuffix
+ - /properties/AWSServiceName
+
+ Read-only properties:
+ - /properties/Id
+
+ """
+ model = request.desired_state
+ response = request.aws_client_factory.iam.create_service_linked_role(**model)
+ model["Id"] = response["Role"]["RoleName"] # TODO
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[IAMServiceLinkedRoleProperties],
+ ) -> ProgressEvent[IAMServiceLinkedRoleProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[IAMServiceLinkedRoleProperties],
+ ) -> ProgressEvent[IAMServiceLinkedRoleProperties]:
+ """
+ Delete a resource
+ """
+ request.aws_client_factory.iam.delete_service_linked_role(
+ RoleName=request.previous_state["Id"]
+ )
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model={},
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[IAMServiceLinkedRoleProperties],
+ ) -> ProgressEvent[IAMServiceLinkedRoleProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_servicelinkedrole.schema.json b/localstack-core/localstack/services/iam/resource_providers/aws_iam_servicelinkedrole.schema.json
new file mode 100644
index 0000000000000..4472358b498b1
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_servicelinkedrole.schema.json
@@ -0,0 +1,32 @@
+{
+ "typeName": "AWS::IAM::ServiceLinkedRole",
+ "description": "Resource Type definition for AWS::IAM::ServiceLinkedRole",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "type": "string"
+ },
+ "CustomSuffix": {
+ "type": "string"
+ },
+ "Description": {
+ "type": "string"
+ },
+ "AWSServiceName": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "AWSServiceName"
+ ],
+ "createOnlyProperties": [
+ "/properties/CustomSuffix",
+ "/properties/AWSServiceName"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id"
+ ]
+}
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_servicelinkedrole_plugin.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_servicelinkedrole_plugin.py
new file mode 100644
index 0000000000000..e81cc105f85c1
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_servicelinkedrole_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class IAMServiceLinkedRoleProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::IAM::ServiceLinkedRole"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.iam.resource_providers.aws_iam_servicelinkedrole import (
+ IAMServiceLinkedRoleProvider,
+ )
+
+ self.factory = IAMServiceLinkedRoleProvider
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_user.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_user.py
new file mode 100644
index 0000000000000..8600522013b39
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_user.py
@@ -0,0 +1,158 @@
+# LocalStack Resource Provider Scaffolding v1
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class IAMUserProperties(TypedDict):
+ Arn: Optional[str]
+ Groups: Optional[list[str]]
+ Id: Optional[str]
+ LoginProfile: Optional[LoginProfile]
+ ManagedPolicyArns: Optional[list[str]]
+ Path: Optional[str]
+ PermissionsBoundary: Optional[str]
+ Policies: Optional[list[Policy]]
+ Tags: Optional[list[Tag]]
+ UserName: Optional[str]
+
+
+class Policy(TypedDict):
+ PolicyDocument: Optional[dict]
+ PolicyName: Optional[str]
+
+
+class LoginProfile(TypedDict):
+ Password: Optional[str]
+ PasswordResetRequired: Optional[bool]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class IAMUserProvider(ResourceProvider[IAMUserProperties]):
+ TYPE = "AWS::IAM::User" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[IAMUserProperties],
+ ) -> ProgressEvent[IAMUserProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Create-only properties:
+ - /properties/UserName
+
+ Read-only properties:
+ - /properties/Id
+ - /properties/Arn
+ """
+ model = request.desired_state
+ iam_client = request.aws_client_factory.iam
+ # TODO: validations
+ # TODO: idempotency
+
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ # this is the first time this callback is invoked
+
+ # Set defaults
+ if not model.get("UserName"):
+ model["UserName"] = util.generate_default_name(
+ request.stack_name, request.logical_resource_id
+ )
+
+ # actually create the resource
+ # note: technically we could make this synchronous, but for the sake of this being an example it is intentionally "asynchronous" and returns IN_PROGRESS
+
+ # this example uses a helper utility, check out the module for more helpful utilities and add your own!
+ iam_client.create_user(
+ **util.select_attributes(model, ["UserName", "Path", "PermissionsBoundary", "Tags"])
+ )
+
+ # alternatively you can also just do:
+ # iam_client.create_user(
+ # UserName=model["UserName"],
+ # Path=model["Path"],
+ # PermissionsBoundary=model["PermissionsBoundary"],
+ # Tags=model["Tags"],
+ # )
+
+ # this kind of logic below was previously done in either a result_handler or a custom "_post_create" function
+ for group in model.get("Groups", []):
+ iam_client.add_user_to_group(GroupName=group, UserName=model["UserName"])
+
+ for policy_arn in model.get("ManagedPolicyArns", []):
+ iam_client.attach_user_policy(UserName=model["UserName"], PolicyArn=policy_arn)
+
+ request.custom_context[REPEATED_INVOCATION] = True
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ get_response = iam_client.get_user(UserName=model["UserName"])
+ model["Id"] = get_response["User"]["UserName"] # this is the ref / physical resource id
+ model["Arn"] = get_response["User"]["Arn"]
+
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ def read(
+ self,
+ request: ResourceRequest[IAMUserProperties],
+ ) -> ProgressEvent[IAMUserProperties]:
+ """
+ Fetch resource information
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[IAMUserProperties],
+ ) -> ProgressEvent[IAMUserProperties]:
+ """
+ Delete a resource
+ """
+ iam_client = request.aws_client_factory.iam
+ iam_client.delete_user(UserName=request.desired_state["Id"])
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=request.previous_state)
+
+ def update(
+ self,
+ request: ResourceRequest[IAMUserProperties],
+ ) -> ProgressEvent[IAMUserProperties]:
+ """
+ Update a resource
+ """
+ # return ProgressEvent(OperationStatus.SUCCESS, request.desired_state)
+ raise NotImplementedError
+
+ def list(
+ self,
+ request: ResourceRequest[IAMUserProperties],
+ ) -> ProgressEvent[IAMUserProperties]:
+ resources = request.aws_client_factory.iam.list_users()
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_models=[
+ IAMUserProperties(Id=resource["UserName"]) for resource in resources["Users"]
+ ],
+ )
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_user.schema.json b/localstack-core/localstack/services/iam/resource_providers/aws_iam_user.schema.json
new file mode 100644
index 0000000000000..aabdb1c81ddbf
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_user.schema.json
@@ -0,0 +1,112 @@
+{
+ "typeName": "AWS::IAM::User",
+ "description": "Resource Type definition for AWS::IAM::User",
+ "additionalProperties": false,
+ "properties": {
+ "Path": {
+ "type": "string"
+ },
+ "ManagedPolicyArns": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Policies": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Policy"
+ }
+ },
+ "UserName": {
+ "type": "string"
+ },
+ "Groups": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "type": "string"
+ }
+ },
+ "Id": {
+ "type": "string"
+ },
+ "Arn": {
+ "type": "string"
+ },
+ "LoginProfile": {
+ "$ref": "#/definitions/LoginProfile"
+ },
+ "Tags": {
+ "type": "array",
+ "uniqueItems": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "PermissionsBoundary": {
+ "type": "string"
+ }
+ },
+ "definitions": {
+ "Policy": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "PolicyDocument": {
+ "type": "object"
+ },
+ "PolicyName": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "PolicyName",
+ "PolicyDocument"
+ ]
+ },
+ "Tag": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Value": {
+ "type": "string"
+ },
+ "Key": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Value",
+ "Key"
+ ]
+ },
+ "LoginProfile": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "PasswordResetRequired": {
+ "type": "boolean"
+ },
+ "Password": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "Password"
+ ]
+ }
+ },
+ "createOnlyProperties": [
+ "/properties/UserName"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ],
+ "readOnlyProperties": [
+ "/properties/Id",
+ "/properties/Arn"
+ ]
+}
diff --git a/localstack-core/localstack/services/iam/resource_providers/aws_iam_user_plugin.py b/localstack-core/localstack/services/iam/resource_providers/aws_iam_user_plugin.py
new file mode 100644
index 0000000000000..60acd8fc1493c
--- /dev/null
+++ b/localstack-core/localstack/services/iam/resource_providers/aws_iam_user_plugin.py
@@ -0,0 +1,18 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class IAMUserProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::IAM::User"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.iam.resource_providers.aws_iam_user import IAMUserProvider
+
+ self.factory = IAMUserProvider
diff --git a/localstack-core/localstack/services/internal.py b/localstack-core/localstack/services/internal.py
new file mode 100644
index 0000000000000..85c4de12ff351
--- /dev/null
+++ b/localstack-core/localstack/services/internal.py
@@ -0,0 +1,344 @@
+"""Module for localstack internal resources, such as health, graph, or _localstack/cloudformation/deploy."""
+
+import logging
+import os
+import re
+import time
+from collections import defaultdict
+from datetime import datetime
+
+from plux import PluginManager
+from werkzeug.exceptions import NotFound
+
+from localstack import config, constants
+from localstack.deprecations import deprecated_endpoint
+from localstack.http import Request, Resource, Response, Router
+from localstack.http.dispatcher import handler_dispatcher
+from localstack.runtime.legacy import signal_supervisor_restart
+from localstack.utils.analytics.metadata import (
+ get_client_metadata,
+ get_localstack_edition,
+ is_license_activated,
+)
+from localstack.utils.collections import merge_recursive
+from localstack.utils.functions import call_safe
+from localstack.utils.numbers import is_number
+from localstack.utils.objects import singleton_factory
+
+LOG = logging.getLogger(__name__)
+
+HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "PATCH"]
+
+
+class DeprecatedResource:
+ """
+ Resource class which wraps a given resource in the deprecated_endpoint (i.e. logs deprecation warnings on every
+ invocation).
+ """
+
+ def __init__(self, resource, previous_path: str, deprecation_version: str, new_path: str):
+ for http_method in HTTP_METHODS:
+ fn_name = f"on_{http_method.lower()}"
+ fn = getattr(resource, fn_name, None)
+ if fn:
+ wrapped = deprecated_endpoint(
+ fn,
+ previous_path=previous_path,
+ deprecation_version=deprecation_version,
+ new_path=new_path,
+ )
+ setattr(self, fn_name, wrapped)
+
+
+class HealthResource:
+ """
+ Resource for the LocalStack /health endpoint. It provides access to the service states and other components of
+ localstack. We support arbitrary data to be put into the health state to support things like the
+ run_startup_scripts function in docker-entrypoint.sh which sets the status of the init scripts feature.
+ """
+
+ def __init__(self, service_manager) -> None:
+ super().__init__()
+ self.service_manager = service_manager
+ self.state = {}
+
+ def on_post(self, request: Request):
+ data = request.get_json(True, True)
+ if not data:
+ return Response("invalid request", 400)
+
+ # backdoor API to support restarting the instance
+ if data.get("action") == "restart":
+ signal_supervisor_restart()
+ elif data.get("action") == "kill":
+ from localstack.runtime import get_current_runtime
+
+ get_current_runtime().exit(0)
+
+ return Response("ok", 200)
+
+ def on_get(self, request: Request):
+ path = request.path
+
+ reload = "reload" in path
+
+ # get service state
+ if reload:
+ self.service_manager.check_all()
+ services = {
+ service: state.value for service, state in self.service_manager.get_states().items()
+ }
+
+ # build state dict from internal state and merge into it the service states
+ result = dict(self.state)
+ result = merge_recursive({"services": services}, result)
+ result["edition"] = get_localstack_edition()
+ result["version"] = constants.VERSION
+ return result
+
+ def on_head(self, request: Request):
+ return Response("ok", 200)
+
+ def on_put(self, request: Request):
+ data = request.get_json(True, True) or {}
+
+ # keys like "features:initScripts" should be interpreted as ['features']['initScripts']
+ state = defaultdict(dict)
+ for k, v in data.items():
+ if ":" in k:
+ path = k.split(":")
+ else:
+ path = [k]
+
+ d = state
+ for p in path[:-1]:
+ d = state[p]
+ d[path[-1]] = v
+
+ self.state = merge_recursive(state, self.state, overwrite=True)
+ return {"status": "OK"}
+
+
+class InfoResource:
+ """
+ Resource that is exposed to /_localstack/info and used to get generalized information about the current
+ localstack instance.
+ """
+
+ def on_get(self, request):
+ return self.get_info_data()
+
+ @staticmethod
+ def get_info_data() -> dict:
+ client_metadata = get_client_metadata()
+ uptime = int(time.time() - config.load_start_time)
+
+ return {
+ "version": client_metadata.version,
+ "edition": get_localstack_edition(),
+ "is_license_activated": is_license_activated(),
+ "session_id": client_metadata.session_id,
+ "machine_id": client_metadata.machine_id,
+ "system": client_metadata.system,
+ "is_docker": client_metadata.is_docker,
+ "server_time_utc": datetime.utcnow().isoformat(timespec="seconds"),
+ "uptime": uptime,
+ }
+
+
+class UsageResource:
+ def on_get(self, request):
+ from localstack.utils import diagnose
+
+ return call_safe(diagnose.get_usage) or {}
+
+
+class DiagnoseResource:
+ def on_get(self, request):
+ from localstack.utils import diagnose
+
+ return {
+ "version": {
+ "image-version": call_safe(diagnose.get_docker_image_details),
+ "localstack-version": call_safe(diagnose.get_localstack_version),
+ "host": {
+ "kernel": call_safe(diagnose.get_host_kernel_version),
+ },
+ },
+ "info": call_safe(InfoResource.get_info_data),
+ "services": call_safe(diagnose.get_service_stats),
+ "config": call_safe(diagnose.get_localstack_config),
+ "docker-inspect": call_safe(diagnose.inspect_main_container),
+ "docker-dependent-image-hashes": call_safe(diagnose.get_important_image_hashes),
+ "file-tree": call_safe(diagnose.get_file_tree),
+ "important-endpoints": call_safe(diagnose.resolve_endpoints),
+ "logs": call_safe(diagnose.get_localstack_logs),
+ "usage": call_safe(diagnose.get_usage),
+ }
+
+
+class PluginsResource:
+ """
+ Resource to list information about plux plugins.
+ """
+
+ plugin_managers: list[PluginManager] = []
+
+ def __init__(self):
+ # defer imports here to lazy-load code
+ from localstack.runtime import hooks, init
+ from localstack.services.plugins import SERVICE_PLUGINS
+
+ # service providers
+ PluginsResource.plugin_managers.append(SERVICE_PLUGINS.plugin_manager)
+ # init script runners
+ PluginsResource.plugin_managers.append(init.init_script_manager().runner_manager)
+ # init hooks
+ PluginsResource.plugin_managers.append(hooks.configure_localstack_container.manager)
+ PluginsResource.plugin_managers.append(hooks.prepare_host.manager)
+ PluginsResource.plugin_managers.append(hooks.on_infra_ready.manager)
+ PluginsResource.plugin_managers.append(hooks.on_infra_start.manager)
+ PluginsResource.plugin_managers.append(hooks.on_infra_shutdown.manager)
+
+ def on_get(self, request):
+ return {
+ manager.namespace: [
+ self._get_plugin_details(manager, name) for name in manager.list_names()
+ ]
+ for manager in self.plugin_managers
+ }
+
+ def _get_plugin_details(self, manager: PluginManager, plugin_name: str) -> dict:
+ container = manager.get_container(plugin_name)
+
+ details = {
+ "name": plugin_name,
+ "is_initialized": container.is_init,
+ "is_loaded": container.is_loaded,
+ }
+
+ # optionally add requires_license information if the plugin provides it
+ requires_license = None
+ if container.plugin:
+ try:
+ requires_license = container.plugin.requires_license
+ except AttributeError:
+ pass
+ if requires_license is not None:
+ details["requires_license"] = requires_license
+
+ return details
+
+
+class InitScriptsResource:
+ def on_get(self, request):
+ from localstack.runtime.init import init_script_manager
+
+ manager = init_script_manager()
+
+ return {
+ "completed": {
+ stage.name: completed for stage, completed in manager.stage_completed.items()
+ },
+ "scripts": [
+ {
+ "stage": script.stage.name,
+ "name": os.path.basename(script.path),
+ "state": script.state.name,
+ }
+ for scripts in manager.scripts.values()
+ for script in scripts
+ ],
+ }
+
+
+class InitScriptsStageResource:
+ def on_get(self, request, stage: str):
+ from localstack.runtime.init import Stage, init_script_manager
+
+ manager = init_script_manager()
+
+ try:
+ stage = Stage[stage.upper()]
+ except KeyError as e:
+ raise NotFound(f"no such stage {stage}") from e
+
+ return {
+ "completed": manager.stage_completed.get(stage),
+ "scripts": [
+ {
+ "stage": script.stage.name,
+ "name": os.path.basename(script.path),
+ "state": script.state.name,
+ }
+ for script in manager.scripts.get(stage)
+ ],
+ }
+
+
+class ConfigResource:
+ def on_get(self, request):
+ from localstack.utils import diagnose
+
+ return call_safe(diagnose.get_localstack_config)
+
+ def on_post(self, request: Request):
+ from localstack.utils.config_listener import update_config_variable
+
+ data = request.get_json(force=True)
+ variable = data.get("variable", "")
+ if not re.match(r"^[_a-zA-Z0-9]+$", variable):
+ return Response("{}", mimetype="application/json", status=400)
+ new_value = data.get("value")
+ if is_number(new_value):
+ new_value = float(new_value)
+ update_config_variable(variable, new_value)
+ value = getattr(config, variable, None)
+ return {
+ "variable": variable,
+ "value": value,
+ }
+
+
+class LocalstackResources(Router):
+ """
+ Router for localstack-internal HTTP resources.
+ """
+
+ def __init__(self):
+ super().__init__(dispatcher=handler_dispatcher())
+ self.add_default_routes()
+ # TODO: load routes as plugins
+
+ def add_default_routes(self):
+ from localstack.services.plugins import SERVICE_PLUGINS
+
+ health_resource = HealthResource(SERVICE_PLUGINS)
+ self.add(Resource("/_localstack/health", health_resource))
+ self.add(Resource("/_localstack/info", InfoResource()))
+ self.add(Resource("/_localstack/plugins", PluginsResource()))
+ self.add(Resource("/_localstack/init", InitScriptsResource()))
+ self.add(Resource("/_localstack/init/", InitScriptsStageResource()))
+
+ if config.ENABLE_CONFIG_UPDATES:
+ LOG.warning(
+ "Enabling config endpoint, "
+ "please be aware that this can expose sensitive information via your network."
+ )
+ self.add(Resource("/_localstack/config", ConfigResource()))
+
+ if config.DEBUG:
+ LOG.warning(
+ "Enabling diagnose endpoint, "
+ "please be aware that this can expose sensitive information via your network."
+ )
+ self.add(Resource("/_localstack/diagnose", DiagnoseResource()))
+ self.add(Resource("/_localstack/usage", UsageResource()))
+
+
+@singleton_factory
+def get_internal_apis() -> LocalstackResources:
+ """
+ Get the LocalstackResources singleton.
+ """
+ return LocalstackResources()
diff --git a/localstack-core/localstack/services/kinesis/__init__.py b/localstack-core/localstack/services/kinesis/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/kinesis/kinesis_mock_server.py b/localstack-core/localstack/services/kinesis/kinesis_mock_server.py
new file mode 100644
index 0000000000000..af23e3940ef24
--- /dev/null
+++ b/localstack-core/localstack/services/kinesis/kinesis_mock_server.py
@@ -0,0 +1,170 @@
+import logging
+import os
+import threading
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+from localstack import config
+from localstack.services.kinesis.packages import kinesismock_package
+from localstack.utils.common import TMP_THREADS, ShellCommandThread, get_free_tcp_port, mkdir
+from localstack.utils.run import FuncThread
+from localstack.utils.serving import Server
+
+LOG = logging.getLogger(__name__)
+
+
+class KinesisMockServer(Server):
+ """
+ Server abstraction for controlling Kinesis Mock in a separate thread
+ """
+
+ def __init__(
+ self,
+ port: int,
+ js_path: Path,
+ latency: str,
+ account_id: str,
+ host: str = "localhost",
+ log_level: str = "INFO",
+ data_dir: Optional[str] = None,
+ ) -> None:
+ self._account_id = account_id
+ self._latency = latency
+ self._data_dir = data_dir
+ self._data_filename = f"{self._account_id}.json"
+ self._js_path = js_path
+ self._log_level = log_level
+ super().__init__(port, host)
+
+ def do_start_thread(self) -> FuncThread:
+ cmd, env_vars = self._create_shell_command()
+ LOG.debug("starting kinesis process %s with env vars %s", cmd, env_vars)
+ t = ShellCommandThread(
+ cmd,
+ strip_color=True,
+ env_vars=env_vars,
+ log_listener=self._log_listener,
+ auto_restart=True,
+ name="kinesis-mock",
+ )
+ TMP_THREADS.append(t)
+ t.start()
+ return t
+
+ def _create_shell_command(self) -> Tuple[List, Dict]:
+ """
+ Helper method for creating kinesis mock invocation command
+ :return: returns a tuple containing the command list and a dictionary with the environment variables
+ """
+
+ env_vars = {
+ # Use the `server.json` packaged next to the main.js
+ "KINESIS_MOCK_CERT_PATH": str((self._js_path.parent / "server.json").absolute()),
+ "KINESIS_MOCK_PLAIN_PORT": self.port,
+ # Each kinesis-mock instance listens to two ports - secure and insecure.
+ # LocalStack uses only one - the insecure one. Block the secure port to avoid conflicts.
+ "KINESIS_MOCK_TLS_PORT": get_free_tcp_port(),
+ "SHARD_LIMIT": config.KINESIS_SHARD_LIMIT,
+ "ON_DEMAND_STREAM_COUNT_LIMIT": config.KINESIS_ON_DEMAND_STREAM_COUNT_LIMIT,
+ "AWS_ACCOUNT_ID": self._account_id,
+ }
+
+ latency_params = [
+ "CREATE_STREAM_DURATION",
+ "DELETE_STREAM_DURATION",
+ "REGISTER_STREAM_CONSUMER_DURATION",
+ "START_STREAM_ENCRYPTION_DURATION",
+ "STOP_STREAM_ENCRYPTION_DURATION",
+ "DEREGISTER_STREAM_CONSUMER_DURATION",
+ "MERGE_SHARDS_DURATION",
+ "SPLIT_SHARD_DURATION",
+ "UPDATE_SHARD_COUNT_DURATION",
+ "UPDATE_STREAM_MODE_DURATION",
+ ]
+ for param in latency_params:
+ env_vars[param] = self._latency
+
+ if self._data_dir and config.KINESIS_PERSISTENCE:
+ env_vars["SHOULD_PERSIST_DATA"] = "true"
+ env_vars["PERSIST_PATH"] = self._data_dir
+ env_vars["PERSIST_FILE_NAME"] = self._data_filename
+ env_vars["PERSIST_INTERVAL"] = config.KINESIS_MOCK_PERSIST_INTERVAL
+
+ env_vars["LOG_LEVEL"] = self._log_level
+ cmd = ["node", self._js_path]
+ return cmd, env_vars
+
+ def _log_listener(self, line, **_kwargs):
+ LOG.info(line.rstrip())
+
+
+class KinesisServerManager:
+ default_startup_timeout = 60
+
+ def __init__(self):
+ self._lock = threading.RLock()
+ self._servers: dict[str, KinesisMockServer] = {}
+
+ def get_server_for_account(self, account_id: str) -> KinesisMockServer:
+ if account_id in self._servers:
+ return self._servers[account_id]
+
+ with self._lock:
+ if account_id in self._servers:
+ return self._servers[account_id]
+
+ LOG.info("Creating kinesis backend for account %s", account_id)
+ self._servers[account_id] = self._create_kinesis_mock_server(account_id)
+ self._servers[account_id].start()
+ if not self._servers[account_id].wait_is_up(timeout=self.default_startup_timeout):
+ raise TimeoutError("gave up waiting for kinesis backend to start up")
+ return self._servers[account_id]
+
+ def shutdown_all(self):
+ with self._lock:
+ while self._servers:
+ account_id, server = self._servers.popitem()
+ LOG.info("Shutting down kinesis backend for account %s", account_id)
+ server.shutdown()
+
+ def _create_kinesis_mock_server(self, account_id: str) -> KinesisMockServer:
+ """
+ Creates a new Kinesis Mock server instance. Installs Kinesis Mock on the host first if necessary.
+ Introspects on the host config to determine server configuration:
+ config.dirs.data -> if set, the server runs with persistence using the path to store data
+ config.LS_LOG -> configure kinesis mock log level (defaults to INFO)
+ config.KINESIS_LATENCY -> configure stream latency (in milliseconds)
+ """
+ port = get_free_tcp_port()
+ kinesismock_package.install()
+ kinesis_mock_js_path = Path(kinesismock_package.get_installer().get_executable_path())
+
+ # kinesis-mock stores state in json files .json, so we can dump everything into `kinesis/`
+ persist_path = os.path.join(config.dirs.data, "kinesis")
+ mkdir(persist_path)
+ if config.KINESIS_MOCK_LOG_LEVEL:
+ log_level = config.KINESIS_MOCK_LOG_LEVEL.upper()
+ elif config.LS_LOG:
+ ls_log_level = config.LS_LOG.upper()
+ if ls_log_level == "WARNING":
+ log_level = "WARN"
+ elif ls_log_level == "TRACE-INTERNAL":
+ log_level = "TRACE"
+ elif ls_log_level not in ("ERROR", "WARN", "INFO", "DEBUG", "TRACE"):
+ # to protect from cases where the log level will be rejected from kinesis-mock
+ log_level = "INFO"
+ else:
+ log_level = ls_log_level
+ else:
+ log_level = "INFO"
+ latency = config.KINESIS_LATENCY + "ms"
+
+ server = KinesisMockServer(
+ port=port,
+ js_path=kinesis_mock_js_path,
+ log_level=log_level,
+ latency=latency,
+ data_dir=persist_path,
+ account_id=account_id,
+ )
+ return server
diff --git a/localstack-core/localstack/services/kinesis/models.py b/localstack-core/localstack/services/kinesis/models.py
new file mode 100644
index 0000000000000..3247ac060fbb0
--- /dev/null
+++ b/localstack-core/localstack/services/kinesis/models.py
@@ -0,0 +1,18 @@
+from collections import defaultdict
+from typing import Dict, List, Set
+
+from localstack.aws.api.kinesis import ConsumerDescription, MetricsName, StreamName
+from localstack.services.stores import AccountRegionBundle, BaseStore, LocalAttribute
+
+
+class KinesisStore(BaseStore):
+ # list of stream consumer details
+ stream_consumers: List[ConsumerDescription] = LocalAttribute(default=list)
+
+ # maps stream name to list of enhanced monitoring metrics
+ enhanced_metrics: Dict[StreamName, Set[MetricsName]] = LocalAttribute(
+ default=lambda: defaultdict(set)
+ )
+
+
+kinesis_stores = AccountRegionBundle("kinesis", KinesisStore)
diff --git a/localstack-core/localstack/services/kinesis/packages.py b/localstack-core/localstack/services/kinesis/packages.py
new file mode 100644
index 0000000000000..0094b64058f64
--- /dev/null
+++ b/localstack-core/localstack/services/kinesis/packages.py
@@ -0,0 +1,28 @@
+import os
+from functools import lru_cache
+from typing import List
+
+from localstack.packages import Package, PackageInstaller
+from localstack.packages.core import NodePackageInstaller
+
+_KINESIS_MOCK_VERSION = os.environ.get("KINESIS_MOCK_VERSION") or "0.4.8"
+
+
+class KinesisMockPackage(Package):
+ def __init__(self, default_version: str = _KINESIS_MOCK_VERSION):
+ super().__init__(name="Kinesis Mock", default_version=default_version)
+
+ @lru_cache
+ def _get_installer(self, version: str) -> PackageInstaller:
+ return KinesisMockPackageInstaller(version)
+
+ def get_versions(self) -> List[str]:
+ return [_KINESIS_MOCK_VERSION]
+
+
+class KinesisMockPackageInstaller(NodePackageInstaller):
+ def __init__(self, version: str):
+ super().__init__(package_name="kinesis-local", version=version)
+
+
+kinesismock_package = KinesisMockPackage()
diff --git a/localstack-core/localstack/services/kinesis/plugins.py b/localstack-core/localstack/services/kinesis/plugins.py
new file mode 100644
index 0000000000000..13f06b3e630ca
--- /dev/null
+++ b/localstack-core/localstack/services/kinesis/plugins.py
@@ -0,0 +1,8 @@
+from localstack.packages import Package, package
+
+
+@package(name="kinesis-mock")
+def kinesismock_package() -> Package:
+ from localstack.services.kinesis.packages import kinesismock_package
+
+ return kinesismock_package
diff --git a/localstack-core/localstack/services/kinesis/provider.py b/localstack-core/localstack/services/kinesis/provider.py
new file mode 100644
index 0000000000000..7f080e35fc122
--- /dev/null
+++ b/localstack-core/localstack/services/kinesis/provider.py
@@ -0,0 +1,185 @@
+import logging
+import os
+import time
+from random import random
+
+from localstack import config
+from localstack.aws.api import RequestContext
+from localstack.aws.api.kinesis import (
+ ConsumerARN,
+ Data,
+ HashKey,
+ KinesisApi,
+ PartitionKey,
+ ProvisionedThroughputExceededException,
+ PutRecordOutput,
+ PutRecordsOutput,
+ PutRecordsRequestEntryList,
+ PutRecordsResultEntry,
+ SequenceNumber,
+ ShardId,
+ StartingPosition,
+ StreamARN,
+ StreamName,
+ SubscribeToShardEvent,
+ SubscribeToShardEventStream,
+ SubscribeToShardOutput,
+)
+from localstack.aws.connect import connect_to
+from localstack.constants import LOCALHOST
+from localstack.services.kinesis.kinesis_mock_server import KinesisServerManager
+from localstack.services.kinesis.models import KinesisStore, kinesis_stores
+from localstack.services.plugins import ServiceLifecycleHook
+from localstack.state import AssetDirectory, StateVisitor
+from localstack.utils.aws import arns
+from localstack.utils.aws.arns import extract_account_id_from_arn, extract_region_from_arn
+from localstack.utils.time import now_utc
+
+LOG = logging.getLogger(__name__)
+MAX_SUBSCRIPTION_SECONDS = 300
+SERVER_STARTUP_TIMEOUT = 120
+
+
+def find_stream_for_consumer(consumer_arn):
+ account_id = extract_account_id_from_arn(consumer_arn)
+ region_name = extract_region_from_arn(consumer_arn)
+ kinesis = connect_to(aws_access_key_id=account_id, region_name=region_name).kinesis
+ for stream_name in kinesis.list_streams()["StreamNames"]:
+ stream_arn = arns.kinesis_stream_arn(stream_name, account_id, region_name)
+ for cons in kinesis.list_stream_consumers(StreamARN=stream_arn)["Consumers"]:
+ if cons["ConsumerARN"] == consumer_arn:
+ return stream_name
+ raise Exception("Unable to find stream for stream consumer %s" % consumer_arn)
+
+
+class KinesisProvider(KinesisApi, ServiceLifecycleHook):
+ server_manager: KinesisServerManager
+
+ def __init__(self):
+ self.server_manager = KinesisServerManager()
+
+ def accept_state_visitor(self, visitor: StateVisitor):
+ visitor.visit(kinesis_stores)
+ visitor.visit(AssetDirectory(self.service, os.path.join(config.dirs.data, "kinesis")))
+
+ def on_before_state_load(self):
+ # no need to restart servers, since that happens lazily in `server_manager.get_server_for_account`.
+ self.server_manager.shutdown_all()
+
+ def on_before_state_reset(self):
+ self.server_manager.shutdown_all()
+
+ def on_before_stop(self):
+ self.server_manager.shutdown_all()
+
+ def get_forward_url(self, account_id: str, region_name: str) -> str:
+ """Return the URL of the backend Kinesis server to forward requests to"""
+ server = self.server_manager.get_server_for_account(account_id)
+ return f"http://{LOCALHOST}:{server.port}"
+
+ @staticmethod
+ def get_store(account_id: str, region_name: str) -> KinesisStore:
+ return kinesis_stores[account_id][region_name]
+
+ def subscribe_to_shard(
+ self,
+ context: RequestContext,
+ consumer_arn: ConsumerARN,
+ shard_id: ShardId,
+ starting_position: StartingPosition,
+ **kwargs,
+ ) -> SubscribeToShardOutput:
+ kinesis = connect_to(
+ aws_access_key_id=context.account_id, region_name=context.region
+ ).kinesis
+ stream_name = find_stream_for_consumer(consumer_arn)
+ iter_type = starting_position["Type"]
+ kwargs = {}
+ starting_sequence_number = starting_position.get("SequenceNumber") or "0"
+ if iter_type in ["AT_SEQUENCE_NUMBER", "AFTER_SEQUENCE_NUMBER"]:
+ kwargs["StartingSequenceNumber"] = starting_sequence_number
+ elif iter_type in ["AT_TIMESTAMP"]:
+ # or value is just an example timestamp from aws docs
+ timestamp = starting_position.get("Timestamp") or 1459799926.480
+ kwargs["Timestamp"] = timestamp
+ initial_shard_iterator = kinesis.get_shard_iterator(
+ StreamName=stream_name, ShardId=shard_id, ShardIteratorType=iter_type, **kwargs
+ )["ShardIterator"]
+
+ def event_generator():
+ shard_iterator = initial_shard_iterator
+ last_sequence_number = starting_sequence_number
+
+ maximum_duration_subscription_timestamp = now_utc() + MAX_SUBSCRIPTION_SECONDS
+
+ while now_utc() < maximum_duration_subscription_timestamp:
+ try:
+ result = kinesis.get_records(ShardIterator=shard_iterator)
+ except Exception as e:
+ if "ResourceNotFoundException" in str(e):
+ LOG.debug(
+ 'Kinesis stream "%s" has been deleted, closing shard subscriber',
+ stream_name,
+ )
+ return
+ raise
+ shard_iterator = result.get("NextShardIterator")
+ records = result.get("Records", [])
+ if not records:
+ # On AWS there is *at least* 1 event every 5 seconds
+ # but this is not possible in this structure.
+ # In order to avoid a 5-second blocking call, we make the compromise of 3 seconds.
+ time.sleep(3)
+
+ yield SubscribeToShardEventStream(
+ SubscribeToShardEvent=SubscribeToShardEvent(
+ Records=records,
+ ContinuationSequenceNumber=str(last_sequence_number),
+ MillisBehindLatest=0,
+ ChildShards=[],
+ )
+ )
+
+ return SubscribeToShardOutput(EventStream=event_generator())
+
+ def put_record(
+ self,
+ context: RequestContext,
+ data: Data,
+ partition_key: PartitionKey,
+ stream_name: StreamName = None,
+ explicit_hash_key: HashKey = None,
+ sequence_number_for_ordering: SequenceNumber = None,
+ stream_arn: StreamARN = None,
+ **kwargs,
+ ) -> PutRecordOutput:
+ # TODO: Ensure use of `stream_arn` works. Currently kinesis-mock only works with ctx request account ID and region
+ if random() < config.KINESIS_ERROR_PROBABILITY:
+ raise ProvisionedThroughputExceededException(
+ "Rate exceeded for shard X in stream Y under account Z."
+ )
+ # If "we were lucky" and the error probability didn't hit, we raise a NotImplementedError in order to
+ # trigger the fallback to kinesis-mock
+ raise NotImplementedError
+
+ def put_records(
+ self,
+ context: RequestContext,
+ records: PutRecordsRequestEntryList,
+ stream_name: StreamName = None,
+ stream_arn: StreamARN = None,
+ **kwargs,
+ ) -> PutRecordsOutput:
+ # TODO: Ensure use of `stream_arn` works. Currently kinesis-mock only works with ctx request account ID and region
+ if random() < config.KINESIS_ERROR_PROBABILITY:
+ records_count = len(records) if records is not None else 0
+ records = [
+ PutRecordsResultEntry(
+ ErrorCode="ProvisionedThroughputExceededException",
+ ErrorMessage="Rate exceeded for shard X in stream Y under account Z.",
+ )
+ ] * records_count
+ return PutRecordsOutput(FailedRecordCount=1, Records=records)
+ # If "we were lucky" and the error probability didn't hit, we raise a NotImplementedError in order to
+ # trigger the fallback to kinesis-mock
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/kinesis/resource_providers/__init__.py b/localstack-core/localstack/services/kinesis/resource_providers/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_stream.py b/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_stream.py
new file mode 100644
index 0000000000000..27d18c1ff3fe3
--- /dev/null
+++ b/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_stream.py
@@ -0,0 +1,181 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class KinesisStreamProperties(TypedDict):
+ Arn: Optional[str]
+ Name: Optional[str]
+ RetentionPeriodHours: Optional[int]
+ ShardCount: Optional[int]
+ StreamEncryption: Optional[StreamEncryption]
+ StreamModeDetails: Optional[StreamModeDetails]
+ Tags: Optional[list[Tag]]
+
+
+class StreamModeDetails(TypedDict):
+ StreamMode: Optional[str]
+
+
+class StreamEncryption(TypedDict):
+ EncryptionType: Optional[str]
+ KeyId: Optional[str]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class KinesisStreamProvider(ResourceProvider[KinesisStreamProperties]):
+ TYPE = "AWS::Kinesis::Stream" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[KinesisStreamProperties],
+ ) -> ProgressEvent[KinesisStreamProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Name
+
+
+
+ Create-only properties:
+ - /properties/Name
+
+ Read-only properties:
+ - /properties/Arn
+
+ IAM permissions required:
+ - kinesis:EnableEnhancedMonitoring
+ - kinesis:DescribeStreamSummary
+ - kinesis:CreateStream
+ - kinesis:IncreaseStreamRetentionPeriod
+ - kinesis:StartStreamEncryption
+ - kinesis:AddTagsToStream
+ - kinesis:ListTagsForStream
+
+ """
+ model = request.desired_state
+ kinesis = request.aws_client_factory.kinesis
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ if not model.get("Name"):
+ model["Name"] = util.generate_default_name(
+ stack_name=request.stack_name, logical_resource_id=request.logical_resource_id
+ )
+ if not model.get("ShardCount"):
+ model["ShardCount"] = 1
+
+ if not model.get("StreamModeDetails"):
+ model["StreamModeDetails"] = StreamModeDetails(StreamMode="ON_DEMAND")
+
+ kinesis.create_stream(
+ StreamName=model["Name"],
+ ShardCount=model["ShardCount"],
+ StreamModeDetails=model["StreamModeDetails"],
+ )
+
+ stream_data = kinesis.describe_stream(StreamName=model["Name"])["StreamDescription"]
+ model["Arn"] = stream_data["StreamARN"]
+ request.custom_context[REPEATED_INVOCATION] = True
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ stream_data = kinesis.describe_stream(StreamARN=model["Arn"])["StreamDescription"]
+ if stream_data["StreamStatus"] != "ACTIVE":
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[KinesisStreamProperties],
+ ) -> ProgressEvent[KinesisStreamProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - kinesis:DescribeStreamSummary
+ - kinesis:ListTagsForStream
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[KinesisStreamProperties],
+ ) -> ProgressEvent[KinesisStreamProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - kinesis:DescribeStreamSummary
+ - kinesis:DeleteStream
+ - kinesis:RemoveTagsFromStream
+ """
+ model = request.previous_state
+ client = request.aws_client_factory.kinesis
+
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ client.delete_stream(StreamARN=model["Arn"], EnforceConsumerDeletion=True)
+ request.custom_context[REPEATED_INVOCATION] = True
+
+ try:
+ client.describe_stream(StreamARN=model["Arn"])
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model={},
+ )
+ except client.exceptions.ResourceNotFoundException:
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model={},
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[KinesisStreamProperties],
+ ) -> ProgressEvent[KinesisStreamProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - kinesis:EnableEnhancedMonitoring
+ - kinesis:DisableEnhancedMonitoring
+ - kinesis:DescribeStreamSummary
+ - kinesis:UpdateShardCount
+ - kinesis:UpdateStreamMode
+ - kinesis:IncreaseStreamRetentionPeriod
+ - kinesis:DecreaseStreamRetentionPeriod
+ - kinesis:StartStreamEncryption
+ - kinesis:StopStreamEncryption
+ - kinesis:AddTagsToStream
+ - kinesis:RemoveTagsFromStream
+ - kinesis:ListTagsForStream
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_stream.schema.json b/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_stream.schema.json
new file mode 100644
index 0000000000000..69b6d10cfd89d
--- /dev/null
+++ b/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_stream.schema.json
@@ -0,0 +1,173 @@
+{
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-kinesis.git",
+ "handlers": {
+ "read": {
+ "permissions": [
+ "kinesis:DescribeStreamSummary",
+ "kinesis:ListTagsForStream"
+ ]
+ },
+ "create": {
+ "permissions": [
+ "kinesis:EnableEnhancedMonitoring",
+ "kinesis:DescribeStreamSummary",
+ "kinesis:CreateStream",
+ "kinesis:IncreaseStreamRetentionPeriod",
+ "kinesis:StartStreamEncryption",
+ "kinesis:AddTagsToStream",
+ "kinesis:ListTagsForStream"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "kinesis:EnableEnhancedMonitoring",
+ "kinesis:DisableEnhancedMonitoring",
+ "kinesis:DescribeStreamSummary",
+ "kinesis:UpdateShardCount",
+ "kinesis:UpdateStreamMode",
+ "kinesis:IncreaseStreamRetentionPeriod",
+ "kinesis:DecreaseStreamRetentionPeriod",
+ "kinesis:StartStreamEncryption",
+ "kinesis:StopStreamEncryption",
+ "kinesis:AddTagsToStream",
+ "kinesis:RemoveTagsFromStream",
+ "kinesis:ListTagsForStream"
+ ],
+ "timeoutInMinutes": 240
+ },
+ "list": {
+ "permissions": [
+ "kinesis:ListStreams"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "kinesis:DescribeStreamSummary",
+ "kinesis:DeleteStream",
+ "kinesis:RemoveTagsFromStream"
+ ]
+ }
+ },
+ "typeName": "AWS::Kinesis::Stream",
+ "readOnlyProperties": [
+ "/properties/Arn"
+ ],
+ "description": "Resource Type definition for AWS::Kinesis::Stream",
+ "createOnlyProperties": [
+ "/properties/Name"
+ ],
+ "additionalProperties": false,
+ "primaryIdentifier": [
+ "/properties/Name"
+ ],
+ "definitions": {
+ "StreamModeDetails": {
+ "description": "When specified, enables or updates the mode of stream. Default is PROVISIONED.",
+ "additionalProperties": false,
+ "type": "object",
+ "properties": {
+ "StreamMode": {
+ "description": "The mode of the stream",
+ "type": "string",
+ "enum": [
+ "ON_DEMAND",
+ "PROVISIONED"
+ ]
+ }
+ },
+ "required": [
+ "StreamMode"
+ ]
+ },
+ "StreamEncryption": {
+ "description": "When specified, enables or updates server-side encryption using an AWS KMS key for a specified stream. Removing this property from your stack template and updating your stack disables encryption.",
+ "additionalProperties": false,
+ "type": "object",
+ "properties": {
+ "EncryptionType": {
+ "description": "The encryption type to use. The only valid value is KMS. ",
+ "type": "string",
+ "enum": [
+ "KMS"
+ ]
+ },
+ "KeyId": {
+ "minLength": 1,
+ "description": "The GUID for the customer-managed AWS KMS key to use for encryption. This value can be a globally unique identifier, a fully specified Amazon Resource Name (ARN) to either an alias or a key, or an alias name prefixed by \"alias/\".You can also use a master key owned by Kinesis Data Streams by specifying the alias aws/kinesis.",
+ "type": "string",
+ "maxLength": 2048
+ }
+ },
+ "required": [
+ "EncryptionType",
+ "KeyId"
+ ]
+ },
+ "Tag": {
+ "description": "An arbitrary set of tags (key-value pairs) to associate with the Kinesis stream.",
+ "additionalProperties": false,
+ "type": "object",
+ "properties": {
+ "Value": {
+ "minLength": 0,
+ "description": "The value for the tag. You can specify a value that is 0 to 255 Unicode characters in length and cannot be prefixed with aws:. You can use any of the following characters: the set of Unicode letters, digits, whitespace, _, ., /, =, +, and -.",
+ "type": "string",
+ "maxLength": 255
+ },
+ "Key": {
+ "minLength": 1,
+ "description": "The key name of the tag. You can specify a value that is 1 to 128 Unicode characters in length and cannot be prefixed with aws:. You can use any of the following characters: the set of Unicode letters, digits, whitespace, _, ., /, =, +, and -.",
+ "type": "string",
+ "maxLength": 128
+ }
+ },
+ "required": [
+ "Key",
+ "Value"
+ ]
+ }
+ },
+ "properties": {
+ "StreamModeDetails": {
+ "default": {
+ "StreamMode": "PROVISIONED"
+ },
+ "description": "The mode in which the stream is running.",
+ "$ref": "#/definitions/StreamModeDetails"
+ },
+ "StreamEncryption": {
+ "description": "When specified, enables or updates server-side encryption using an AWS KMS key for a specified stream.",
+ "$ref": "#/definitions/StreamEncryption"
+ },
+ "Arn": {
+ "description": "The Amazon resource name (ARN) of the Kinesis stream",
+ "type": "string"
+ },
+ "RetentionPeriodHours": {
+ "description": "The number of hours for the data records that are stored in shards to remain accessible.",
+ "type": "integer",
+ "minimum": 24
+ },
+ "Tags": {
+ "uniqueItems": false,
+ "description": "An arbitrary set of tags (key\u2013value pairs) to associate with the Kinesis stream.",
+ "insertionOrder": false,
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "Name": {
+ "minLength": 1,
+ "pattern": "^[a-zA-Z0-9_.-]+$",
+ "description": "The name of the Kinesis stream.",
+ "type": "string",
+ "maxLength": 128
+ },
+ "ShardCount": {
+ "description": "The number of shards that the stream uses. Required when StreamMode = PROVISIONED is passed.",
+ "type": "integer",
+ "minimum": 1
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_stream_plugin.py b/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_stream_plugin.py
new file mode 100644
index 0000000000000..d7e834e7bb0bf
--- /dev/null
+++ b/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_stream_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class KinesisStreamProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::Kinesis::Stream"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.kinesis.resource_providers.aws_kinesis_stream import (
+ KinesisStreamProvider,
+ )
+
+ self.factory = KinesisStreamProvider
diff --git a/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_streamconsumer.py b/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_streamconsumer.py
new file mode 100644
index 0000000000000..3f0faee08ffda
--- /dev/null
+++ b/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_streamconsumer.py
@@ -0,0 +1,131 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class KinesisStreamConsumerProperties(TypedDict):
+ ConsumerName: Optional[str]
+ StreamARN: Optional[str]
+ ConsumerARN: Optional[str]
+ ConsumerCreationTimestamp: Optional[str]
+ ConsumerStatus: Optional[str]
+ Id: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class KinesisStreamConsumerProvider(ResourceProvider[KinesisStreamConsumerProperties]):
+ TYPE = "AWS::Kinesis::StreamConsumer" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[KinesisStreamConsumerProperties],
+ ) -> ProgressEvent[KinesisStreamConsumerProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/Id
+
+ Required properties:
+ - ConsumerName
+ - StreamARN
+
+ Create-only properties:
+ - /properties/ConsumerName
+ - /properties/StreamARN
+
+ Read-only properties:
+ - /properties/ConsumerStatus
+ - /properties/ConsumerARN
+ - /properties/ConsumerCreationTimestamp
+ - /properties/Id
+
+
+
+ """
+ model = request.desired_state
+ kinesis = request.aws_client_factory.kinesis
+
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ # this is the first time this callback is invoked
+ # TODO: idempotency
+
+ response = kinesis.register_stream_consumer(
+ StreamARN=model["StreamARN"], ConsumerName=model["ConsumerName"]
+ )
+ model["ConsumerARN"] = response["Consumer"]["ConsumerARN"]
+ model["ConsumerStatus"] = response["Consumer"]["ConsumerStatus"]
+ request.custom_context[REPEATED_INVOCATION] = True
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ response = kinesis.describe_stream_consumer(ConsumerARN=model["ConsumerARN"])
+ model["ConsumerStatus"] = response["ConsumerDescription"]["ConsumerStatus"]
+ if model["ConsumerStatus"] == "CREATING":
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[KinesisStreamConsumerProperties],
+ ) -> ProgressEvent[KinesisStreamConsumerProperties]:
+ """
+ Fetch resource information
+
+
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[KinesisStreamConsumerProperties],
+ ) -> ProgressEvent[KinesisStreamConsumerProperties]:
+ """
+ Delete a resource
+
+
+ """
+ model = request.desired_state
+ kinesis = request.aws_client_factory.kinesis
+ kinesis.deregister_stream_consumer(ConsumerARN=model["ConsumerARN"])
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[KinesisStreamConsumerProperties],
+ ) -> ProgressEvent[KinesisStreamConsumerProperties]:
+ """
+ Update a resource
+
+
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_streamconsumer.schema.json b/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_streamconsumer.schema.json
new file mode 100644
index 0000000000000..635fb10017540
--- /dev/null
+++ b/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_streamconsumer.schema.json
@@ -0,0 +1,42 @@
+{
+ "typeName": "AWS::Kinesis::StreamConsumer",
+ "description": "Resource Type definition for AWS::Kinesis::StreamConsumer",
+ "additionalProperties": false,
+ "properties": {
+ "Id": {
+ "type": "string"
+ },
+ "ConsumerCreationTimestamp": {
+ "type": "string"
+ },
+ "ConsumerName": {
+ "type": "string"
+ },
+ "ConsumerARN": {
+ "type": "string"
+ },
+ "ConsumerStatus": {
+ "type": "string"
+ },
+ "StreamARN": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "ConsumerName",
+ "StreamARN"
+ ],
+ "readOnlyProperties": [
+ "/properties/ConsumerStatus",
+ "/properties/ConsumerARN",
+ "/properties/ConsumerCreationTimestamp",
+ "/properties/Id"
+ ],
+ "createOnlyProperties": [
+ "/properties/ConsumerName",
+ "/properties/StreamARN"
+ ],
+ "primaryIdentifier": [
+ "/properties/Id"
+ ]
+}
diff --git a/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_streamconsumer_plugin.py b/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_streamconsumer_plugin.py
new file mode 100644
index 0000000000000..b1f2cab38423d
--- /dev/null
+++ b/localstack-core/localstack/services/kinesis/resource_providers/aws_kinesis_streamconsumer_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class KinesisStreamConsumerProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::Kinesis::StreamConsumer"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.kinesis.resource_providers.aws_kinesis_streamconsumer import (
+ KinesisStreamConsumerProvider,
+ )
+
+ self.factory = KinesisStreamConsumerProvider
diff --git a/localstack-core/localstack/services/kinesisfirehose/__init__.py b/localstack-core/localstack/services/kinesisfirehose/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/kinesisfirehose/resource_providers/__init__.py b/localstack-core/localstack/services/kinesisfirehose/resource_providers/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/kinesisfirehose/resource_providers/aws_kinesisfirehose_deliverystream.py b/localstack-core/localstack/services/kinesisfirehose/resource_providers/aws_kinesisfirehose_deliverystream.py
new file mode 100644
index 0000000000000..6764a783667f0
--- /dev/null
+++ b/localstack-core/localstack/services/kinesisfirehose/resource_providers/aws_kinesisfirehose_deliverystream.py
@@ -0,0 +1,496 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class KinesisFirehoseDeliveryStreamProperties(TypedDict):
+ AmazonOpenSearchServerlessDestinationConfiguration: Optional[
+ AmazonOpenSearchServerlessDestinationConfiguration
+ ]
+ AmazonopensearchserviceDestinationConfiguration: Optional[
+ AmazonopensearchserviceDestinationConfiguration
+ ]
+ Arn: Optional[str]
+ DeliveryStreamEncryptionConfigurationInput: Optional[DeliveryStreamEncryptionConfigurationInput]
+ DeliveryStreamName: Optional[str]
+ DeliveryStreamType: Optional[str]
+ ElasticsearchDestinationConfiguration: Optional[ElasticsearchDestinationConfiguration]
+ ExtendedS3DestinationConfiguration: Optional[ExtendedS3DestinationConfiguration]
+ HttpEndpointDestinationConfiguration: Optional[HttpEndpointDestinationConfiguration]
+ KinesisStreamSourceConfiguration: Optional[KinesisStreamSourceConfiguration]
+ RedshiftDestinationConfiguration: Optional[RedshiftDestinationConfiguration]
+ S3DestinationConfiguration: Optional[S3DestinationConfiguration]
+ SplunkDestinationConfiguration: Optional[SplunkDestinationConfiguration]
+ Tags: Optional[list[Tag]]
+
+
+class DeliveryStreamEncryptionConfigurationInput(TypedDict):
+ KeyType: Optional[str]
+ KeyARN: Optional[str]
+
+
+class ElasticsearchBufferingHints(TypedDict):
+ IntervalInSeconds: Optional[int]
+ SizeInMBs: Optional[int]
+
+
+class CloudWatchLoggingOptions(TypedDict):
+ Enabled: Optional[bool]
+ LogGroupName: Optional[str]
+ LogStreamName: Optional[str]
+
+
+class ProcessorParameter(TypedDict):
+ ParameterName: Optional[str]
+ ParameterValue: Optional[str]
+
+
+class Processor(TypedDict):
+ Type: Optional[str]
+ Parameters: Optional[list[ProcessorParameter]]
+
+
+class ProcessingConfiguration(TypedDict):
+ Enabled: Optional[bool]
+ Processors: Optional[list[Processor]]
+
+
+class ElasticsearchRetryOptions(TypedDict):
+ DurationInSeconds: Optional[int]
+
+
+class BufferingHints(TypedDict):
+ IntervalInSeconds: Optional[int]
+ SizeInMBs: Optional[int]
+
+
+class KMSEncryptionConfig(TypedDict):
+ AWSKMSKeyARN: Optional[str]
+
+
+class EncryptionConfiguration(TypedDict):
+ KMSEncryptionConfig: Optional[KMSEncryptionConfig]
+ NoEncryptionConfig: Optional[str]
+
+
+class S3DestinationConfiguration(TypedDict):
+ BucketARN: Optional[str]
+ RoleARN: Optional[str]
+ BufferingHints: Optional[BufferingHints]
+ CloudWatchLoggingOptions: Optional[CloudWatchLoggingOptions]
+ CompressionFormat: Optional[str]
+ EncryptionConfiguration: Optional[EncryptionConfiguration]
+ ErrorOutputPrefix: Optional[str]
+ Prefix: Optional[str]
+
+
+class VpcConfiguration(TypedDict):
+ RoleARN: Optional[str]
+ SecurityGroupIds: Optional[list[str]]
+ SubnetIds: Optional[list[str]]
+
+
+class DocumentIdOptions(TypedDict):
+ DefaultDocumentIdFormat: Optional[str]
+
+
+class ElasticsearchDestinationConfiguration(TypedDict):
+ IndexName: Optional[str]
+ RoleARN: Optional[str]
+ S3Configuration: Optional[S3DestinationConfiguration]
+ BufferingHints: Optional[ElasticsearchBufferingHints]
+ CloudWatchLoggingOptions: Optional[CloudWatchLoggingOptions]
+ ClusterEndpoint: Optional[str]
+ DocumentIdOptions: Optional[DocumentIdOptions]
+ DomainARN: Optional[str]
+ IndexRotationPeriod: Optional[str]
+ ProcessingConfiguration: Optional[ProcessingConfiguration]
+ RetryOptions: Optional[ElasticsearchRetryOptions]
+ S3BackupMode: Optional[str]
+ TypeName: Optional[str]
+ VpcConfiguration: Optional[VpcConfiguration]
+
+
+class AmazonopensearchserviceBufferingHints(TypedDict):
+ IntervalInSeconds: Optional[int]
+ SizeInMBs: Optional[int]
+
+
+class AmazonopensearchserviceRetryOptions(TypedDict):
+ DurationInSeconds: Optional[int]
+
+
+class AmazonopensearchserviceDestinationConfiguration(TypedDict):
+ IndexName: Optional[str]
+ RoleARN: Optional[str]
+ S3Configuration: Optional[S3DestinationConfiguration]
+ BufferingHints: Optional[AmazonopensearchserviceBufferingHints]
+ CloudWatchLoggingOptions: Optional[CloudWatchLoggingOptions]
+ ClusterEndpoint: Optional[str]
+ DocumentIdOptions: Optional[DocumentIdOptions]
+ DomainARN: Optional[str]
+ IndexRotationPeriod: Optional[str]
+ ProcessingConfiguration: Optional[ProcessingConfiguration]
+ RetryOptions: Optional[AmazonopensearchserviceRetryOptions]
+ S3BackupMode: Optional[str]
+ TypeName: Optional[str]
+ VpcConfiguration: Optional[VpcConfiguration]
+
+
+class AmazonOpenSearchServerlessBufferingHints(TypedDict):
+ IntervalInSeconds: Optional[int]
+ SizeInMBs: Optional[int]
+
+
+class AmazonOpenSearchServerlessRetryOptions(TypedDict):
+ DurationInSeconds: Optional[int]
+
+
+class AmazonOpenSearchServerlessDestinationConfiguration(TypedDict):
+ IndexName: Optional[str]
+ RoleARN: Optional[str]
+ S3Configuration: Optional[S3DestinationConfiguration]
+ BufferingHints: Optional[AmazonOpenSearchServerlessBufferingHints]
+ CloudWatchLoggingOptions: Optional[CloudWatchLoggingOptions]
+ CollectionEndpoint: Optional[str]
+ ProcessingConfiguration: Optional[ProcessingConfiguration]
+ RetryOptions: Optional[AmazonOpenSearchServerlessRetryOptions]
+ S3BackupMode: Optional[str]
+ VpcConfiguration: Optional[VpcConfiguration]
+
+
+class HiveJsonSerDe(TypedDict):
+ TimestampFormats: Optional[list[str]]
+
+
+class OpenXJsonSerDe(TypedDict):
+ CaseInsensitive: Optional[bool]
+ ColumnToJsonKeyMappings: Optional[dict]
+ ConvertDotsInJsonKeysToUnderscores: Optional[bool]
+
+
+class Deserializer(TypedDict):
+ HiveJsonSerDe: Optional[HiveJsonSerDe]
+ OpenXJsonSerDe: Optional[OpenXJsonSerDe]
+
+
+class InputFormatConfiguration(TypedDict):
+ Deserializer: Optional[Deserializer]
+
+
+class OrcSerDe(TypedDict):
+ BlockSizeBytes: Optional[int]
+ BloomFilterColumns: Optional[list[str]]
+ BloomFilterFalsePositiveProbability: Optional[float]
+ Compression: Optional[str]
+ DictionaryKeyThreshold: Optional[float]
+ EnablePadding: Optional[bool]
+ FormatVersion: Optional[str]
+ PaddingTolerance: Optional[float]
+ RowIndexStride: Optional[int]
+ StripeSizeBytes: Optional[int]
+
+
+class ParquetSerDe(TypedDict):
+ BlockSizeBytes: Optional[int]
+ Compression: Optional[str]
+ EnableDictionaryCompression: Optional[bool]
+ MaxPaddingBytes: Optional[int]
+ PageSizeBytes: Optional[int]
+ WriterVersion: Optional[str]
+
+
+class Serializer(TypedDict):
+ OrcSerDe: Optional[OrcSerDe]
+ ParquetSerDe: Optional[ParquetSerDe]
+
+
+class OutputFormatConfiguration(TypedDict):
+ Serializer: Optional[Serializer]
+
+
+class SchemaConfiguration(TypedDict):
+ CatalogId: Optional[str]
+ DatabaseName: Optional[str]
+ Region: Optional[str]
+ RoleARN: Optional[str]
+ TableName: Optional[str]
+ VersionId: Optional[str]
+
+
+class DataFormatConversionConfiguration(TypedDict):
+ Enabled: Optional[bool]
+ InputFormatConfiguration: Optional[InputFormatConfiguration]
+ OutputFormatConfiguration: Optional[OutputFormatConfiguration]
+ SchemaConfiguration: Optional[SchemaConfiguration]
+
+
+class RetryOptions(TypedDict):
+ DurationInSeconds: Optional[int]
+
+
+class DynamicPartitioningConfiguration(TypedDict):
+ Enabled: Optional[bool]
+ RetryOptions: Optional[RetryOptions]
+
+
+class ExtendedS3DestinationConfiguration(TypedDict):
+ BucketARN: Optional[str]
+ RoleARN: Optional[str]
+ BufferingHints: Optional[BufferingHints]
+ CloudWatchLoggingOptions: Optional[CloudWatchLoggingOptions]
+ CompressionFormat: Optional[str]
+ DataFormatConversionConfiguration: Optional[DataFormatConversionConfiguration]
+ DynamicPartitioningConfiguration: Optional[DynamicPartitioningConfiguration]
+ EncryptionConfiguration: Optional[EncryptionConfiguration]
+ ErrorOutputPrefix: Optional[str]
+ Prefix: Optional[str]
+ ProcessingConfiguration: Optional[ProcessingConfiguration]
+ S3BackupConfiguration: Optional[S3DestinationConfiguration]
+ S3BackupMode: Optional[str]
+
+
+class KinesisStreamSourceConfiguration(TypedDict):
+ KinesisStreamARN: Optional[str]
+ RoleARN: Optional[str]
+
+
+class CopyCommand(TypedDict):
+ DataTableName: Optional[str]
+ CopyOptions: Optional[str]
+ DataTableColumns: Optional[str]
+
+
+class RedshiftRetryOptions(TypedDict):
+ DurationInSeconds: Optional[int]
+
+
+class RedshiftDestinationConfiguration(TypedDict):
+ ClusterJDBCURL: Optional[str]
+ CopyCommand: Optional[CopyCommand]
+ Password: Optional[str]
+ RoleARN: Optional[str]
+ S3Configuration: Optional[S3DestinationConfiguration]
+ Username: Optional[str]
+ CloudWatchLoggingOptions: Optional[CloudWatchLoggingOptions]
+ ProcessingConfiguration: Optional[ProcessingConfiguration]
+ RetryOptions: Optional[RedshiftRetryOptions]
+ S3BackupConfiguration: Optional[S3DestinationConfiguration]
+ S3BackupMode: Optional[str]
+
+
+class SplunkRetryOptions(TypedDict):
+ DurationInSeconds: Optional[int]
+
+
+class SplunkDestinationConfiguration(TypedDict):
+ HECEndpoint: Optional[str]
+ HECEndpointType: Optional[str]
+ HECToken: Optional[str]
+ S3Configuration: Optional[S3DestinationConfiguration]
+ CloudWatchLoggingOptions: Optional[CloudWatchLoggingOptions]
+ HECAcknowledgmentTimeoutInSeconds: Optional[int]
+ ProcessingConfiguration: Optional[ProcessingConfiguration]
+ RetryOptions: Optional[SplunkRetryOptions]
+ S3BackupMode: Optional[str]
+
+
+class HttpEndpointConfiguration(TypedDict):
+ Url: Optional[str]
+ AccessKey: Optional[str]
+ Name: Optional[str]
+
+
+class HttpEndpointCommonAttribute(TypedDict):
+ AttributeName: Optional[str]
+ AttributeValue: Optional[str]
+
+
+class HttpEndpointRequestConfiguration(TypedDict):
+ CommonAttributes: Optional[list[HttpEndpointCommonAttribute]]
+ ContentEncoding: Optional[str]
+
+
+class HttpEndpointDestinationConfiguration(TypedDict):
+ EndpointConfiguration: Optional[HttpEndpointConfiguration]
+ S3Configuration: Optional[S3DestinationConfiguration]
+ BufferingHints: Optional[BufferingHints]
+ CloudWatchLoggingOptions: Optional[CloudWatchLoggingOptions]
+ ProcessingConfiguration: Optional[ProcessingConfiguration]
+ RequestConfiguration: Optional[HttpEndpointRequestConfiguration]
+ RetryOptions: Optional[RetryOptions]
+ RoleARN: Optional[str]
+ S3BackupMode: Optional[str]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class KinesisFirehoseDeliveryStreamProvider(
+ ResourceProvider[KinesisFirehoseDeliveryStreamProperties]
+):
+ TYPE = "AWS::KinesisFirehose::DeliveryStream" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[KinesisFirehoseDeliveryStreamProperties],
+ ) -> ProgressEvent[KinesisFirehoseDeliveryStreamProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/DeliveryStreamName
+
+
+
+ Create-only properties:
+ - /properties/DeliveryStreamName
+ - /properties/DeliveryStreamType
+ - /properties/ElasticsearchDestinationConfiguration/VpcConfiguration
+ - /properties/AmazonopensearchserviceDestinationConfiguration/VpcConfiguration
+ - /properties/AmazonOpenSearchServerlessDestinationConfiguration/VpcConfiguration
+ - /properties/KinesisStreamSourceConfiguration
+
+ Read-only properties:
+ - /properties/Arn
+
+ IAM permissions required:
+ - firehose:CreateDeliveryStream
+ - firehose:DescribeDeliveryStream
+ - iam:GetRole
+ - iam:PassRole
+ - kms:CreateGrant
+ - kms:DescribeKey
+
+ """
+ model = request.desired_state
+ firehose = request.aws_client_factory.firehose
+ parameters = [
+ "DeliveryStreamName",
+ "DeliveryStreamType",
+ "S3DestinationConfiguration",
+ "ElasticsearchDestinationConfiguration",
+ "AmazonopensearchserviceDestinationConfiguration",
+ "DeliveryStreamEncryptionConfigurationInput",
+ "ExtendedS3DestinationConfiguration",
+ "HttpEndpointDestinationConfiguration",
+ "KinesisStreamSourceConfiguration",
+ "RedshiftDestinationConfiguration",
+ "SplunkDestinationConfiguration",
+ "Tags",
+ ]
+ attrs = util.select_attributes(model, params=parameters)
+ if not attrs.get("DeliveryStreamName"):
+ attrs["DeliveryStreamName"] = util.generate_default_name(
+ request.stack_name, request.logical_resource_id
+ )
+
+ if not request.custom_context.get(REPEATED_INVOCATION):
+ response = firehose.create_delivery_stream(**attrs)
+ # TODO: defaults
+ # TODO: idempotency
+ model["Arn"] = response["DeliveryStreamARN"]
+ request.custom_context[REPEATED_INVOCATION] = True
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+ # TODO add handler for CREATE FAILED state
+ stream = firehose.describe_delivery_stream(DeliveryStreamName=model["DeliveryStreamName"])
+ if stream["DeliveryStreamDescription"]["DeliveryStreamStatus"] != "ACTIVE":
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[KinesisFirehoseDeliveryStreamProperties],
+ ) -> ProgressEvent[KinesisFirehoseDeliveryStreamProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - firehose:DescribeDeliveryStream
+ - firehose:ListTagsForDeliveryStream
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[KinesisFirehoseDeliveryStreamProperties],
+ ) -> ProgressEvent[KinesisFirehoseDeliveryStreamProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - firehose:DeleteDeliveryStream
+ - firehose:DescribeDeliveryStream
+ - kms:RevokeGrant
+ - kms:DescribeKey
+ """
+ model = request.desired_state
+ firehose = request.aws_client_factory.firehose
+ try:
+ stream = firehose.describe_delivery_stream(
+ DeliveryStreamName=model["DeliveryStreamName"]
+ )
+ except request.aws_client_factory.firehose.exceptions.ResourceNotFoundException:
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ if stream["DeliveryStreamDescription"]["DeliveryStreamStatus"] != "DELETING":
+ firehose.delete_delivery_stream(DeliveryStreamName=model["DeliveryStreamName"])
+ return ProgressEvent(
+ status=OperationStatus.IN_PROGRESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[KinesisFirehoseDeliveryStreamProperties],
+ ) -> ProgressEvent[KinesisFirehoseDeliveryStreamProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - firehose:UpdateDestination
+ - firehose:DescribeDeliveryStream
+ - firehose:StartDeliveryStreamEncryption
+ - firehose:StopDeliveryStreamEncryption
+ - firehose:ListTagsForDeliveryStream
+ - firehose:TagDeliveryStream
+ - firehose:UntagDeliveryStream
+ - kms:CreateGrant
+ - kms:RevokeGrant
+ - kms:DescribeKey
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/kinesisfirehose/resource_providers/aws_kinesisfirehose_deliverystream.schema.json b/localstack-core/localstack/services/kinesisfirehose/resource_providers/aws_kinesisfirehose_deliverystream.schema.json
new file mode 100644
index 0000000000000..939b5c7bd35d2
--- /dev/null
+++ b/localstack-core/localstack/services/kinesisfirehose/resource_providers/aws_kinesisfirehose_deliverystream.schema.json
@@ -0,0 +1,1205 @@
+{
+ "typeName": "AWS::KinesisFirehose::DeliveryStream",
+ "description": "Resource Type definition for AWS::KinesisFirehose::DeliveryStream",
+ "additionalProperties": false,
+ "properties": {
+ "Arn": {
+ "type": "string"
+ },
+ "DeliveryStreamEncryptionConfigurationInput": {
+ "$ref": "#/definitions/DeliveryStreamEncryptionConfigurationInput"
+ },
+ "DeliveryStreamName": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 64,
+ "pattern": "[a-zA-Z0-9._-]+"
+ },
+ "DeliveryStreamType": {
+ "type": "string",
+ "enum": [
+ "DirectPut",
+ "KinesisStreamAsSource"
+ ]
+ },
+ "ElasticsearchDestinationConfiguration": {
+ "$ref": "#/definitions/ElasticsearchDestinationConfiguration"
+ },
+ "AmazonopensearchserviceDestinationConfiguration": {
+ "$ref": "#/definitions/AmazonopensearchserviceDestinationConfiguration"
+ },
+ "AmazonOpenSearchServerlessDestinationConfiguration": {
+ "$ref": "#/definitions/AmazonOpenSearchServerlessDestinationConfiguration"
+ },
+ "ExtendedS3DestinationConfiguration": {
+ "$ref": "#/definitions/ExtendedS3DestinationConfiguration"
+ },
+ "KinesisStreamSourceConfiguration": {
+ "$ref": "#/definitions/KinesisStreamSourceConfiguration"
+ },
+ "RedshiftDestinationConfiguration": {
+ "$ref": "#/definitions/RedshiftDestinationConfiguration"
+ },
+ "S3DestinationConfiguration": {
+ "$ref": "#/definitions/S3DestinationConfiguration"
+ },
+ "SplunkDestinationConfiguration": {
+ "$ref": "#/definitions/SplunkDestinationConfiguration"
+ },
+ "HttpEndpointDestinationConfiguration": {
+ "$ref": "#/definitions/HttpEndpointDestinationConfiguration"
+ },
+ "Tags": {
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/Tag"
+ },
+ "minItems": 1,
+ "maxItems": 50
+ }
+ },
+ "definitions": {
+ "DeliveryStreamEncryptionConfigurationInput": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "KeyARN": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "arn:.*"
+ },
+ "KeyType": {
+ "type": "string",
+ "enum": [
+ "AWS_OWNED_CMK",
+ "CUSTOMER_MANAGED_CMK"
+ ]
+ }
+ },
+ "required": [
+ "KeyType"
+ ]
+ },
+ "SplunkDestinationConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "CloudWatchLoggingOptions": {
+ "$ref": "#/definitions/CloudWatchLoggingOptions"
+ },
+ "HECAcknowledgmentTimeoutInSeconds": {
+ "type": "integer",
+ "minimum": 180,
+ "maximum": 600
+ },
+ "HECEndpoint": {
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 2048
+ },
+ "HECEndpointType": {
+ "type": "string",
+ "enum": [
+ "Raw",
+ "Event"
+ ]
+ },
+ "HECToken": {
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 2048
+ },
+ "ProcessingConfiguration": {
+ "$ref": "#/definitions/ProcessingConfiguration"
+ },
+ "RetryOptions": {
+ "$ref": "#/definitions/SplunkRetryOptions"
+ },
+ "S3BackupMode": {
+ "type": "string"
+ },
+ "S3Configuration": {
+ "$ref": "#/definitions/S3DestinationConfiguration"
+ }
+ },
+ "required": [
+ "HECEndpoint",
+ "S3Configuration",
+ "HECToken",
+ "HECEndpointType"
+ ]
+ },
+ "HttpEndpointDestinationConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "RoleARN": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "arn:.*"
+ },
+ "EndpointConfiguration": {
+ "$ref": "#/definitions/HttpEndpointConfiguration"
+ },
+ "RequestConfiguration": {
+ "$ref": "#/definitions/HttpEndpointRequestConfiguration"
+ },
+ "BufferingHints": {
+ "$ref": "#/definitions/BufferingHints"
+ },
+ "CloudWatchLoggingOptions": {
+ "$ref": "#/definitions/CloudWatchLoggingOptions"
+ },
+ "ProcessingConfiguration": {
+ "$ref": "#/definitions/ProcessingConfiguration"
+ },
+ "RetryOptions": {
+ "$ref": "#/definitions/RetryOptions"
+ },
+ "S3BackupMode": {
+ "type": "string"
+ },
+ "S3Configuration": {
+ "$ref": "#/definitions/S3DestinationConfiguration"
+ }
+ },
+ "required": [
+ "EndpointConfiguration",
+ "S3Configuration"
+ ]
+ },
+ "KinesisStreamSourceConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "KinesisStreamARN": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "arn:.*"
+ },
+ "RoleARN": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "arn:.*"
+ }
+ },
+ "required": [
+ "RoleARN",
+ "KinesisStreamARN"
+ ]
+ },
+ "VpcConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "RoleARN": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "arn:.*"
+ },
+ "SubnetIds": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 1024
+ },
+ "minItems": 1,
+ "maxItems": 16
+ },
+ "SecurityGroupIds": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 1024
+ },
+ "minItems": 1,
+ "maxItems": 5
+ }
+ },
+ "required": [
+ "RoleARN",
+ "SubnetIds",
+ "SecurityGroupIds"
+ ]
+ },
+ "DocumentIdOptions": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "DefaultDocumentIdFormat": {
+ "type": "string",
+ "enum": [
+ "FIREHOSE_DEFAULT",
+ "NO_DOCUMENT_ID"
+ ]
+ }
+ },
+ "required": [
+ "DefaultDocumentIdFormat"
+ ]
+ },
+ "ExtendedS3DestinationConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "BucketARN": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 2048,
+ "pattern": "arn:.*"
+ },
+ "BufferingHints": {
+ "$ref": "#/definitions/BufferingHints"
+ },
+ "CloudWatchLoggingOptions": {
+ "$ref": "#/definitions/CloudWatchLoggingOptions"
+ },
+ "CompressionFormat": {
+ "type": "string",
+ "enum": [
+ "UNCOMPRESSED",
+ "GZIP",
+ "ZIP",
+ "Snappy",
+ "HADOOP_SNAPPY"
+ ]
+ },
+ "DataFormatConversionConfiguration": {
+ "$ref": "#/definitions/DataFormatConversionConfiguration"
+ },
+ "DynamicPartitioningConfiguration": {
+ "$ref": "#/definitions/DynamicPartitioningConfiguration"
+ },
+ "EncryptionConfiguration": {
+ "$ref": "#/definitions/EncryptionConfiguration"
+ },
+ "ErrorOutputPrefix": {
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 1024
+ },
+ "Prefix": {
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 1024
+ },
+ "ProcessingConfiguration": {
+ "$ref": "#/definitions/ProcessingConfiguration"
+ },
+ "RoleARN": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "arn:.*"
+ },
+ "S3BackupConfiguration": {
+ "$ref": "#/definitions/S3DestinationConfiguration"
+ },
+ "S3BackupMode": {
+ "type": "string",
+ "enum": [
+ "Disabled",
+ "Enabled"
+ ]
+ }
+ },
+ "required": [
+ "BucketARN",
+ "RoleARN"
+ ]
+ },
+ "S3DestinationConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "BucketARN": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 2048,
+ "pattern": "arn:.*"
+ },
+ "BufferingHints": {
+ "$ref": "#/definitions/BufferingHints"
+ },
+ "CloudWatchLoggingOptions": {
+ "$ref": "#/definitions/CloudWatchLoggingOptions"
+ },
+ "CompressionFormat": {
+ "type": "string",
+ "enum": [
+ "UNCOMPRESSED",
+ "GZIP",
+ "ZIP",
+ "Snappy",
+ "HADOOP_SNAPPY"
+ ]
+ },
+ "EncryptionConfiguration": {
+ "$ref": "#/definitions/EncryptionConfiguration"
+ },
+ "ErrorOutputPrefix": {
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 1024
+ },
+ "Prefix": {
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 1024
+ },
+ "RoleARN": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "arn:.*"
+ }
+ },
+ "required": [
+ "BucketARN",
+ "RoleARN"
+ ]
+ },
+ "RedshiftDestinationConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "CloudWatchLoggingOptions": {
+ "$ref": "#/definitions/CloudWatchLoggingOptions"
+ },
+ "ClusterJDBCURL": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512
+ },
+ "CopyCommand": {
+ "$ref": "#/definitions/CopyCommand"
+ },
+ "Password": {
+ "type": "string",
+ "minLength": 6,
+ "maxLength": 512
+ },
+ "ProcessingConfiguration": {
+ "$ref": "#/definitions/ProcessingConfiguration"
+ },
+ "RetryOptions": {
+ "$ref": "#/definitions/RedshiftRetryOptions"
+ },
+ "RoleARN": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "arn:.*"
+ },
+ "S3BackupConfiguration": {
+ "$ref": "#/definitions/S3DestinationConfiguration"
+ },
+ "S3BackupMode": {
+ "type": "string",
+ "enum": [
+ "Disabled",
+ "Enabled"
+ ]
+ },
+ "S3Configuration": {
+ "$ref": "#/definitions/S3DestinationConfiguration"
+ },
+ "Username": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512
+ }
+ },
+ "required": [
+ "S3Configuration",
+ "Username",
+ "ClusterJDBCURL",
+ "CopyCommand",
+ "RoleARN",
+ "Password"
+ ]
+ },
+ "ElasticsearchDestinationConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "BufferingHints": {
+ "$ref": "#/definitions/ElasticsearchBufferingHints"
+ },
+ "CloudWatchLoggingOptions": {
+ "$ref": "#/definitions/CloudWatchLoggingOptions"
+ },
+ "DomainARN": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "arn:.*"
+ },
+ "IndexName": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 80
+ },
+ "IndexRotationPeriod": {
+ "type": "string",
+ "enum": [
+ "NoRotation",
+ "OneHour",
+ "OneDay",
+ "OneWeek",
+ "OneMonth"
+ ]
+ },
+ "ProcessingConfiguration": {
+ "$ref": "#/definitions/ProcessingConfiguration"
+ },
+ "RetryOptions": {
+ "$ref": "#/definitions/ElasticsearchRetryOptions"
+ },
+ "RoleARN": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "arn:.*"
+ },
+ "S3BackupMode": {
+ "type": "string",
+ "enum": [
+ "FailedDocumentsOnly",
+ "AllDocuments"
+ ]
+ },
+ "S3Configuration": {
+ "$ref": "#/definitions/S3DestinationConfiguration"
+ },
+ "ClusterEndpoint": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "https:.*"
+ },
+ "TypeName": {
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 100
+ },
+ "VpcConfiguration": {
+ "$ref": "#/definitions/VpcConfiguration"
+ },
+ "DocumentIdOptions": {
+ "$ref": "#/definitions/DocumentIdOptions"
+ }
+ },
+ "required": [
+ "IndexName",
+ "S3Configuration",
+ "RoleARN"
+ ]
+ },
+ "AmazonopensearchserviceDestinationConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "BufferingHints": {
+ "$ref": "#/definitions/AmazonopensearchserviceBufferingHints"
+ },
+ "CloudWatchLoggingOptions": {
+ "$ref": "#/definitions/CloudWatchLoggingOptions"
+ },
+ "DomainARN": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "arn:.*"
+ },
+ "IndexName": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 80
+ },
+ "IndexRotationPeriod": {
+ "type": "string",
+ "enum": [
+ "NoRotation",
+ "OneHour",
+ "OneDay",
+ "OneWeek",
+ "OneMonth"
+ ]
+ },
+ "ProcessingConfiguration": {
+ "$ref": "#/definitions/ProcessingConfiguration"
+ },
+ "RetryOptions": {
+ "$ref": "#/definitions/AmazonopensearchserviceRetryOptions"
+ },
+ "RoleARN": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "arn:.*"
+ },
+ "S3BackupMode": {
+ "type": "string",
+ "enum": [
+ "FailedDocumentsOnly",
+ "AllDocuments"
+ ]
+ },
+ "S3Configuration": {
+ "$ref": "#/definitions/S3DestinationConfiguration"
+ },
+ "ClusterEndpoint": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "https:.*"
+ },
+ "TypeName": {
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 100
+ },
+ "VpcConfiguration": {
+ "$ref": "#/definitions/VpcConfiguration"
+ },
+ "DocumentIdOptions": {
+ "$ref": "#/definitions/DocumentIdOptions"
+ }
+ },
+ "required": [
+ "IndexName",
+ "S3Configuration",
+ "RoleARN"
+ ]
+ },
+ "AmazonOpenSearchServerlessDestinationConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "BufferingHints": {
+ "$ref": "#/definitions/AmazonOpenSearchServerlessBufferingHints"
+ },
+ "CloudWatchLoggingOptions": {
+ "$ref": "#/definitions/CloudWatchLoggingOptions"
+ },
+ "IndexName": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 80
+ },
+ "ProcessingConfiguration": {
+ "$ref": "#/definitions/ProcessingConfiguration"
+ },
+ "RetryOptions": {
+ "$ref": "#/definitions/AmazonOpenSearchServerlessRetryOptions"
+ },
+ "RoleARN": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "arn:.*"
+ },
+ "S3BackupMode": {
+ "type": "string",
+ "enum": [
+ "FailedDocumentsOnly",
+ "AllDocuments"
+ ]
+ },
+ "S3Configuration": {
+ "$ref": "#/definitions/S3DestinationConfiguration"
+ },
+ "CollectionEndpoint": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "https:.*"
+ },
+ "VpcConfiguration": {
+ "$ref": "#/definitions/VpcConfiguration"
+ }
+ },
+ "required": [
+ "IndexName",
+ "S3Configuration",
+ "RoleARN"
+ ]
+ },
+ "BufferingHints": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "IntervalInSeconds": {
+ "type": "integer"
+ },
+ "SizeInMBs": {
+ "type": "integer"
+ }
+ }
+ },
+ "ProcessingConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Enabled": {
+ "type": "boolean"
+ },
+ "Processors": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/Processor"
+ }
+ }
+ }
+ },
+ "SplunkRetryOptions": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "DurationInSeconds": {
+ "type": "integer"
+ }
+ }
+ },
+ "ElasticsearchRetryOptions": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "DurationInSeconds": {
+ "type": "integer"
+ }
+ }
+ },
+ "AmazonopensearchserviceRetryOptions": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "DurationInSeconds": {
+ "type": "integer"
+ }
+ }
+ },
+ "AmazonOpenSearchServerlessRetryOptions": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "DurationInSeconds": {
+ "type": "integer"
+ }
+ }
+ },
+ "RedshiftRetryOptions": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "DurationInSeconds": {
+ "type": "integer"
+ }
+ }
+ },
+ "RetryOptions": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "DurationInSeconds": {
+ "type": "integer"
+ }
+ }
+ },
+ "DataFormatConversionConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Enabled": {
+ "type": "boolean"
+ },
+ "InputFormatConfiguration": {
+ "$ref": "#/definitions/InputFormatConfiguration"
+ },
+ "OutputFormatConfiguration": {
+ "$ref": "#/definitions/OutputFormatConfiguration"
+ },
+ "SchemaConfiguration": {
+ "$ref": "#/definitions/SchemaConfiguration"
+ }
+ }
+ },
+ "DynamicPartitioningConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Enabled": {
+ "type": "boolean"
+ },
+ "RetryOptions": {
+ "$ref": "#/definitions/RetryOptions"
+ }
+ }
+ },
+ "CopyCommand": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "CopyOptions": {
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 204800
+ },
+ "DataTableColumns": {
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 204800
+ },
+ "DataTableName": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512
+ }
+ },
+ "required": [
+ "DataTableName"
+ ]
+ },
+ "EncryptionConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "KMSEncryptionConfig": {
+ "$ref": "#/definitions/KMSEncryptionConfig"
+ },
+ "NoEncryptionConfig": {
+ "type": "string",
+ "enum": [
+ "NoEncryption"
+ ]
+ }
+ }
+ },
+ "ElasticsearchBufferingHints": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "IntervalInSeconds": {
+ "type": "integer"
+ },
+ "SizeInMBs": {
+ "type": "integer"
+ }
+ }
+ },
+ "AmazonopensearchserviceBufferingHints": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "IntervalInSeconds": {
+ "type": "integer"
+ },
+ "SizeInMBs": {
+ "type": "integer"
+ }
+ }
+ },
+ "AmazonOpenSearchServerlessBufferingHints": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "IntervalInSeconds": {
+ "type": "integer"
+ },
+ "SizeInMBs": {
+ "type": "integer"
+ }
+ }
+ },
+ "CloudWatchLoggingOptions": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Enabled": {
+ "type": "boolean"
+ },
+ "LogGroupName": {
+ "type": "string"
+ },
+ "LogStreamName": {
+ "type": "string"
+ }
+ }
+ },
+ "OutputFormatConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Serializer": {
+ "$ref": "#/definitions/Serializer"
+ }
+ }
+ },
+ "Processor": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Parameters": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/ProcessorParameter"
+ }
+ },
+ "Type": {
+ "type": "string",
+ "enum": [
+ "RecordDeAggregation",
+ "Lambda",
+ "MetadataExtraction",
+ "AppendDelimiterToRecord"
+ ]
+ }
+ },
+ "required": [
+ "Type"
+ ]
+ },
+ "KMSEncryptionConfig": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "AWSKMSKeyARN": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "AWSKMSKeyARN"
+ ]
+ },
+ "InputFormatConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Deserializer": {
+ "$ref": "#/definitions/Deserializer"
+ }
+ }
+ },
+ "SchemaConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "CatalogId": {
+ "type": "string"
+ },
+ "DatabaseName": {
+ "type": "string"
+ },
+ "Region": {
+ "type": "string"
+ },
+ "RoleARN": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 512,
+ "pattern": "arn:.*"
+ },
+ "TableName": {
+ "type": "string"
+ },
+ "VersionId": {
+ "type": "string"
+ }
+ }
+ },
+ "Serializer": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "OrcSerDe": {
+ "$ref": "#/definitions/OrcSerDe"
+ },
+ "ParquetSerDe": {
+ "$ref": "#/definitions/ParquetSerDe"
+ }
+ }
+ },
+ "ProcessorParameter": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "ParameterName": {
+ "type": "string"
+ },
+ "ParameterValue": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "ParameterValue",
+ "ParameterName"
+ ]
+ },
+ "Deserializer": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "HiveJsonSerDe": {
+ "$ref": "#/definitions/HiveJsonSerDe"
+ },
+ "OpenXJsonSerDe": {
+ "$ref": "#/definitions/OpenXJsonSerDe"
+ }
+ }
+ },
+ "HiveJsonSerDe": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "TimestampFormats": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ }
+ }
+ },
+ "OrcSerDe": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "BlockSizeBytes": {
+ "type": "integer"
+ },
+ "BloomFilterColumns": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "type": "string"
+ }
+ },
+ "BloomFilterFalsePositiveProbability": {
+ "type": "number"
+ },
+ "Compression": {
+ "type": "string"
+ },
+ "DictionaryKeyThreshold": {
+ "type": "number"
+ },
+ "EnablePadding": {
+ "type": "boolean"
+ },
+ "FormatVersion": {
+ "type": "string"
+ },
+ "PaddingTolerance": {
+ "type": "number"
+ },
+ "RowIndexStride": {
+ "type": "integer"
+ },
+ "StripeSizeBytes": {
+ "type": "integer"
+ }
+ }
+ },
+ "ParquetSerDe": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "BlockSizeBytes": {
+ "type": "integer"
+ },
+ "Compression": {
+ "type": "string"
+ },
+ "EnableDictionaryCompression": {
+ "type": "boolean"
+ },
+ "MaxPaddingBytes": {
+ "type": "integer"
+ },
+ "PageSizeBytes": {
+ "type": "integer"
+ },
+ "WriterVersion": {
+ "type": "string"
+ }
+ }
+ },
+ "OpenXJsonSerDe": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "CaseInsensitive": {
+ "type": "boolean"
+ },
+ "ColumnToJsonKeyMappings": {
+ "type": "object",
+ "patternProperties": {
+ "[a-zA-Z0-9]+": {
+ "type": "string"
+ }
+ }
+ },
+ "ConvertDotsInJsonKeysToUnderscores": {
+ "type": "boolean"
+ }
+ }
+ },
+ "HttpEndpointRequestConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "ContentEncoding": {
+ "type": "string",
+ "enum": [
+ "NONE",
+ "GZIP"
+ ]
+ },
+ "CommonAttributes": {
+ "type": "array",
+ "uniqueItems": true,
+ "items": {
+ "$ref": "#/definitions/HttpEndpointCommonAttribute"
+ },
+ "minItems": 0,
+ "maxItems": 50
+ }
+ }
+ },
+ "HttpEndpointCommonAttribute": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "AttributeName": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 256
+ },
+ "AttributeValue": {
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 1024
+ }
+ },
+ "required": [
+ "AttributeName",
+ "AttributeValue"
+ ]
+ },
+ "HttpEndpointConfiguration": {
+ "type": "object",
+ "additionalProperties": false,
+ "properties": {
+ "Url": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 1000
+ },
+ "AccessKey": {
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 4096
+ },
+ "Name": {
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 256
+ }
+ },
+ "required": [
+ "Url"
+ ]
+ },
+ "Tag": {
+ "type": "object",
+ "properties": {
+ "Key": {
+ "type": "string",
+ "pattern": "^(?!aws:)[\\p{L}\\p{Z}\\p{N}_.:\\/=+\\-@%]*$",
+ "minLength": 1,
+ "maxLength": 128
+ },
+ "Value": {
+ "type": "string",
+ "pattern": "^[\\p{L}\\p{Z}\\p{N}_.:\\/=+\\-@%]*$",
+ "minLength": 0,
+ "maxLength": 256
+ }
+ },
+ "required": [
+ "Key"
+ ]
+ }
+ },
+ "handlers": {
+ "create": {
+ "permissions": [
+ "firehose:CreateDeliveryStream",
+ "firehose:DescribeDeliveryStream",
+ "iam:GetRole",
+ "iam:PassRole",
+ "kms:CreateGrant",
+ "kms:DescribeKey"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "firehose:DescribeDeliveryStream",
+ "firehose:ListTagsForDeliveryStream"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "firehose:UpdateDestination",
+ "firehose:DescribeDeliveryStream",
+ "firehose:StartDeliveryStreamEncryption",
+ "firehose:StopDeliveryStreamEncryption",
+ "firehose:ListTagsForDeliveryStream",
+ "firehose:TagDeliveryStream",
+ "firehose:UntagDeliveryStream",
+ "kms:CreateGrant",
+ "kms:RevokeGrant",
+ "kms:DescribeKey"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "firehose:DeleteDeliveryStream",
+ "firehose:DescribeDeliveryStream",
+ "kms:RevokeGrant",
+ "kms:DescribeKey"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "firehose:ListDeliveryStreams"
+ ]
+ }
+ },
+ "readOnlyProperties": [
+ "/properties/Arn"
+ ],
+ "createOnlyProperties": [
+ "/properties/DeliveryStreamName",
+ "/properties/DeliveryStreamType",
+ "/properties/ElasticsearchDestinationConfiguration/VpcConfiguration",
+ "/properties/AmazonopensearchserviceDestinationConfiguration/VpcConfiguration",
+ "/properties/AmazonOpenSearchServerlessDestinationConfiguration/VpcConfiguration",
+ "/properties/KinesisStreamSourceConfiguration"
+ ],
+ "primaryIdentifier": [
+ "/properties/DeliveryStreamName"
+ ]
+}
diff --git a/localstack-core/localstack/services/kinesisfirehose/resource_providers/aws_kinesisfirehose_deliverystream_plugin.py b/localstack-core/localstack/services/kinesisfirehose/resource_providers/aws_kinesisfirehose_deliverystream_plugin.py
new file mode 100644
index 0000000000000..772007e6ce18d
--- /dev/null
+++ b/localstack-core/localstack/services/kinesisfirehose/resource_providers/aws_kinesisfirehose_deliverystream_plugin.py
@@ -0,0 +1,20 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class KinesisFirehoseDeliveryStreamProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::KinesisFirehose::DeliveryStream"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.kinesisfirehose.resource_providers.aws_kinesisfirehose_deliverystream import (
+ KinesisFirehoseDeliveryStreamProvider,
+ )
+
+ self.factory = KinesisFirehoseDeliveryStreamProvider
diff --git a/localstack-core/localstack/services/kms/__init__.py b/localstack-core/localstack/services/kms/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/kms/exceptions.py b/localstack-core/localstack/services/kms/exceptions.py
new file mode 100644
index 0000000000000..ad157c5d85c4a
--- /dev/null
+++ b/localstack-core/localstack/services/kms/exceptions.py
@@ -0,0 +1,16 @@
+from localstack.aws.api import CommonServiceException
+
+
+class ValidationException(CommonServiceException):
+ def __init__(self, message: str):
+ super().__init__("ValidationException", message, 400, True)
+
+
+class AccessDeniedException(CommonServiceException):
+ def __init__(self, message: str):
+ super().__init__("AccessDeniedException", message, 400, True)
+
+
+class TagException(CommonServiceException):
+ def __init__(self, message=None):
+ super().__init__("TagException", status_code=400, message=message)
diff --git a/localstack-core/localstack/services/kms/models.py b/localstack-core/localstack/services/kms/models.py
new file mode 100644
index 0000000000000..e39f435f77660
--- /dev/null
+++ b/localstack-core/localstack/services/kms/models.py
@@ -0,0 +1,788 @@
+import base64
+import datetime
+import io
+import json
+import logging
+import os
+import random
+import re
+import struct
+import uuid
+from collections import namedtuple
+from dataclasses import dataclass
+from typing import Dict, Optional, Tuple
+
+from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import hashes, hmac
+from cryptography.hazmat.primitives import serialization as crypto_serialization
+from cryptography.hazmat.primitives.asymmetric import ec, padding, rsa, utils
+from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
+from cryptography.hazmat.primitives.asymmetric.padding import PSS, PKCS1v15
+from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
+from cryptography.hazmat.primitives.asymmetric.utils import Prehashed
+from cryptography.hazmat.primitives.kdf.hkdf import HKDF
+from cryptography.hazmat.primitives.serialization import load_der_public_key
+
+from localstack.aws.api.kms import (
+ CreateAliasRequest,
+ CreateGrantRequest,
+ CreateKeyRequest,
+ EncryptionContextType,
+ InvalidKeyUsageException,
+ KeyMetadata,
+ KeySpec,
+ KeyState,
+ KeyUsageType,
+ KMSInvalidMacException,
+ KMSInvalidSignatureException,
+ MacAlgorithmSpec,
+ MessageType,
+ MultiRegionConfiguration,
+ MultiRegionKey,
+ MultiRegionKeyType,
+ OriginType,
+ ReplicateKeyRequest,
+ SigningAlgorithmSpec,
+ TagList,
+ UnsupportedOperationException,
+)
+from localstack.constants import TAG_KEY_CUSTOM_ID
+from localstack.services.kms.exceptions import TagException, ValidationException
+from localstack.services.kms.utils import is_valid_key_arn, validate_tag
+from localstack.services.stores import AccountRegionBundle, BaseStore, LocalAttribute
+from localstack.utils.aws.arns import get_partition, kms_alias_arn, kms_key_arn
+from localstack.utils.crypto import decrypt, encrypt
+from localstack.utils.strings import long_uid, to_bytes, to_str
+
+LOG = logging.getLogger(__name__)
+
+PATTERN_UUID = re.compile(
+ r"^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$"
+)
+MULTI_REGION_PATTERN = re.compile(r"^mrk-[a-fA-F0-9]{32}$")
+
+SYMMETRIC_DEFAULT_MATERIAL_LENGTH = 32
+
+RSA_CRYPTO_KEY_LENGTHS = {
+ "RSA_2048": 2048,
+ "RSA_3072": 3072,
+ "RSA_4096": 4096,
+}
+
+ECC_CURVES = {
+ "ECC_NIST_P256": ec.SECP256R1(),
+ "ECC_NIST_P384": ec.SECP384R1(),
+ "ECC_NIST_P521": ec.SECP521R1(),
+ "ECC_SECG_P256K1": ec.SECP256K1(),
+}
+
+HMAC_RANGE_KEY_LENGTHS = {
+ "HMAC_224": (28, 64),
+ "HMAC_256": (32, 64),
+ "HMAC_384": (48, 128),
+ "HMAC_512": (64, 128),
+}
+
+KEY_ID_LEN = 36
+# Moto uses IV_LEN of 12, as it is fine for GCM encryption mode, but we use CBC, so have to set it to 16.
+IV_LEN = 16
+TAG_LEN = 16
+CIPHERTEXT_HEADER_FORMAT = ">{key_id_len}s{iv_len}s{tag_len}s".format(
+ key_id_len=KEY_ID_LEN, iv_len=IV_LEN, tag_len=TAG_LEN
+)
+HEADER_LEN = KEY_ID_LEN + IV_LEN + TAG_LEN
+Ciphertext = namedtuple("Ciphertext", ("key_id", "iv", "ciphertext", "tag"))
+
+RESERVED_ALIASES = [
+ "alias/aws/acm",
+ "alias/aws/dynamodb",
+ "alias/aws/ebs",
+ "alias/aws/elasticfilesystem",
+ "alias/aws/es",
+ "alias/aws/glue",
+ "alias/aws/kinesisvideo",
+ "alias/aws/lambda",
+ "alias/aws/rds",
+ "alias/aws/redshift",
+ "alias/aws/s3",
+ "alias/aws/secretsmanager",
+ "alias/aws/ssm",
+ "alias/aws/xray",
+]
+
+# list of key names that should be skipped when serializing the encryption context
+IGNORED_CONTEXT_KEYS = ["aws-crypto-public-key"]
+
+# special tag name to allow specifying a custom key material for created keys
+TAG_KEY_CUSTOM_KEY_MATERIAL = "_custom_key_material_"
+
+
+def _serialize_ciphertext_blob(ciphertext: Ciphertext) -> bytes:
+ header = struct.pack(
+ CIPHERTEXT_HEADER_FORMAT,
+ ciphertext.key_id.encode("utf-8"),
+ ciphertext.iv,
+ ciphertext.tag,
+ )
+ return header + ciphertext.ciphertext
+
+
+def deserialize_ciphertext_blob(ciphertext_blob: bytes) -> Ciphertext:
+ header = ciphertext_blob[:HEADER_LEN]
+ ciphertext = ciphertext_blob[HEADER_LEN:]
+ key_id, iv, tag = struct.unpack(CIPHERTEXT_HEADER_FORMAT, header)
+ return Ciphertext(key_id=key_id.decode("utf-8"), iv=iv, ciphertext=ciphertext, tag=tag)
+
+
+def _serialize_encryption_context(encryption_context: Optional[EncryptionContextType]) -> bytes:
+ if encryption_context:
+ aad = io.BytesIO()
+ for key, value in sorted(encryption_context.items(), key=lambda x: x[0]):
+ # remove the reserved key-value pair from additional authentication data
+ if key not in IGNORED_CONTEXT_KEYS:
+ aad.write(key.encode("utf-8"))
+ aad.write(value.encode("utf-8"))
+ return aad.getvalue()
+ else:
+ return b""
+
+
+# Confusion alert!
+# In KMS, there are two things that can be called "keys":
+# 1. A cryptographic key, i.e. a string of characters, a private/public/symmetrical key for cryptographic encoding
+# and decoding etc. It is modeled here by KmsCryptoKey class.
+# 2. An AWS object that stores both a cryptographic key and some relevant metadata, e.g. creation time, a unique ID,
+# some state. It is modeled by KmsKey class.
+#
+# While KmsKeys always contain KmsCryptoKeys, sometimes KmsCryptoKeys exist without corresponding KmsKeys,
+# e.g. GenerateDataKeyPair API call returns contents of a new KmsCryptoKey that is not associated with any KmsKey,
+# but is partially encrypted by some pre-existing KmsKey.
+
+
+class KmsCryptoKey:
+ """
+ KmsCryptoKeys used to model both of the two cases where AWS generates keys:
+ 1. Keys that are created to be used inside of AWS. For such a key, its key material / private key are not to
+ leave AWS unencrypted. If they have to leave AWS, a different KmsCryptoKey is used to encrypt the data first.
+ 2. Keys that AWS creates for customers for some external use. Such a key might be returned to a customer with its
+ key material or public key unencrypted - see KMS GenerateDataKey / GenerateDataKeyPair. But such a key is not stored
+ by AWS and is not used by AWS.
+ """
+
+ public_key: Optional[bytes]
+ private_key: Optional[bytes]
+ key_material: bytes
+ key_spec: str
+
+ def __init__(self, key_spec: str, key_material: Optional[bytes] = None):
+ self.private_key = None
+ self.public_key = None
+ # Technically, key_material, being a symmetric encryption key, is only relevant for
+ # key_spec == SYMMETRIC_DEFAULT.
+ # But LocalStack uses symmetric encryption with this key_material even for other specs. Asymmetric keys are
+ # generated, but are not actually used for encryption. Signing is different.
+ self.key_material = key_material or os.urandom(SYMMETRIC_DEFAULT_MATERIAL_LENGTH)
+ self.key_spec = key_spec
+
+ if key_spec == "SYMMETRIC_DEFAULT":
+ return
+
+ if key_spec.startswith("RSA"):
+ key_size = RSA_CRYPTO_KEY_LENGTHS.get(key_spec)
+ key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
+ elif key_spec.startswith("ECC"):
+ curve = ECC_CURVES.get(key_spec)
+ key = ec.generate_private_key(curve)
+ elif key_spec.startswith("HMAC"):
+ if key_spec not in HMAC_RANGE_KEY_LENGTHS:
+ raise ValidationException(
+ f"1 validation error detected: Value '{key_spec}' at 'keySpec' "
+ f"failed to satisfy constraint: Member must satisfy enum value set: "
+ f"[RSA_2048, ECC_NIST_P384, ECC_NIST_P256, ECC_NIST_P521, HMAC_384, RSA_3072, "
+ f"ECC_SECG_P256K1, RSA_4096, SYMMETRIC_DEFAULT, HMAC_256, HMAC_224, HMAC_512]"
+ )
+ minimum_length, maximum_length = HMAC_RANGE_KEY_LENGTHS.get(key_spec)
+ self.key_material = key_material or os.urandom(
+ random.randint(minimum_length, maximum_length)
+ )
+ return
+ else:
+ # We do not support SM2 - asymmetric keys both suitable for ENCRYPT_DECRYPT and SIGN_VERIFY,
+ # but only used in China AWS regions.
+ raise UnsupportedOperationException(f"KeySpec {key_spec} is not supported")
+
+ self._serialize_key(key)
+
+ def load_key_material(self, material: bytes):
+ if self.key_spec == "SYMMETRIC_DEFAULT":
+ self.key_material = material
+ else:
+ key = crypto_serialization.load_der_private_key(material, password=None)
+ self._serialize_key(key)
+
+ def _serialize_key(self, key: ec.EllipticCurvePrivateKey | rsa.RSAPrivateKey):
+ self.public_key = key.public_key().public_bytes(
+ crypto_serialization.Encoding.DER,
+ crypto_serialization.PublicFormat.SubjectPublicKeyInfo,
+ )
+ self.private_key = key.private_bytes(
+ crypto_serialization.Encoding.DER,
+ crypto_serialization.PrivateFormat.PKCS8,
+ crypto_serialization.NoEncryption(),
+ )
+
+ @property
+ def key(self) -> RSAPrivateKey | EllipticCurvePrivateKey:
+ return crypto_serialization.load_der_private_key(
+ self.private_key,
+ password=None,
+ backend=default_backend(),
+ )
+
+
+class KmsKey:
+ metadata: KeyMetadata
+ crypto_key: KmsCryptoKey
+ tags: Dict[str, str]
+ policy: str
+ is_key_rotation_enabled: bool
+
+ def __init__(
+ self,
+ create_key_request: CreateKeyRequest = None,
+ account_id: str = None,
+ region: str = None,
+ ):
+ create_key_request = create_key_request or CreateKeyRequest()
+
+ # Please keep in mind that tags of a key could be present in the request, they are not a part of metadata. At
+ # least in the sense of DescribeKey not returning them with the rest of the metadata. Instead, tags are more
+ # like aliases:
+ # https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html
+ # "DescribeKey does not return the following information: ... Tags on the KMS key."
+ self.tags = {}
+ self.add_tags(create_key_request.get("Tags"))
+ # Same goes for the policy. It is in the request, but not in the metadata.
+ self.policy = create_key_request.get("Policy") or self._get_default_key_policy(
+ account_id, region
+ )
+ # https://docs.aws.amazon.com/kms/latest/developerguide/rotate-keys.html
+ # "Automatic key rotation is disabled by default on customer managed keys but authorized users can enable and
+ # disable it."
+ self.is_key_rotation_enabled = False
+
+ self._populate_metadata(create_key_request, account_id, region)
+ custom_key_material = None
+ if TAG_KEY_CUSTOM_KEY_MATERIAL in self.tags:
+ # check if the _custom_key_material_ tag is specified, to use a custom key material for this key
+ custom_key_material = base64.b64decode(self.tags[TAG_KEY_CUSTOM_KEY_MATERIAL])
+ # remove the _custom_key_material_ tag from the tags to not readily expose the custom key material
+ del self.tags[TAG_KEY_CUSTOM_KEY_MATERIAL]
+ self.crypto_key = KmsCryptoKey(self.metadata.get("KeySpec"), custom_key_material)
+
+ def calculate_and_set_arn(self, account_id, region):
+ self.metadata["Arn"] = kms_key_arn(self.metadata.get("KeyId"), account_id, region)
+
+ def generate_mac(self, msg: bytes, mac_algorithm: MacAlgorithmSpec) -> bytes:
+ h = self._get_hmac_context(mac_algorithm)
+ h.update(msg)
+ return h.finalize()
+
+ def verify_mac(self, msg: bytes, mac: bytes, mac_algorithm: MacAlgorithmSpec) -> bool:
+ h = self._get_hmac_context(mac_algorithm)
+ h.update(msg)
+ try:
+ h.verify(mac)
+ return True
+ except InvalidSignature:
+ raise KMSInvalidMacException()
+
+ # Encrypt is a method of KmsKey and not of KmsCryptoKey only because it requires KeyId, and KmsCryptoKeys do not
+ # hold KeyIds. Maybe it would be possible to remodel this better.
+ def encrypt(self, plaintext: bytes, encryption_context: EncryptionContextType = None) -> bytes:
+ iv = os.urandom(IV_LEN)
+ aad = _serialize_encryption_context(encryption_context=encryption_context)
+ ciphertext, tag = encrypt(self.crypto_key.key_material, plaintext, iv, aad)
+ return _serialize_ciphertext_blob(
+ ciphertext=Ciphertext(
+ key_id=self.metadata.get("KeyId"), iv=iv, ciphertext=ciphertext, tag=tag
+ )
+ )
+
+ # The ciphertext has to be deserialized before this call.
+ def decrypt(
+ self, ciphertext: Ciphertext, encryption_context: EncryptionContextType = None
+ ) -> bytes:
+ aad = _serialize_encryption_context(encryption_context=encryption_context)
+ return decrypt(
+ self.crypto_key.key_material, ciphertext.ciphertext, ciphertext.iv, ciphertext.tag, aad
+ )
+
+ def decrypt_rsa(self, encrypted: bytes) -> bytes:
+ private_key = crypto_serialization.load_der_private_key(
+ self.crypto_key.private_key, password=None, backend=default_backend()
+ )
+ decrypted = private_key.decrypt(
+ encrypted,
+ padding.OAEP(
+ mgf=padding.MGF1(algorithm=hashes.SHA256()),
+ algorithm=hashes.SHA256(),
+ label=None,
+ ),
+ )
+ return decrypted
+
+ def sign(
+ self, data: bytes, message_type: MessageType, signing_algorithm: SigningAlgorithmSpec
+ ) -> bytes:
+ hasher, wrapped_hasher = self._construct_sign_verify_hasher(signing_algorithm, message_type)
+ try:
+ if signing_algorithm.startswith("ECDSA"):
+ return self.crypto_key.key.sign(data, ec.ECDSA(wrapped_hasher))
+ else:
+ padding = self._construct_sign_verify_padding(signing_algorithm, hasher)
+ return self.crypto_key.key.sign(data, padding, wrapped_hasher)
+ except ValueError as exc:
+ raise ValidationException(str(exc))
+
+ def verify(
+ self,
+ data: bytes,
+ message_type: MessageType,
+ signing_algorithm: SigningAlgorithmSpec,
+ signature: bytes,
+ ) -> bool:
+ hasher, wrapped_hasher = self._construct_sign_verify_hasher(signing_algorithm, message_type)
+ try:
+ if signing_algorithm.startswith("ECDSA"):
+ self.crypto_key.key.public_key().verify(signature, data, ec.ECDSA(wrapped_hasher))
+ else:
+ padding = self._construct_sign_verify_padding(signing_algorithm, hasher)
+ self.crypto_key.key.public_key().verify(signature, data, padding, wrapped_hasher)
+ return True
+ except ValueError as exc:
+ raise ValidationException(str(exc))
+ except InvalidSignature:
+ # AWS itself raises this exception without any additional message.
+ raise KMSInvalidSignatureException()
+
+ def derive_shared_secret(self, public_key: bytes) -> bytes:
+ key_spec = self.metadata.get("KeySpec")
+ match key_spec:
+ case KeySpec.ECC_NIST_P256 | KeySpec.ECC_SECG_P256K1:
+ algorithm = hashes.SHA256()
+ case KeySpec.ECC_NIST_P384:
+ algorithm = hashes.SHA384()
+ case KeySpec.ECC_NIST_P521:
+ algorithm = hashes.SHA512()
+ case _:
+ raise InvalidKeyUsageException(
+ f"{self.metadata['Arn']} key usage is {self.metadata['KeyUsage']} which is not valid for DeriveSharedSecret."
+ )
+
+ # Deserialize public key from DER encoded data to EllipticCurvePublicKey.
+ try:
+ pub_key = load_der_public_key(public_key)
+ except (UnsupportedAlgorithm, ValueError):
+ raise ValidationException("")
+ shared_secret = self.crypto_key.key.exchange(ec.ECDH(), pub_key)
+ # Perform shared secret derivation.
+ return HKDF(
+ algorithm=algorithm,
+ salt=None,
+ info=b"",
+ length=algorithm.digest_size,
+ backend=default_backend(),
+ ).derive(shared_secret)
+
+ # This method gets called when a key is replicated to another region. It's meant to populate the required metadata
+ # fields in a new replica key.
+ def replicate_metadata(
+ self, replicate_key_request: ReplicateKeyRequest, account_id: str, replica_region: str
+ ) -> None:
+ self.metadata["Description"] = replicate_key_request.get("Description") or ""
+ primary_key_arn = self.metadata["Arn"]
+ # Multi region keys have the same key ID for all replicas, but ARNs differ, as they include actual regions of
+ # replicas.
+ self.calculate_and_set_arn(account_id, replica_region)
+
+ current_replica_keys = self.metadata.get("MultiRegionConfiguration", {}).get(
+ "ReplicaKeys", []
+ )
+ current_replica_keys.append(MultiRegionKey(Arn=self.metadata["Arn"], Region=replica_region))
+ primary_key_region = (
+ self.metadata.get("MultiRegionConfiguration", {}).get("PrimaryKey", {}).get("Region")
+ )
+
+ self.metadata["MultiRegionConfiguration"] = MultiRegionConfiguration(
+ MultiRegionKeyType=MultiRegionKeyType.REPLICA,
+ PrimaryKey=MultiRegionKey(
+ Arn=primary_key_arn,
+ Region=primary_key_region,
+ ),
+ ReplicaKeys=current_replica_keys,
+ )
+
+ def _get_hmac_context(self, mac_algorithm: MacAlgorithmSpec) -> hmac.HMAC:
+ if mac_algorithm == "HMAC_SHA_224":
+ h = hmac.HMAC(self.crypto_key.key_material, hashes.SHA224())
+ elif mac_algorithm == "HMAC_SHA_256":
+ h = hmac.HMAC(self.crypto_key.key_material, hashes.SHA256())
+ elif mac_algorithm == "HMAC_SHA_384":
+ h = hmac.HMAC(self.crypto_key.key_material, hashes.SHA384())
+ elif mac_algorithm == "HMAC_SHA_512":
+ h = hmac.HMAC(self.crypto_key.key_material, hashes.SHA512())
+ else:
+ raise ValidationException(
+ f"1 validation error detected: Value '{mac_algorithm}' at 'macAlgorithm' "
+ f"failed to satisfy constraint: Member must satisfy enum value set: "
+ f"[HMAC_SHA_384, HMAC_SHA_256, HMAC_SHA_224, HMAC_SHA_512]"
+ )
+ return h
+
+ def _construct_sign_verify_hasher(
+ self, signing_algorithm: SigningAlgorithmSpec, message_type: MessageType
+ ) -> (
+ Prehashed | hashes.SHA256 | hashes.SHA384 | hashes.SHA512,
+ Prehashed | hashes.SHA256 | hashes.SHA384 | hashes.SHA512,
+ ):
+ if "SHA_256" in signing_algorithm:
+ hasher = hashes.SHA256()
+ elif "SHA_384" in signing_algorithm:
+ hasher = hashes.SHA384()
+ elif "SHA_512" in signing_algorithm:
+ hasher = hashes.SHA512()
+ else:
+ raise ValidationException(
+ f"Unsupported hash type in SigningAlgorithm '{signing_algorithm}'"
+ )
+
+ wrapped_hasher = hasher
+ if message_type == MessageType.DIGEST:
+ wrapped_hasher = utils.Prehashed(hasher)
+ return hasher, wrapped_hasher
+
+ def _construct_sign_verify_padding(
+ self,
+ signing_algorithm: SigningAlgorithmSpec,
+ hasher: Prehashed | hashes.SHA256 | hashes.SHA384 | hashes.SHA512,
+ ) -> PKCS1v15 | PSS:
+ if signing_algorithm.startswith("RSA"):
+ if "PKCS" in signing_algorithm:
+ return padding.PKCS1v15()
+ elif "PSS" in signing_algorithm:
+ return padding.PSS(mgf=padding.MGF1(hasher), salt_length=padding.PSS.MAX_LENGTH)
+ else:
+ LOG.warning("Unsupported padding in SigningAlgorithm '%s'", signing_algorithm)
+
+ # Not a comment, rather some possibly relevant links for the future.
+ # https://docs.aws.amazon.com/kms/latest/developerguide/asymm-create-key.html
+ # "You cannot create an elliptic curve key pair for encryption and decryption."
+ # https://docs.aws.amazon.com/kms/latest/developerguide/concepts.html#asymmetric-keys-concept
+ # "You can create asymmetric KMS keys that represent RSA key pairs for public key encryption or signing and
+ # verification, or elliptic curve key pairs for signing and verification."
+ #
+ # A useful link with a cheat-sheet of what operations are supported by what types of keys:
+ # https://docs.aws.amazon.com/kms/latest/developerguide/symm-asymm-compare.html
+ #
+ # https://docs.aws.amazon.com/kms/latest/developerguide/concepts.html#data-keys
+ # "AWS KMS generates the data key. Then it encrypts a copy of the data key under a symmetric encryption KMS key that
+ # you specify."
+ #
+ # Data keys are symmetric, data key pairs are asymmetric.
+ def _populate_metadata(
+ self, create_key_request: CreateKeyRequest, account_id: str, region: str
+ ) -> None:
+ self.metadata = KeyMetadata()
+ # Metadata fields coming from a creation request
+ #
+ # We do not include tags into the metadata. Tags might be present in a key creation request, but our metadata
+ # only contains data displayed by DescribeKey. And tags are not there:
+ # https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html
+ # "DescribeKey does not return the following information: ... Tags on the KMS key."
+
+ self.metadata["Description"] = create_key_request.get("Description") or ""
+ self.metadata["MultiRegion"] = create_key_request.get("MultiRegion") or False
+ self.metadata["Origin"] = create_key_request.get("Origin") or "AWS_KMS"
+ # https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html#KMS-CreateKey-request-CustomerMasterKeySpec
+ # CustomerMasterKeySpec has been deprecated, still used for compatibility. Is replaced by KeySpec.
+ # The meaning is the same, just the name differs.
+ self.metadata["KeySpec"] = (
+ create_key_request.get("KeySpec")
+ or create_key_request.get("CustomerMasterKeySpec")
+ or "SYMMETRIC_DEFAULT"
+ )
+ self.metadata["CustomerMasterKeySpec"] = self.metadata.get("KeySpec")
+ self.metadata["KeyUsage"] = self._get_key_usage(
+ create_key_request.get("KeyUsage"), self.metadata.get("KeySpec")
+ )
+
+ # Metadata fields AWS introduces automatically
+ self.metadata["AWSAccountId"] = account_id
+ self.metadata["CreationDate"] = datetime.datetime.now()
+ self.metadata["Enabled"] = create_key_request.get("Origin") != OriginType.EXTERNAL
+ self.metadata["KeyManager"] = "CUSTOMER"
+ self.metadata["KeyState"] = (
+ KeyState.Enabled
+ if create_key_request.get("Origin") != OriginType.EXTERNAL
+ else KeyState.PendingImport
+ )
+
+ if TAG_KEY_CUSTOM_ID in self.tags:
+ # check if the _custom_id_ tag is specified, to set a user-defined KeyId for this key
+ self.metadata["KeyId"] = self.tags[TAG_KEY_CUSTOM_ID].strip()
+ elif self.metadata.get("MultiRegion"):
+ # https://docs.aws.amazon.com/kms/latest/developerguide/multi-region-keys-overview.html
+ # "Notice that multi-Region keys have a distinctive key ID that begins with mrk-. You can use the mrk- prefix to
+ # identify MRKs programmatically."
+ # The ID for MultiRegion keys also do not have dashes.
+ self.metadata["KeyId"] = "mrk-" + str(uuid.uuid4().hex)
+ else:
+ self.metadata["KeyId"] = str(uuid.uuid4())
+ self.calculate_and_set_arn(account_id, region)
+
+ self._populate_encryption_algorithms(
+ self.metadata.get("KeyUsage"), self.metadata.get("KeySpec")
+ )
+ self._populate_signing_algorithms(
+ self.metadata.get("KeyUsage"), self.metadata.get("KeySpec")
+ )
+ self._populate_mac_algorithms(self.metadata.get("KeyUsage"), self.metadata.get("KeySpec"))
+
+ if self.metadata["MultiRegion"]:
+ self.metadata["MultiRegionConfiguration"] = MultiRegionConfiguration(
+ MultiRegionKeyType=MultiRegionKeyType.PRIMARY,
+ PrimaryKey=MultiRegionKey(Arn=self.metadata["Arn"], Region=region),
+ ReplicaKeys=[],
+ )
+
+ def add_tags(self, tags: TagList) -> None:
+ # Just in case we get None from somewhere.
+ if not tags:
+ return
+
+ unique_tag_keys = {tag["TagKey"] for tag in tags}
+ if len(unique_tag_keys) < len(tags):
+ raise TagException("Duplicate tag keys")
+
+ if len(tags) > 50:
+ raise TagException("Too many tags")
+
+ # Do not care if we overwrite an existing tag:
+ # https://docs.aws.amazon.com/kms/latest/APIReference/API_TagResource.html
+ # "To edit a tag, specify an existing tag key and a new tag value."
+ for i, tag in enumerate(tags, start=1):
+ validate_tag(i, tag)
+ self.tags[tag.get("TagKey")] = tag.get("TagValue")
+
+ def schedule_key_deletion(self, pending_window_in_days: int) -> None:
+ self.metadata["Enabled"] = False
+ # TODO For MultiRegion keys, the status of replicas get set to "PendingDeletion", while the primary key
+ # becomes "PendingReplicaDeletion". Here we just set all keys to "PendingDeletion", as we do not have any
+ # notion of a primary key in LocalStack. Might be useful to improve it.
+ # https://docs.aws.amazon.com/kms/latest/developerguide/multi-region-keys-delete.html#primary-delete
+ self.metadata["KeyState"] = "PendingDeletion"
+ self.metadata["DeletionDate"] = datetime.datetime.now() + datetime.timedelta(
+ days=pending_window_in_days
+ )
+
+ # An example of how the whole policy should look like:
+ # https://docs.aws.amazon.com/kms/latest/developerguide/key-policy-overview.html
+ # The default statement is here:
+ # https://docs.aws.amazon.com/kms/latest/developerguide/key-policy-default.html#key-policy-default-allow-root-enable-iam
+ def _get_default_key_policy(self, account_id: str, region: str) -> str:
+ return json.dumps(
+ {
+ "Version": "2012-10-17",
+ "Id": "key-default-1",
+ "Statement": [
+ {
+ "Sid": "Enable IAM User Permissions",
+ "Effect": "Allow",
+ "Principal": {"AWS": f"arn:{get_partition(region)}:iam::{account_id}:root"},
+ "Action": "kms:*",
+ "Resource": "*",
+ }
+ ],
+ }
+ )
+
+ def _populate_encryption_algorithms(self, key_usage: str, key_spec: str) -> None:
+ # The two main usages for KMS keys are encryption/decryption and signing/verification.
+ # Doesn't make sense to populate fields related to encryption/decryption unless the key is created with that
+ # goal in mind.
+ if key_usage != "ENCRYPT_DECRYPT":
+ return
+ if key_spec == "SYMMETRIC_DEFAULT":
+ self.metadata["EncryptionAlgorithms"] = ["SYMMETRIC_DEFAULT"]
+ else:
+ self.metadata["EncryptionAlgorithms"] = ["RSAES_OAEP_SHA_1", "RSAES_OAEP_SHA_256"]
+
+ def _populate_signing_algorithms(self, key_usage: str, key_spec: str) -> None:
+ # The two main usages for KMS keys are encryption/decryption and signing/verification.
+ # Doesn't make sense to populate fields related to signing/verification unless the key is created with that
+ # goal in mind.
+ if key_usage != "SIGN_VERIFY":
+ return
+ if key_spec in ["ECC_NIST_P256", "ECC_SECG_P256K1"]:
+ self.metadata["SigningAlgorithms"] = ["ECDSA_SHA_256"]
+ elif key_spec == "ECC_NIST_P384":
+ self.metadata["SigningAlgorithms"] = ["ECDSA_SHA_384"]
+ elif key_spec == "ECC_NIST_P521":
+ self.metadata["SigningAlgorithms"] = ["ECDSA_SHA_512"]
+ else:
+ self.metadata["SigningAlgorithms"] = [
+ "RSASSA_PKCS1_V1_5_SHA_256",
+ "RSASSA_PKCS1_V1_5_SHA_384",
+ "RSASSA_PKCS1_V1_5_SHA_512",
+ "RSASSA_PSS_SHA_256",
+ "RSASSA_PSS_SHA_384",
+ "RSASSA_PSS_SHA_512",
+ ]
+
+ def _populate_mac_algorithms(self, key_usage: str, key_spec: str) -> None:
+ if key_usage != "GENERATE_VERIFY_MAC":
+ return
+ if key_spec == "HMAC_224":
+ self.metadata["MacAlgorithms"] = ["HMAC_SHA_224"]
+ elif key_spec == "HMAC_256":
+ self.metadata["MacAlgorithms"] = ["HMAC_SHA_256"]
+ elif key_spec == "HMAC_384":
+ self.metadata["MacAlgorithms"] = ["HMAC_SHA_384"]
+ elif key_spec == "HMAC_512":
+ self.metadata["MacAlgorithms"] = ["HMAC_SHA_512"]
+
+ def _get_key_usage(self, request_key_usage: str, key_spec: str) -> str:
+ if key_spec in HMAC_RANGE_KEY_LENGTHS:
+ if request_key_usage is None:
+ raise ValidationException(
+ "You must specify a KeyUsage value for all KMS keys except for symmetric encryption keys."
+ )
+ elif request_key_usage != KeyUsageType.GENERATE_VERIFY_MAC:
+ raise ValidationException(
+ f"1 validation error detected: Value '{request_key_usage}' at 'keyUsage' "
+ f"failed to satisfy constraint: Member must satisfy enum value set: "
+ f"[ENCRYPT_DECRYPT, SIGN_VERIFY, GENERATE_VERIFY_MAC]"
+ )
+ else:
+ return KeyUsageType.GENERATE_VERIFY_MAC
+ elif request_key_usage == KeyUsageType.KEY_AGREEMENT:
+ if key_spec not in [
+ KeySpec.ECC_NIST_P256,
+ KeySpec.ECC_NIST_P384,
+ KeySpec.ECC_NIST_P521,
+ KeySpec.ECC_SECG_P256K1,
+ KeySpec.SM2,
+ ]:
+ raise ValidationException(
+ f"KeyUsage {request_key_usage} is not compatible with KeySpec {key_spec}"
+ )
+ else:
+ return request_key_usage
+ else:
+ return request_key_usage or "ENCRYPT_DECRYPT"
+
+
+class KmsGrant:
+ # AWS documentation doesn't seem to mention any metadata object for grants like it does mention KeyMetadata for
+ # keys. But, based on our understanding of AWS documentation for CreateGrant, ListGrants operations etc,
+ # AWS has some set of fields for grants like it has for keys. So we are going to call them `metadata` here for
+ # consistency.
+ metadata: Dict
+ # Tokens are not a part of metadata, as their use is more limited and specific than for the rest of the
+ # metadata: https://docs.aws.amazon.com/kms/latest/developerguide/grant-manage.html#using-grant-token
+ # Tokens are used to refer to a grant in a short period right after the grant gets created. Normally it might
+ # take KMS up to 5 minutes to make a new grant available. In that time window referring to a grant by its
+ # GrantId might not work, so tokens are supposed to be used. The tokens could possibly be used even
+ # afterwards. But since the only way to get a token is through a CreateGrant operation (see below), the chances
+ # of someone storing a token and using it later are slim.
+ #
+ # https://docs.aws.amazon.com/kms/latest/developerguide/grants.html#grant_token
+ # "CreateGrant is the only operation that returns a grant token. You cannot get a grant token from any other
+ # AWS KMS operation or from the CloudTrail log event for the CreateGrant operation. The ListGrants and
+ # ListRetirableGrants operations return the grant ID, but not a grant token."
+ #
+ # Usually a grant might have multiple unique tokens. But here we just model it with a single token for
+ # simplicity.
+ token: str
+
+ def __init__(self, create_grant_request: CreateGrantRequest, account_id: str, region_name: str):
+ self.metadata = dict(create_grant_request)
+
+ if is_valid_key_arn(self.metadata["KeyId"]):
+ self.metadata["KeyArn"] = self.metadata["KeyId"]
+ else:
+ self.metadata["KeyArn"] = kms_key_arn(self.metadata["KeyId"], account_id, region_name)
+
+ self.metadata["GrantId"] = long_uid()
+ self.metadata["CreationDate"] = datetime.datetime.now()
+ # https://docs.aws.amazon.com/kms/latest/APIReference/API_GrantListEntry.html
+ # "If a name was provided in the CreateGrant request, that name is returned. Otherwise this value is null."
+ # According to the examples in AWS docs
+ # https://docs.aws.amazon.com/kms/latest/APIReference/API_ListGrants.html#API_ListGrants_Examples
+ # The Name field is present with just an empty string value.
+ self.metadata.setdefault("Name", "")
+
+ # Encode account ID and region in grant token.
+ # This way the grant can be located when being retired by grant principal.
+ # The token consists of account ID, region name and a UUID concatenated with ':' and encoded with base64
+ decoded_token = account_id + ":" + region_name + ":" + long_uid()
+ self.token = to_str(base64.b64encode(to_bytes(decoded_token)))
+
+
+class KmsAlias:
+ # Like with grants (see comment for KmsGrant), there is no mention of some specific object modeling metadata
+ # for KMS aliases. But there is data that is some metadata, so we model it in a way similar to KeyMetadata for keys.
+ metadata: Dict
+
+ def __init__(
+ self,
+ create_alias_request: CreateAliasRequest = None,
+ account_id: str = None,
+ region: str = None,
+ ):
+ create_alias_request = create_alias_request or CreateAliasRequest()
+ self.metadata = {}
+ self.metadata["AliasName"] = create_alias_request.get("AliasName")
+ self.metadata["TargetKeyId"] = create_alias_request.get("TargetKeyId")
+ self.update_date_of_last_update()
+ self.metadata["CreationDate"] = self.metadata["LastUpdateDate"]
+ self.metadata["AliasArn"] = kms_alias_arn(self.metadata["AliasName"], account_id, region)
+
+ def update_date_of_last_update(self):
+ self.metadata["LastUpdateDate"] = datetime.datetime.now()
+
+
+@dataclass
+class KeyImportState:
+ key_id: str
+ import_token: str
+ wrapping_algo: str
+ key: KmsKey
+
+
+class KmsStore(BaseStore):
+ # maps key ids to keys
+ keys: Dict[str, KmsKey] = LocalAttribute(default=dict)
+
+ # According to AWS documentation on grants https://docs.aws.amazon.com/kms/latest/APIReference/API_RetireGrant.html
+ # "Cross-account use: Yes. You can retire a grant on a KMS key in a different AWS account."
+
+ # maps grant ids to grants
+ grants: Dict[str, KmsGrant] = LocalAttribute(default=dict)
+
+ # maps from (grant names (used for idempotency), key id) to grant ids
+ grant_names: Dict[Tuple[str, str], str] = LocalAttribute(default=dict)
+
+ # maps grant tokens to grant ids
+ grant_tokens: Dict[str, str] = LocalAttribute(default=dict)
+
+ # maps key alias names to aliases
+ aliases: Dict[str, KmsAlias] = LocalAttribute(default=dict)
+
+ # maps import tokens to import data
+ imports: Dict[str, KeyImportState] = LocalAttribute(default=dict)
+
+
+kms_stores = AccountRegionBundle("kms", KmsStore)
diff --git a/localstack-core/localstack/services/kms/provider.py b/localstack-core/localstack/services/kms/provider.py
new file mode 100644
index 0000000000000..3ccd54c359c30
--- /dev/null
+++ b/localstack-core/localstack/services/kms/provider.py
@@ -0,0 +1,1522 @@
+import base64
+import copy
+import datetime
+import logging
+import os
+from typing import Dict, Tuple
+
+from cryptography.exceptions import InvalidTag
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.asymmetric import padding
+
+from localstack.aws.api import CommonServiceException, RequestContext, handler
+from localstack.aws.api.kms import (
+ AlgorithmSpec,
+ AlreadyExistsException,
+ CancelKeyDeletionRequest,
+ CancelKeyDeletionResponse,
+ CiphertextType,
+ CreateAliasRequest,
+ CreateGrantRequest,
+ CreateGrantResponse,
+ CreateKeyRequest,
+ CreateKeyResponse,
+ DataKeyPairSpec,
+ DateType,
+ DecryptResponse,
+ DeleteAliasRequest,
+ DeriveSharedSecretResponse,
+ DescribeKeyRequest,
+ DescribeKeyResponse,
+ DisabledException,
+ DisableKeyRequest,
+ DisableKeyRotationRequest,
+ EnableKeyRequest,
+ EncryptionAlgorithmSpec,
+ EncryptionContextType,
+ EncryptResponse,
+ ExpirationModelType,
+ GenerateDataKeyPairResponse,
+ GenerateDataKeyPairWithoutPlaintextResponse,
+ GenerateDataKeyRequest,
+ GenerateDataKeyResponse,
+ GenerateDataKeyWithoutPlaintextRequest,
+ GenerateDataKeyWithoutPlaintextResponse,
+ GenerateMacRequest,
+ GenerateMacResponse,
+ GenerateRandomRequest,
+ GenerateRandomResponse,
+ GetKeyPolicyRequest,
+ GetKeyPolicyResponse,
+ GetKeyRotationStatusRequest,
+ GetKeyRotationStatusResponse,
+ GetParametersForImportResponse,
+ GetPublicKeyResponse,
+ GrantIdType,
+ GrantTokenList,
+ GrantTokenType,
+ ImportKeyMaterialResponse,
+ IncorrectKeyException,
+ InvalidCiphertextException,
+ InvalidGrantIdException,
+ InvalidKeyUsageException,
+ KeyAgreementAlgorithmSpec,
+ KeyIdType,
+ KeySpec,
+ KeyState,
+ KeyUsageType,
+ KmsApi,
+ KMSInvalidStateException,
+ LimitType,
+ ListAliasesResponse,
+ ListGrantsRequest,
+ ListGrantsResponse,
+ ListKeyPoliciesRequest,
+ ListKeyPoliciesResponse,
+ ListKeysRequest,
+ ListKeysResponse,
+ ListResourceTagsRequest,
+ ListResourceTagsResponse,
+ MacAlgorithmSpec,
+ MarkerType,
+ MultiRegionKey,
+ NotFoundException,
+ NullableBooleanType,
+ OriginType,
+ PlaintextType,
+ PrincipalIdType,
+ PublicKeyType,
+ PutKeyPolicyRequest,
+ RecipientInfo,
+ ReEncryptResponse,
+ ReplicateKeyRequest,
+ ReplicateKeyResponse,
+ ScheduleKeyDeletionRequest,
+ ScheduleKeyDeletionResponse,
+ SignRequest,
+ SignResponse,
+ TagResourceRequest,
+ UnsupportedOperationException,
+ UntagResourceRequest,
+ UpdateAliasRequest,
+ UpdateKeyDescriptionRequest,
+ VerifyMacRequest,
+ VerifyMacResponse,
+ VerifyRequest,
+ VerifyResponse,
+ WrappingKeySpec,
+)
+from localstack.services.kms.exceptions import ValidationException
+from localstack.services.kms.models import (
+ MULTI_REGION_PATTERN,
+ PATTERN_UUID,
+ RESERVED_ALIASES,
+ KeyImportState,
+ KmsAlias,
+ KmsCryptoKey,
+ KmsGrant,
+ KmsKey,
+ KmsStore,
+ deserialize_ciphertext_blob,
+ kms_stores,
+)
+from localstack.services.kms.utils import is_valid_key_arn, parse_key_arn, validate_alias_name
+from localstack.services.plugins import ServiceLifecycleHook
+from localstack.utils.aws.arns import get_partition, kms_alias_arn, parse_arn
+from localstack.utils.collections import PaginatedList
+from localstack.utils.common import select_attributes
+from localstack.utils.strings import short_uid, to_bytes, to_str
+
+LOG = logging.getLogger(__name__)
+
+# valid operations
+VALID_OPERATIONS = [
+ "CreateKey",
+ "Decrypt",
+ "Encrypt",
+ "GenerateDataKey",
+ "GenerateDataKeyWithoutPlaintext",
+ "ReEncryptFrom",
+ "ReEncryptTo",
+ "Sign",
+ "Verify",
+ "GetPublicKey",
+ "CreateGrant",
+ "RetireGrant",
+ "DescribeKey",
+ "GenerateDataKeyPair",
+ "GenerateDataKeyPairWithoutPlaintext",
+]
+
+
+class ValidationError(CommonServiceException):
+ """General validation error type (defined in the AWS docs, but not part of the botocore spec)"""
+
+ def __init__(self, message=None):
+ super().__init__("ValidationError", message=message)
+
+
+# For all operations constraints for states of keys are based on
+# https://docs.aws.amazon.com/kms/latest/developerguide/key-state.html
+class KmsProvider(KmsApi, ServiceLifecycleHook):
+ """
+ The LocalStack Key Management Service (KMS) provider.
+
+ Cross-account access is supported by following operations where key ID belonging
+ to another account can be used with the key ARN.
+ - CreateGrant
+ - DescribeKey
+ - GetKeyRotationStatus
+ - GetPublicKey
+ - ListGrants
+ - RetireGrant
+ - RevokeGrant
+ - Decrypt
+ - Encrypt
+ - GenerateDataKey
+ - GenerateDataKeyPair
+ - GenerateDataKeyPairWithoutPlaintext
+ - GenerateDataKeyWithoutPlaintext
+ - GenerateMac
+ - ReEncrypt
+ - Sign
+ - Verify
+ - VerifyMac
+ """
+
+ #
+ # Helpers
+ #
+
+ @staticmethod
+ def _get_store(account_id: str, region_name: str) -> KmsStore:
+ return kms_stores[account_id][region_name]
+
+ @staticmethod
+ def _create_kms_alias(account_id: str, region_name: str, request: CreateAliasRequest):
+ store = kms_stores[account_id][region_name]
+ alias = KmsAlias(request, account_id, region_name)
+ alias_name = request.get("AliasName")
+ store.aliases[alias_name] = alias
+
+ @staticmethod
+ def _create_kms_key(
+ account_id: str, region_name: str, request: CreateKeyRequest = None
+ ) -> KmsKey:
+ store = kms_stores[account_id][region_name]
+ key = KmsKey(request, account_id, region_name)
+ key_id = key.metadata["KeyId"]
+ store.keys[key_id] = key
+ return key
+
+ @staticmethod
+ def _get_key_id_from_any_id(account_id: str, region_name: str, some_id: str) -> str:
+ """
+ Resolve a KMS key ID by using one of the following identifiers:
+ - key ID
+ - key ARN
+ - key alias
+ - key alias ARN
+ """
+ alias_name = None
+ key_id = None
+ key_arn = None
+
+ if some_id.startswith("arn:"):
+ if ":alias/" in some_id:
+ alias_arn = some_id
+ alias_name = "alias/" + alias_arn.split(":alias/")[1]
+ elif ":key/" in some_id:
+ key_arn = some_id
+ key_id = key_arn.split(":key/")[1]
+ parsed_arn = parse_arn(key_arn)
+ if parsed_arn["region"] != region_name:
+ raise NotFoundException(f"Invalid arn {parsed_arn['region']}")
+ else:
+ raise ValueError(
+ f"Supplied value of {some_id} is an ARN, but neither of a KMS key nor of a KMS key "
+ f"alias"
+ )
+ elif some_id.startswith("alias/"):
+ alias_name = some_id
+ else:
+ key_id = some_id
+
+ store = kms_stores[account_id][region_name]
+
+ if alias_name:
+ KmsProvider._create_alias_if_reserved_and_not_exists(
+ account_id,
+ region_name,
+ alias_name,
+ )
+ if alias_name not in store.aliases:
+ raise NotFoundException(f"Unable to find KMS alias with name {alias_name}")
+ key_id = store.aliases[alias_name].metadata["TargetKeyId"]
+
+ # regular KeyId are UUID, and MultiRegion keys starts with 'mrk-' and 32 hex chars
+ if not PATTERN_UUID.match(key_id) and not MULTI_REGION_PATTERN.match(key_id):
+ raise NotFoundException(f"Invalid keyId '{key_id}'")
+
+ if key_id not in store.keys:
+ if not key_arn:
+ key_arn = (
+ f"arn:{get_partition(region_name)}:kms:{region_name}:{account_id}:key/{key_id}"
+ )
+ raise NotFoundException(f"Key '{key_arn}' does not exist")
+
+ return key_id
+
+ @staticmethod
+ def _create_alias_if_reserved_and_not_exists(
+ account_id: str, region_name: str, alias_name: str
+ ):
+ store = kms_stores[account_id][region_name]
+ if alias_name not in RESERVED_ALIASES or alias_name in store.aliases:
+ return
+ create_key_request = {}
+ key_id = KmsProvider._create_kms_key(
+ account_id,
+ region_name,
+ create_key_request,
+ ).metadata.get("KeyId")
+ create_alias_request = CreateAliasRequest(AliasName=alias_name, TargetKeyId=key_id)
+ KmsProvider._create_kms_alias(account_id, region_name, create_alias_request)
+
+ # While in AWS keys have more than Enabled, Disabled and PendingDeletion states, we currently only model these 3
+ # in LocalStack, so this function is limited to them.
+ #
+ # The current default values are based on most of the operations working in AWS with enabled keys, but failing with
+ # disabled and those pending deletion.
+ #
+ # If we decide to use the other states as well, we might want to come up with a better key state validation per
+ # operation. Can consult this page for what states are supported by various operations:
+ # https://docs.aws.amazon.com/kms/latest/developerguide/key-state.html
+ @staticmethod
+ def _get_kms_key(
+ account_id: str,
+ region_name: str,
+ any_type_of_key_id: str,
+ any_key_state_allowed: bool = False,
+ enabled_key_allowed: bool = True,
+ disabled_key_allowed: bool = False,
+ pending_deletion_key_allowed: bool = False,
+ ) -> KmsKey:
+ store = kms_stores[account_id][region_name]
+
+ if any_key_state_allowed:
+ enabled_key_allowed = True
+ disabled_key_allowed = True
+ pending_deletion_key_allowed = True
+ if not (enabled_key_allowed or disabled_key_allowed or pending_deletion_key_allowed):
+ raise ValueError("A key is requested, but all possible key states are prohibited")
+
+ key_id = KmsProvider._get_key_id_from_any_id(account_id, region_name, any_type_of_key_id)
+ key = store.keys[key_id]
+
+ if not disabled_key_allowed and key.metadata.get("KeyState") == "Disabled":
+ raise DisabledException(f"{key.metadata.get('Arn')} is disabled.")
+ if not pending_deletion_key_allowed and key.metadata.get("KeyState") == "PendingDeletion":
+ raise KMSInvalidStateException(f"{key.metadata.get('Arn')} is pending deletion.")
+ if not enabled_key_allowed and key.metadata.get("KeyState") == "Enabled":
+ raise KMSInvalidStateException(
+ f"{key.metadata.get('Arn')} is enabled, but the operation doesn't support "
+ f"such a state"
+ )
+ return store.keys[key_id]
+
+ @staticmethod
+ def _get_kms_alias(account_id: str, region_name: str, alias_name_or_arn: str) -> KmsAlias:
+ store = kms_stores[account_id][region_name]
+
+ if not alias_name_or_arn.startswith("arn:"):
+ alias_name = alias_name_or_arn
+ else:
+ if ":alias/" not in alias_name_or_arn:
+ raise ValidationException(f"{alias_name_or_arn} is not a valid alias ARN")
+ alias_name = "alias/" + alias_name_or_arn.split(":alias/")[1]
+
+ validate_alias_name(alias_name)
+
+ if alias_name not in store.aliases:
+ alias_arn = kms_alias_arn(alias_name, account_id, region_name)
+ # AWS itself uses AliasArn instead of AliasName in this exception.
+ raise NotFoundException(f"Alias {alias_arn} is not found.")
+
+ return store.aliases.get(alias_name)
+
+ @staticmethod
+ def _parse_key_id(key_id_or_arn: str, context: RequestContext) -> Tuple[str, str, str]:
+ """
+ Return locator attributes (account ID, region_name, key ID) of a given KMS key.
+
+ If an ARN is provided, this is extracted from it. Otherwise, context data is used.
+
+ :param key_id_or_arn: KMS key ID or ARN
+ :param context: request context
+ :return: Tuple of account ID, region name and key ID
+ """
+ if is_valid_key_arn(key_id_or_arn):
+ account_id, region_name, key_id = parse_key_arn(key_id_or_arn)
+ if region_name != context.region:
+ raise NotFoundException(f"Invalid arn {region_name}")
+ return account_id, region_name, key_id
+
+ return context.account_id, context.region, key_id_or_arn
+
+ @staticmethod
+ def _is_rsa_spec(key_spec: str) -> bool:
+ return key_spec in [KeySpec.RSA_2048, KeySpec.RSA_3072, KeySpec.RSA_4096]
+
+ #
+ # Operation Handlers
+ #
+
+ @handler("CreateKey", expand=False)
+ def create_key(
+ self,
+ context: RequestContext,
+ request: CreateKeyRequest = None,
+ ) -> CreateKeyResponse:
+ key = self._create_kms_key(context.account_id, context.region, request)
+ return CreateKeyResponse(KeyMetadata=key.metadata)
+
+ @handler("ScheduleKeyDeletion", expand=False)
+ def schedule_key_deletion(
+ self, context: RequestContext, request: ScheduleKeyDeletionRequest
+ ) -> ScheduleKeyDeletionResponse:
+ pending_window = int(request.get("PendingWindowInDays", 30))
+ if pending_window < 7 or pending_window > 30:
+ raise ValidationException(
+ f"PendingWindowInDays should be between 7 and 30, but it is {pending_window}"
+ )
+ key = self._get_kms_key(
+ context.account_id,
+ context.region,
+ request.get("KeyId"),
+ enabled_key_allowed=True,
+ disabled_key_allowed=True,
+ )
+ key.schedule_key_deletion(pending_window)
+ attrs = ["DeletionDate", "KeyId", "KeyState"]
+ result = select_attributes(key.metadata, attrs)
+ result["PendingWindowInDays"] = pending_window
+ return ScheduleKeyDeletionResponse(**result)
+
+ @handler("CancelKeyDeletion", expand=False)
+ def cancel_key_deletion(
+ self, context: RequestContext, request: CancelKeyDeletionRequest
+ ) -> CancelKeyDeletionResponse:
+ key = self._get_kms_key(
+ context.account_id,
+ context.region,
+ request.get("KeyId"),
+ enabled_key_allowed=False,
+ pending_deletion_key_allowed=True,
+ )
+ key.metadata["KeyState"] = KeyState.Disabled
+ key.metadata["DeletionDate"] = None
+ # https://docs.aws.amazon.com/kms/latest/APIReference/API_CancelKeyDeletion.html#API_CancelKeyDeletion_ResponseElements
+ # "The Amazon Resource Name (key ARN) of the KMS key whose deletion is canceled."
+ return CancelKeyDeletionResponse(KeyId=key.metadata.get("Arn"))
+
+ @handler("DisableKey", expand=False)
+ def disable_key(self, context: RequestContext, request: DisableKeyRequest) -> None:
+ # Technically, AWS allows DisableKey for keys that are already disabled.
+ key = self._get_kms_key(
+ context.account_id,
+ context.region,
+ request.get("KeyId"),
+ enabled_key_allowed=True,
+ disabled_key_allowed=True,
+ )
+ key.metadata["KeyState"] = KeyState.Disabled
+ key.metadata["Enabled"] = False
+
+ @handler("EnableKey", expand=False)
+ def enable_key(self, context: RequestContext, request: EnableKeyRequest) -> None:
+ key = self._get_kms_key(
+ context.account_id,
+ context.region,
+ request.get("KeyId"),
+ enabled_key_allowed=True,
+ disabled_key_allowed=True,
+ )
+ key.metadata["KeyState"] = KeyState.Enabled
+ key.metadata["Enabled"] = True
+
+ @handler("ListKeys", expand=False)
+ def list_keys(self, context: RequestContext, request: ListKeysRequest) -> ListKeysResponse:
+ # https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html#API_ListKeys_ResponseSyntax
+ # Out of whole KeyMetadata only two fields are present in the response.
+ keys_list = PaginatedList(
+ [
+ {"KeyId": key.metadata["KeyId"], "KeyArn": key.metadata["Arn"]}
+ for key in self._get_store(context.account_id, context.region).keys.values()
+ ]
+ )
+ # https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html#API_ListKeys_RequestParameters
+ # Regarding the default value of Limit: "If you do not include a value, it defaults to 100."
+ page, next_token = keys_list.get_page(
+ lambda key_data: key_data.get("KeyId"),
+ next_token=request.get("Marker"),
+ page_size=request.get("Limit", 100),
+ )
+ kwargs = {"NextMarker": next_token, "Truncated": True} if next_token else {}
+ return ListKeysResponse(Keys=page, **kwargs)
+
+ @handler("DescribeKey", expand=False)
+ def describe_key(
+ self, context: RequestContext, request: DescribeKeyRequest
+ ) -> DescribeKeyResponse:
+ account_id, region_name, key_id = self._parse_key_id(request["KeyId"], context)
+ key = self._get_kms_key(account_id, region_name, key_id, any_key_state_allowed=True)
+ return DescribeKeyResponse(KeyMetadata=key.metadata)
+
+ @handler("ReplicateKey", expand=False)
+ def replicate_key(
+ self, context: RequestContext, request: ReplicateKeyRequest
+ ) -> ReplicateKeyResponse:
+ account_id = context.account_id
+ key = self._get_kms_key(account_id, context.region, request.get("KeyId"))
+ key_id = key.metadata.get("KeyId")
+ if not key.metadata.get("MultiRegion"):
+ raise UnsupportedOperationException(
+ f"Unable to replicate a non-MultiRegion key {key_id}"
+ )
+ replica_region = request.get("ReplicaRegion")
+ replicate_to_store = kms_stores[account_id][replica_region]
+ if key_id in replicate_to_store.keys:
+ raise AlreadyExistsException(
+ f"Unable to replicate key {key_id} to region {replica_region}, as the key "
+ f"already exist there"
+ )
+ replica_key = copy.deepcopy(key)
+ replica_key.replicate_metadata(request, account_id, replica_region)
+ replicate_to_store.keys[key_id] = replica_key
+
+ self.update_primary_key_with_replica_keys(key, replica_key, replica_region)
+
+ return ReplicateKeyResponse(ReplicaKeyMetadata=replica_key.metadata)
+
+ @staticmethod
+ # Adds new multi region replica key to the primary key's metadata.
+ def update_primary_key_with_replica_keys(key: KmsKey, replica_key: KmsKey, region: str):
+ key.metadata["MultiRegionConfiguration"]["ReplicaKeys"].append(
+ MultiRegionKey(
+ Arn=replica_key.metadata["Arn"],
+ Region=region,
+ )
+ )
+
+ @handler("UpdateKeyDescription", expand=False)
+ def update_key_description(
+ self, context: RequestContext, request: UpdateKeyDescriptionRequest
+ ) -> None:
+ key = self._get_kms_key(
+ context.account_id,
+ context.region,
+ request.get("KeyId"),
+ enabled_key_allowed=True,
+ disabled_key_allowed=True,
+ )
+ key.metadata["Description"] = request.get("Description")
+
+ @handler("CreateGrant", expand=False)
+ def create_grant(
+ self, context: RequestContext, request: CreateGrantRequest
+ ) -> CreateGrantResponse:
+ key_account_id, key_region_name, key_id = self._parse_key_id(request["KeyId"], context)
+ key = self._get_kms_key(key_account_id, key_region_name, key_id)
+
+ # KeyId can potentially hold one of multiple different types of key identifiers. Here we find a key no
+ # matter which type of id is used.
+ key_id = key.metadata.get("KeyId")
+ request["KeyId"] = key_id
+ self._validate_grant_request(request)
+ grant_name = request.get("Name")
+
+ store = self._get_store(context.account_id, context.region)
+ if grant_name and (grant_name, key_id) in store.grant_names:
+ grant = store.grants[store.grant_names[(grant_name, key_id)]]
+ else:
+ grant = KmsGrant(request, context.account_id, context.region)
+ grant_id = grant.metadata["GrantId"]
+ store.grants[grant_id] = grant
+ if grant_name:
+ store.grant_names[(grant_name, key_id)] = grant_id
+ store.grant_tokens[grant.token] = grant_id
+
+ # At the moment we do not support multiple GrantTokens for grant creation request. Instead, we always use
+ # the same token. For the reference, AWS documentation says:
+ # https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateGrant.html#API_CreateGrant_RequestParameters
+ # "The returned grant token is unique with every CreateGrant request, even when a duplicate GrantId is
+ # returned". "A duplicate GrantId" refers to the idempotency of grant creation requests - if a request has
+ # "Name" field, and if such name already belongs to a previously created grant, no new grant gets created
+ # and the existing grant with the name is returned.
+ return CreateGrantResponse(GrantId=grant.metadata["GrantId"], GrantToken=grant.token)
+
+ @handler("ListGrants", expand=False)
+ def list_grants(
+ self, context: RequestContext, request: ListGrantsRequest
+ ) -> ListGrantsResponse:
+ if not request.get("KeyId"):
+ raise ValidationError("Required input parameter KeyId not specified")
+ key_account_id, key_region_name, _ = self._parse_key_id(request["KeyId"], context)
+ # KeyId can potentially hold one of multiple different types of key identifiers. Here we find a key no
+ # matter which type of id is used.
+ key = self._get_kms_key(
+ key_account_id, key_region_name, request.get("KeyId"), any_key_state_allowed=True
+ )
+ key_id = key.metadata.get("KeyId")
+
+ store = self._get_store(context.account_id, context.region)
+ grant_id = request.get("GrantId")
+ if grant_id:
+ if grant_id not in store.grants:
+ raise InvalidGrantIdException()
+ return ListGrantsResponse(Grants=[store.grants[grant_id].metadata])
+
+ matching_grants = []
+ grantee_principal = request.get("GranteePrincipal")
+ for grant in store.grants.values():
+ # KeyId is a mandatory field of ListGrants request, so is going to be present.
+ _, _, grant_key_id = parse_key_arn(grant.metadata["KeyArn"])
+ if grant_key_id != key_id:
+ continue
+ # GranteePrincipal is a mandatory field for CreateGrant, should be in grants. But it is an optional field
+ # for ListGrants, so might not be there.
+ if grantee_principal and grant.metadata["GranteePrincipal"] != grantee_principal:
+ continue
+ matching_grants.append(grant.metadata)
+
+ grants_list = PaginatedList(matching_grants)
+ page, next_token = grants_list.get_page(
+ lambda grant_data: grant_data.get("GrantId"),
+ next_token=request.get("Marker"),
+ page_size=request.get("Limit", 50),
+ )
+ kwargs = {"NextMarker": next_token, "Truncated": True} if next_token else {}
+
+ return ListGrantsResponse(Grants=page, **kwargs)
+
+ @staticmethod
+ def _delete_grant(store: KmsStore, grant_id: str, key_id: str):
+ grant = store.grants[grant_id]
+
+ _, _, grant_key_id = parse_key_arn(grant.metadata.get("KeyArn"))
+ if key_id != grant_key_id:
+ raise ValidationError(f"Invalid KeyId={key_id} specified for grant {grant_id}")
+
+ store.grant_tokens.pop(grant.token)
+ store.grant_names.pop((grant.metadata.get("Name"), key_id), None)
+ store.grants.pop(grant_id)
+
+ def revoke_grant(
+ self,
+ context: RequestContext,
+ key_id: KeyIdType,
+ grant_id: GrantIdType,
+ dry_run: NullableBooleanType = None,
+ **kwargs,
+ ) -> None:
+ # TODO add support for "dry_run"
+ key_account_id, key_region_name, key_id = self._parse_key_id(key_id, context)
+ key = self._get_kms_key(key_account_id, key_region_name, key_id, any_key_state_allowed=True)
+ key_id = key.metadata.get("KeyId")
+
+ store = self._get_store(context.account_id, context.region)
+
+ if grant_id not in store.grants:
+ raise InvalidGrantIdException()
+
+ self._delete_grant(store, grant_id, key_id)
+
+ def retire_grant(
+ self,
+ context: RequestContext,
+ grant_token: GrantTokenType = None,
+ key_id: KeyIdType = None,
+ grant_id: GrantIdType = None,
+ dry_run: NullableBooleanType = None,
+ **kwargs,
+ ) -> None:
+ # TODO add support for "dry_run"
+ if not grant_token and (not grant_id or not key_id):
+ raise ValidationException("Grant token OR (grant ID, key ID) must be specified")
+
+ if grant_token:
+ decoded_token = to_str(base64.b64decode(grant_token))
+ grant_account_id, grant_region_name, _ = decoded_token.split(":")
+ grant_store = self._get_store(grant_account_id, grant_region_name)
+
+ if grant_token not in grant_store.grant_tokens:
+ raise NotFoundException(f"Unable to find grant token {grant_token}")
+
+ grant_id = grant_store.grant_tokens[grant_token]
+ else:
+ grant_store = self._get_store(context.account_id, context.region)
+
+ if key_id:
+ key_account_id, key_region_name, key_id = self._parse_key_id(key_id, context)
+ key = self._get_kms_key(
+ key_account_id, key_region_name, key_id, any_key_state_allowed=True
+ )
+ key_id = key.metadata.get("KeyId")
+ else:
+ _, _, key_id = parse_key_arn(grant_store.grants[grant_id].metadata.get("KeyArn"))
+
+ self._delete_grant(grant_store, grant_id, key_id)
+
+ def list_retirable_grants(
+ self,
+ context: RequestContext,
+ retiring_principal: PrincipalIdType,
+ limit: LimitType = None,
+ marker: MarkerType = None,
+ **kwargs,
+ ) -> ListGrantsResponse:
+ if not retiring_principal:
+ raise ValidationError("Required input parameter 'RetiringPrincipal' not specified")
+
+ matching_grants = [
+ grant.metadata
+ for grant in self._get_store(context.account_id, context.region).grants.values()
+ if grant.metadata.get("RetiringPrincipal") == retiring_principal
+ ]
+ grants_list = PaginatedList(matching_grants)
+ limit = limit or 50
+ page, next_token = grants_list.get_page(
+ lambda grant_data: grant_data.get("GrantId"),
+ next_token=marker,
+ page_size=limit,
+ )
+ kwargs = {"NextMarker": next_token, "Truncated": True} if next_token else {}
+
+ return ListGrantsResponse(Grants=page, **kwargs)
+
+ def get_public_key(
+ self,
+ context: RequestContext,
+ key_id: KeyIdType,
+ grant_tokens: GrantTokenList = None,
+ **kwargs,
+ ) -> GetPublicKeyResponse:
+ # According to https://docs.aws.amazon.com/kms/latest/developerguide/key-state.html, GetPublicKey is supposed
+ # to fail for disabled keys. But it actually doesn't fail in AWS.
+ account_id, region_name, key_id = self._parse_key_id(key_id, context)
+ key = self._get_kms_key(
+ account_id,
+ region_name,
+ key_id,
+ enabled_key_allowed=True,
+ disabled_key_allowed=True,
+ )
+ attrs = [
+ "KeySpec",
+ "KeyUsage",
+ "EncryptionAlgorithms",
+ "SigningAlgorithms",
+ ]
+ result = select_attributes(key.metadata, attrs)
+ result["PublicKey"] = key.crypto_key.public_key
+ result["KeyId"] = key.metadata["Arn"]
+ return GetPublicKeyResponse(**result)
+
+ def _generate_data_key_pair(
+ self,
+ context: RequestContext,
+ key_id: str,
+ key_pair_spec: str,
+ encryption_context: EncryptionContextType = None,
+ ):
+ account_id, region_name, key_id = self._parse_key_id(key_id, context)
+ key = self._get_kms_key(account_id, region_name, key_id)
+ self._validate_key_for_encryption_decryption(context, key)
+ crypto_key = KmsCryptoKey(key_pair_spec)
+ return {
+ "KeyId": key.metadata["Arn"],
+ "KeyPairSpec": key_pair_spec,
+ "PrivateKeyCiphertextBlob": key.encrypt(crypto_key.private_key, encryption_context),
+ "PrivateKeyPlaintext": crypto_key.private_key,
+ "PublicKey": crypto_key.public_key,
+ }
+
+ @handler("GenerateDataKeyPair")
+ def generate_data_key_pair(
+ self,
+ context: RequestContext,
+ key_id: KeyIdType,
+ key_pair_spec: DataKeyPairSpec,
+ encryption_context: EncryptionContextType = None,
+ grant_tokens: GrantTokenList = None,
+ recipient: RecipientInfo = None,
+ dry_run: NullableBooleanType = None,
+ **kwargs,
+ ) -> GenerateDataKeyPairResponse:
+ # TODO add support for "dry_run"
+ result = self._generate_data_key_pair(context, key_id, key_pair_spec, encryption_context)
+ return GenerateDataKeyPairResponse(**result)
+
+ @handler("GenerateRandom", expand=False)
+ def generate_random(
+ self, context: RequestContext, request: GenerateRandomRequest
+ ) -> GenerateRandomResponse:
+ number_of_bytes = request.get("NumberOfBytes")
+ if number_of_bytes is None:
+ raise ValidationException("NumberOfBytes is required.")
+ if number_of_bytes > 1024:
+ raise ValidationException(
+ f"1 validation error detected: Value '{number_of_bytes}' at 'numberOfBytes' failed "
+ "to satisfy constraint: Member must have value less than or equal to 1024"
+ )
+ if number_of_bytes < 1:
+ raise ValidationException(
+ f"1 validation error detected: Value '{number_of_bytes}' at 'numberOfBytes' failed "
+ "to satisfy constraint: Member must have value greater than or equal to 1"
+ )
+
+ byte_string = os.urandom(number_of_bytes)
+
+ return GenerateRandomResponse(Plaintext=byte_string)
+
+ @handler("GenerateDataKeyPairWithoutPlaintext")
+ def generate_data_key_pair_without_plaintext(
+ self,
+ context: RequestContext,
+ key_id: KeyIdType,
+ key_pair_spec: DataKeyPairSpec,
+ encryption_context: EncryptionContextType = None,
+ grant_tokens: GrantTokenList = None,
+ dry_run: NullableBooleanType = None,
+ **kwargs,
+ ) -> GenerateDataKeyPairWithoutPlaintextResponse:
+ # TODO add support for "dry_run"
+ result = self._generate_data_key_pair(context, key_id, key_pair_spec, encryption_context)
+ result.pop("PrivateKeyPlaintext")
+ return GenerateDataKeyPairResponse(**result)
+
+ # We currently act on neither on KeySpec setting (which is different from and holds values different then
+ # KeySpec for CreateKey) nor on NumberOfBytes. Instead, we generate a key with a key length that is "standard" in
+ # LocalStack.
+ #
+ def _generate_data_key(
+ self, context: RequestContext, key_id: str, encryption_context: EncryptionContextType = None
+ ):
+ account_id, region_name, key_id = self._parse_key_id(key_id, context)
+ key = self._get_kms_key(account_id, region_name, key_id)
+ # TODO Should also have a validation for the key being a symmetric one.
+ self._validate_key_for_encryption_decryption(context, key)
+ crypto_key = KmsCryptoKey("SYMMETRIC_DEFAULT")
+ return {
+ "KeyId": key.metadata["Arn"],
+ "Plaintext": crypto_key.key_material,
+ "CiphertextBlob": key.encrypt(crypto_key.key_material, encryption_context),
+ }
+
+ @handler("GenerateDataKey", expand=False)
+ def generate_data_key(
+ self, context: RequestContext, request: GenerateDataKeyRequest
+ ) -> GenerateDataKeyResponse:
+ result = self._generate_data_key(
+ context, request.get("KeyId"), request.get("EncryptionContext")
+ )
+ return GenerateDataKeyResponse(**result)
+
+ @handler("GenerateDataKeyWithoutPlaintext", expand=False)
+ def generate_data_key_without_plaintext(
+ self, context: RequestContext, request: GenerateDataKeyWithoutPlaintextRequest
+ ) -> GenerateDataKeyWithoutPlaintextResponse:
+ result = self._generate_data_key(
+ context, request.get("KeyId"), request.get("EncryptionContext")
+ )
+ result.pop("Plaintext")
+ return GenerateDataKeyWithoutPlaintextResponse(**result)
+
+ @handler("GenerateMac", expand=False)
+ def generate_mac(
+ self,
+ context: RequestContext,
+ request: GenerateMacRequest,
+ ) -> GenerateMacResponse:
+ msg = request.get("Message")
+ self._validate_mac_msg_length(msg)
+
+ account_id, region_name, key_id = self._parse_key_id(request["KeyId"], context)
+ key = self._get_kms_key(account_id, region_name, key_id)
+
+ self._validate_key_for_generate_verify_mac(context, key)
+
+ algorithm = request.get("MacAlgorithm")
+ self._validate_mac_algorithm(key, algorithm)
+
+ mac = key.generate_mac(msg, algorithm)
+
+ return GenerateMacResponse(Mac=mac, MacAlgorithm=algorithm, KeyId=key.metadata.get("Arn"))
+
+ @handler("VerifyMac", expand=False)
+ def verify_mac(
+ self,
+ context: RequestContext,
+ request: VerifyMacRequest,
+ ) -> VerifyMacResponse:
+ msg = request.get("Message")
+ self._validate_mac_msg_length(msg)
+
+ account_id, region_name, key_id = self._parse_key_id(request["KeyId"], context)
+ key = self._get_kms_key(account_id, region_name, key_id)
+
+ self._validate_key_for_generate_verify_mac(context, key)
+
+ algorithm = request.get("MacAlgorithm")
+ self._validate_mac_algorithm(key, algorithm)
+
+ mac_valid = key.verify_mac(msg, request.get("Mac"), algorithm)
+
+ return VerifyMacResponse(
+ KeyId=key.metadata.get("Arn"), MacValid=mac_valid, MacAlgorithm=algorithm
+ )
+
+ @handler("Sign", expand=False)
+ def sign(self, context: RequestContext, request: SignRequest) -> SignResponse:
+ account_id, region_name, key_id = self._parse_key_id(request["KeyId"], context)
+ key = self._get_kms_key(account_id, region_name, key_id)
+
+ self._validate_key_for_sign_verify(context, key)
+
+ # TODO Add constraints on KeySpec / SigningAlgorithm pairs:
+ # https://docs.aws.amazon.com/kms/latest/developerguide/asymmetric-key-specs.html#key-spec-ecc
+
+ signing_algorithm = request.get("SigningAlgorithm")
+ signature = key.sign(request.get("Message"), request.get("MessageType"), signing_algorithm)
+
+ result = {
+ "KeyId": key.metadata["Arn"],
+ "Signature": signature,
+ "SigningAlgorithm": signing_algorithm,
+ }
+ return SignResponse(**result)
+
+ # Currently LocalStack only calculates SHA256 digests no matter what the signing algorithm is.
+ @handler("Verify", expand=False)
+ def verify(self, context: RequestContext, request: VerifyRequest) -> VerifyResponse:
+ account_id, region_name, key_id = self._parse_key_id(request["KeyId"], context)
+ key = self._get_kms_key(account_id, region_name, key_id)
+
+ self._validate_key_for_sign_verify(context, key)
+
+ signing_algorithm = request.get("SigningAlgorithm")
+ is_signature_valid = key.verify(
+ request.get("Message"),
+ request.get("MessageType"),
+ signing_algorithm,
+ request.get("Signature"),
+ )
+
+ result = {
+ "KeyId": key.metadata["Arn"],
+ "SignatureValid": is_signature_valid,
+ "SigningAlgorithm": signing_algorithm,
+ }
+ return VerifyResponse(**result)
+
+ def re_encrypt(
+ self,
+ context: RequestContext,
+ ciphertext_blob: CiphertextType,
+ destination_key_id: KeyIdType,
+ source_encryption_context: EncryptionContextType = None,
+ source_key_id: KeyIdType = None,
+ destination_encryption_context: EncryptionContextType = None,
+ source_encryption_algorithm: EncryptionAlgorithmSpec = None,
+ destination_encryption_algorithm: EncryptionAlgorithmSpec = None,
+ grant_tokens: GrantTokenList = None,
+ dry_run: NullableBooleanType = None,
+ **kwargs,
+ ) -> ReEncryptResponse:
+ # TODO: when implementing, ensure cross-account support for source_key_id and destination_key_id
+ raise NotImplementedError
+
+ def encrypt(
+ self,
+ context: RequestContext,
+ key_id: KeyIdType,
+ plaintext: PlaintextType,
+ encryption_context: EncryptionContextType = None,
+ grant_tokens: GrantTokenList = None,
+ encryption_algorithm: EncryptionAlgorithmSpec = None,
+ dry_run: NullableBooleanType = None,
+ **kwargs,
+ ) -> EncryptResponse:
+ # TODO add support for "dry_run"
+ account_id, region_name, key_id = self._parse_key_id(key_id, context)
+ key = self._get_kms_key(account_id, region_name, key_id)
+ self._validate_plaintext_length(plaintext)
+ self._validate_plaintext_key_type_based(plaintext, key, encryption_algorithm)
+ self._validate_key_for_encryption_decryption(context, key)
+ self._validate_key_state_not_pending_import(key)
+
+ ciphertext_blob = key.encrypt(plaintext, encryption_context)
+ # For compatibility, we return EncryptionAlgorithm values expected from AWS. But LocalStack currently always
+ # encrypts with symmetric encryption no matter the key settings.
+ return EncryptResponse(
+ CiphertextBlob=ciphertext_blob,
+ KeyId=key.metadata.get("Arn"),
+ EncryptionAlgorithm=encryption_algorithm,
+ )
+
+ # TODO We currently do not even check encryption_context, while moto does. Should add the corresponding logic later.
+ def decrypt(
+ self,
+ context: RequestContext,
+ ciphertext_blob: CiphertextType,
+ encryption_context: EncryptionContextType = None,
+ grant_tokens: GrantTokenList = None,
+ key_id: KeyIdType = None,
+ encryption_algorithm: EncryptionAlgorithmSpec = None,
+ recipient: RecipientInfo = None,
+ dry_run: NullableBooleanType = None,
+ **kwargs,
+ ) -> DecryptResponse:
+ # In AWS, key_id is only supplied for data encrypted with an asymmetrical algorithm. For symmetrical
+ # encryption, key_id is taken from the encrypted data itself.
+ # Since LocalStack doesn't currently do asymmetrical encryption, there is a question of modeling here: we
+ # currently expect data to be only encrypted with symmetric encryption, so having key_id inside. It might not
+ # always be what customers expect.
+ if key_id:
+ account_id, region_name, key_id = self._parse_key_id(key_id, context)
+ try:
+ ciphertext = deserialize_ciphertext_blob(ciphertext_blob=ciphertext_blob)
+ except Exception:
+ ciphertext = None
+ pass
+ else:
+ try:
+ ciphertext = deserialize_ciphertext_blob(ciphertext_blob=ciphertext_blob)
+ account_id, region_name, key_id = self._parse_key_id(ciphertext.key_id, context)
+ except Exception:
+ raise InvalidCiphertextException(
+ "LocalStack is unable to deserialize the ciphertext blob. Perhaps the "
+ "blob didn't come from LocalStack"
+ )
+
+ key = self._get_kms_key(account_id, region_name, key_id)
+ if ciphertext and key.metadata["KeyId"] != ciphertext.key_id:
+ raise IncorrectKeyException(
+ "The key ID in the request does not identify a CMK that can perform this operation."
+ )
+
+ self._validate_key_for_encryption_decryption(context, key)
+ self._validate_key_state_not_pending_import(key)
+
+ try:
+ # TODO: Extend the implementation to handle additional encryption/decryption scenarios
+ # beyond the current support for offline encryption and online decryption using RSA keys if key id exists in
+ # parameters, where `ciphertext_blob` will not be deserializable.
+ if self._is_rsa_spec(key.crypto_key.key_spec) and not ciphertext:
+ plaintext = key.decrypt_rsa(ciphertext_blob)
+ else:
+ plaintext = key.decrypt(ciphertext, encryption_context)
+ except InvalidTag:
+ raise InvalidCiphertextException()
+ # For compatibility, we return EncryptionAlgorithm values expected from AWS. But LocalStack currently always
+ # encrypts with symmetric encryption no matter the key settings.
+ #
+ # We return a key ARN instead of KeyId despite the name of the parameter, as this is what AWS does and states
+ # in its docs.
+ # TODO add support for "recipient"
+ # https://docs.aws.amazon.com/kms/latest/APIReference/API_Decrypt.html#API_Decrypt_RequestSyntax
+ # TODO add support for "dry_run"
+ return DecryptResponse(
+ KeyId=key.metadata.get("Arn"),
+ Plaintext=plaintext,
+ EncryptionAlgorithm=encryption_algorithm,
+ )
+
+ def get_parameters_for_import(
+ self,
+ context: RequestContext,
+ key_id: KeyIdType,
+ wrapping_algorithm: AlgorithmSpec,
+ wrapping_key_spec: WrappingKeySpec,
+ **kwargs,
+ ) -> GetParametersForImportResponse:
+ store = self._get_store(context.account_id, context.region)
+ # KeyId can potentially hold one of multiple different types of key identifiers. get_key finds a key no
+ # matter which type of id is used.
+ key_to_import_material_to = self._get_kms_key(
+ context.account_id,
+ context.region,
+ key_id,
+ enabled_key_allowed=True,
+ disabled_key_allowed=True,
+ )
+ key_arn = key_to_import_material_to.metadata["Arn"]
+ key_origin = key_to_import_material_to.metadata.get("Origin")
+
+ if key_origin != "EXTERNAL":
+ raise UnsupportedOperationException(
+ f"{key_arn} origin is {key_origin} which is not valid for this operation."
+ )
+
+ key_id = key_to_import_material_to.metadata["KeyId"]
+
+ key = KmsKey(CreateKeyRequest(KeySpec=wrapping_key_spec))
+ import_token = short_uid()
+ import_state = KeyImportState(
+ key_id=key_id, import_token=import_token, wrapping_algo=wrapping_algorithm, key=key
+ )
+ store.imports[import_token] = import_state
+ # https://docs.aws.amazon.com/kms/latest/APIReference/API_GetParametersForImport.html
+ # "To import key material, you must use the public key and import token from the same response. These items
+ # are valid for 24 hours."
+ expiry_date = datetime.datetime.now() + datetime.timedelta(days=100)
+ return GetParametersForImportResponse(
+ KeyId=key_to_import_material_to.metadata["Arn"],
+ ImportToken=to_bytes(import_state.import_token),
+ PublicKey=import_state.key.crypto_key.public_key,
+ ParametersValidTo=expiry_date,
+ )
+
+ def import_key_material(
+ self,
+ context: RequestContext,
+ key_id: KeyIdType,
+ import_token: CiphertextType,
+ encrypted_key_material: CiphertextType,
+ valid_to: DateType = None,
+ expiration_model: ExpirationModelType = None,
+ **kwargs,
+ ) -> ImportKeyMaterialResponse:
+ store = self._get_store(context.account_id, context.region)
+ import_token = to_str(import_token)
+ import_state = store.imports.get(import_token)
+ if not import_state:
+ raise NotFoundException(f"Unable to find key import token '{import_token}'")
+ key_to_import_material_to = self._get_kms_key(
+ context.account_id,
+ context.region,
+ key_id,
+ enabled_key_allowed=True,
+ disabled_key_allowed=True,
+ )
+
+ if import_state.wrapping_algo == AlgorithmSpec.RSAES_PKCS1_V1_5:
+ decrypt_padding = padding.PKCS1v15()
+ elif import_state.wrapping_algo == AlgorithmSpec.RSAES_OAEP_SHA_1:
+ decrypt_padding = padding.OAEP(padding.MGF1(hashes.SHA1()), hashes.SHA1(), None)
+ elif import_state.wrapping_algo == AlgorithmSpec.RSAES_OAEP_SHA_256:
+ decrypt_padding = padding.OAEP(padding.MGF1(hashes.SHA256()), hashes.SHA256(), None)
+ else:
+ raise KMSInvalidStateException(
+ f"Unsupported padding, requested wrapping algorithm:'{import_state.wrapping_algo}'"
+ )
+
+ # TODO check if there was already a key imported for this kms key
+ # if so, it has to be identical. We cannot change keys by reimporting after deletion/expiry
+ key_material = import_state.key.crypto_key.key.decrypt(
+ encrypted_key_material, decrypt_padding
+ )
+ if expiration_model:
+ key_to_import_material_to.metadata["ExpirationModel"] = expiration_model
+ else:
+ key_to_import_material_to.metadata["ExpirationModel"] = (
+ ExpirationModelType.KEY_MATERIAL_EXPIRES
+ )
+ if (
+ key_to_import_material_to.metadata["ExpirationModel"]
+ == ExpirationModelType.KEY_MATERIAL_EXPIRES
+ and not valid_to
+ ):
+ raise ValidationException(
+ "A validTo date must be set if the ExpirationModel is KEY_MATERIAL_EXPIRES"
+ )
+ # TODO actually set validTo and make the key expire
+ key_to_import_material_to.metadata["Enabled"] = True
+ key_to_import_material_to.metadata["KeyState"] = KeyState.Enabled
+ key_to_import_material_to.crypto_key.load_key_material(key_material)
+
+ return ImportKeyMaterialResponse()
+
+ def delete_imported_key_material(
+ self, context: RequestContext, key_id: KeyIdType, **kwargs
+ ) -> None:
+ key = self._get_kms_key(
+ context.account_id,
+ context.region,
+ key_id,
+ enabled_key_allowed=True,
+ disabled_key_allowed=True,
+ )
+ key.crypto_key.key_material = None
+ key.metadata["Enabled"] = False
+ key.metadata["KeyState"] = KeyState.PendingImport
+ key.metadata.pop("ExpirationModel", None)
+
+ @handler("CreateAlias", expand=False)
+ def create_alias(self, context: RequestContext, request: CreateAliasRequest) -> None:
+ store = self._get_store(context.account_id, context.region)
+ alias_name = request["AliasName"]
+ validate_alias_name(alias_name)
+ if alias_name in store.aliases:
+ alias_arn = store.aliases.get(alias_name).metadata["AliasArn"]
+ # AWS itself uses AliasArn instead of AliasName in this exception.
+ raise AlreadyExistsException(f"An alias with the name {alias_arn} already exists")
+ # KeyId can potentially hold one of multiple different types of key identifiers. Here we find a key no
+ # matter which type of id is used.
+ key = self._get_kms_key(
+ context.account_id,
+ context.region,
+ request.get("TargetKeyId"),
+ enabled_key_allowed=True,
+ disabled_key_allowed=True,
+ )
+ request["TargetKeyId"] = key.metadata.get("KeyId")
+ self._create_kms_alias(context.account_id, context.region, request)
+
+ @handler("DeleteAlias", expand=False)
+ def delete_alias(self, context: RequestContext, request: DeleteAliasRequest) -> None:
+ # We do not check the state of the key, as, according to AWS docs, all key states, that are possible in
+ # LocalStack, are supported by this operation.
+ store = self._get_store(context.account_id, context.region)
+ alias_name = request["AliasName"]
+ if alias_name not in store.aliases:
+ alias_arn = kms_alias_arn(request["AliasName"], context.account_id, context.region)
+ # AWS itself uses AliasArn instead of AliasName in this exception.
+ raise NotFoundException(f"Alias {alias_arn} is not found")
+ store.aliases.pop(alias_name, None)
+
+ @handler("UpdateAlias", expand=False)
+ def update_alias(self, context: RequestContext, request: UpdateAliasRequest) -> None:
+ # https://docs.aws.amazon.com/kms/latest/developerguide/key-state.html
+ # "If the source KMS key is pending deletion, the command succeeds. If the destination KMS key is pending
+ # deletion, the command fails with error: KMSInvalidStateException : is pending deletion."
+ # Also disabled keys are accepted for this operation (see the table on that page).
+ #
+ # As such, we do not care about the state of the source key, but check the destination one.
+
+ alias_name = request["AliasName"]
+ # This API, per AWS docs, accepts only names, not ARNs.
+ validate_alias_name(alias_name)
+ alias = self._get_kms_alias(context.account_id, context.region, alias_name)
+ key_id = request["TargetKeyId"]
+ # Don't care about the key itself, just want to validate its state.
+ self._get_kms_key(
+ context.account_id,
+ context.region,
+ key_id,
+ enabled_key_allowed=True,
+ disabled_key_allowed=True,
+ )
+ alias.metadata["TargetKeyId"] = key_id
+ alias.update_date_of_last_update()
+
+ @handler("ListAliases")
+ def list_aliases(
+ self,
+ context: RequestContext,
+ key_id: KeyIdType = None,
+ limit: LimitType = None,
+ marker: MarkerType = None,
+ **kwargs,
+ ) -> ListAliasesResponse:
+ store = self._get_store(context.account_id, context.region)
+ if key_id:
+ # KeyId can potentially hold one of multiple different types of key identifiers. Here we find a key no
+ # matter which type of id is used.
+ key = self._get_kms_key(
+ context.account_id, context.region, key_id, any_key_state_allowed=True
+ )
+ key_id = key.metadata.get("KeyId")
+
+ matching_aliases = []
+ for alias in store.aliases.values():
+ if key_id and alias.metadata["TargetKeyId"] != key_id:
+ continue
+ matching_aliases.append(alias.metadata)
+ aliases_list = PaginatedList(matching_aliases)
+ limit = limit or 100
+ page, next_token = aliases_list.get_page(
+ lambda alias_metadata: alias_metadata.get("AliasName"),
+ next_token=marker,
+ page_size=limit,
+ )
+ kwargs = {"NextMarker": next_token, "Truncated": True} if next_token else {}
+ return ListAliasesResponse(Aliases=page, **kwargs)
+
+ @handler("GetKeyRotationStatus", expand=False)
+ def get_key_rotation_status(
+ self, context: RequestContext, request: GetKeyRotationStatusRequest
+ ) -> GetKeyRotationStatusResponse:
+ # https://docs.aws.amazon.com/kms/latest/developerguide/key-state.html
+ # "If the KMS key has imported key material or is in a custom key store: UnsupportedOperationException."
+ # We do not model that here, though.
+ account_id, region_name, key_id = self._parse_key_id(request["KeyId"], context)
+ key = self._get_kms_key(account_id, region_name, key_id, any_key_state_allowed=True)
+ return GetKeyRotationStatusResponse(KeyRotationEnabled=key.is_key_rotation_enabled)
+
+ @handler("DisableKeyRotation", expand=False)
+ def disable_key_rotation(
+ self, context: RequestContext, request: DisableKeyRotationRequest
+ ) -> None:
+ # https://docs.aws.amazon.com/kms/latest/developerguide/key-state.html
+ # "If the KMS key has imported key material or is in a custom key store: UnsupportedOperationException."
+ # We do not model that here, though.
+ key = self._get_kms_key(context.account_id, context.region, request.get("KeyId"))
+ key.is_key_rotation_enabled = False
+
+ @handler("EnableKeyRotation", expand=False)
+ def enable_key_rotation(
+ self, context: RequestContext, request: DisableKeyRotationRequest
+ ) -> None:
+ # https://docs.aws.amazon.com/kms/latest/developerguide/key-state.html
+ # "If the KMS key has imported key material or is in a custom key store: UnsupportedOperationException."
+ # We do not model that here, though.
+ key = self._get_kms_key(context.account_id, context.region, request.get("KeyId"))
+ key.is_key_rotation_enabled = True
+
+ @handler("ListKeyPolicies", expand=False)
+ def list_key_policies(
+ self, context: RequestContext, request: ListKeyPoliciesRequest
+ ) -> ListKeyPoliciesResponse:
+ # We just care if the key exists. The response, by AWS specifications, is the same for all keys, as the only
+ # supported policy is "default":
+ # https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeyPolicies.html#API_ListKeyPolicies_ResponseElements
+ self._get_kms_key(
+ context.account_id, context.region, request.get("KeyId"), any_key_state_allowed=True
+ )
+ return ListKeyPoliciesResponse(PolicyNames=["default"], Truncated=False)
+
+ @handler("PutKeyPolicy", expand=False)
+ def put_key_policy(self, context: RequestContext, request: PutKeyPolicyRequest) -> None:
+ key = self._get_kms_key(
+ context.account_id, context.region, request.get("KeyId"), any_key_state_allowed=True
+ )
+ if request.get("PolicyName") != "default":
+ raise UnsupportedOperationException("Only default policy is supported")
+ key.policy = request.get("Policy")
+
+ @handler("GetKeyPolicy", expand=False)
+ def get_key_policy(
+ self, context: RequestContext, request: GetKeyPolicyRequest
+ ) -> GetKeyPolicyResponse:
+ key = self._get_kms_key(
+ context.account_id, context.region, request.get("KeyId"), any_key_state_allowed=True
+ )
+ if request.get("PolicyName") != "default":
+ raise NotFoundException("No such policy exists")
+ return GetKeyPolicyResponse(Policy=key.policy)
+
+ @handler("ListResourceTags", expand=False)
+ def list_resource_tags(
+ self, context: RequestContext, request: ListResourceTagsRequest
+ ) -> ListResourceTagsResponse:
+ key = self._get_kms_key(
+ context.account_id, context.region, request.get("KeyId"), any_key_state_allowed=True
+ )
+ keys_list = PaginatedList(
+ [{"TagKey": tag_key, "TagValue": tag_value} for tag_key, tag_value in key.tags.items()]
+ )
+ page, next_token = keys_list.get_page(
+ lambda tag: tag.get("TagKey"),
+ next_token=request.get("Marker"),
+ page_size=request.get("Limit", 50),
+ )
+ kwargs = {"NextMarker": next_token, "Truncated": True} if next_token else {}
+ return ListResourceTagsResponse(Tags=page, **kwargs)
+
+ @handler("TagResource", expand=False)
+ def tag_resource(self, context: RequestContext, request: TagResourceRequest) -> None:
+ key = self._get_kms_key(
+ context.account_id,
+ context.region,
+ request.get("KeyId"),
+ enabled_key_allowed=True,
+ disabled_key_allowed=True,
+ )
+ key.add_tags(request.get("Tags"))
+
+ @handler("UntagResource", expand=False)
+ def untag_resource(self, context: RequestContext, request: UntagResourceRequest) -> None:
+ key = self._get_kms_key(
+ context.account_id,
+ context.region,
+ request.get("KeyId"),
+ enabled_key_allowed=True,
+ disabled_key_allowed=True,
+ )
+ if not request.get("TagKeys"):
+ return
+ for tag_key in request.get("TagKeys"):
+ # AWS doesn't seem to mind removal of a non-existent tag, so we do not raise any exception.
+ key.tags.pop(tag_key, None)
+
+ def derive_shared_secret(
+ self,
+ context: RequestContext,
+ key_id: KeyIdType,
+ key_agreement_algorithm: KeyAgreementAlgorithmSpec,
+ public_key: PublicKeyType,
+ grant_tokens: GrantTokenList = None,
+ dry_run: NullableBooleanType = None,
+ recipient: RecipientInfo = None,
+ **kwargs,
+ ) -> DeriveSharedSecretResponse:
+ key = self._get_kms_key(
+ context.account_id,
+ context.region,
+ key_id,
+ enabled_key_allowed=True,
+ disabled_key_allowed=True,
+ )
+ key_usage = key.metadata.get("KeyUsage")
+ key_origin = key.metadata.get("Origin")
+
+ if key_usage != KeyUsageType.KEY_AGREEMENT:
+ raise InvalidKeyUsageException(
+ f"{key.metadata['Arn']} key usage is {key_usage} which is not valid for {context.operation.name}."
+ )
+
+ if key_agreement_algorithm != KeyAgreementAlgorithmSpec.ECDH:
+ raise ValidationException(
+ f"1 validation error detected: Value '{key_agreement_algorithm}' at 'keyAgreementAlgorithm' "
+ f"failed to satisfy constraint: Member must satisfy enum value set: [ECDH]"
+ )
+
+ # TODO: Verify the actual error raised
+ if key_origin not in [OriginType.AWS_KMS, OriginType.EXTERNAL]:
+ raise ValueError(f"Key origin: {key_origin} is not valid for {context.operation.name}.")
+
+ shared_secret = key.derive_shared_secret(public_key)
+ return DeriveSharedSecretResponse(
+ KeyId=key_id,
+ SharedSecret=shared_secret,
+ KeyAgreementAlgorithm=key_agreement_algorithm,
+ KeyOrigin=key_origin,
+ )
+
+ def _validate_key_state_not_pending_import(self, key: KmsKey):
+ if key.metadata["KeyState"] == KeyState.PendingImport:
+ raise KMSInvalidStateException(f"{key.metadata['Arn']} is pending import.")
+
+ def _validate_key_for_encryption_decryption(self, context: RequestContext, key: KmsKey):
+ key_usage = key.metadata["KeyUsage"]
+ if key_usage != "ENCRYPT_DECRYPT":
+ raise InvalidKeyUsageException(
+ f"{key.metadata['Arn']} key usage is {key_usage} which is not valid for {context.operation.name}."
+ )
+
+ def _validate_key_for_sign_verify(self, context: RequestContext, key: KmsKey):
+ key_usage = key.metadata["KeyUsage"]
+ if key_usage != "SIGN_VERIFY":
+ raise InvalidKeyUsageException(
+ f"{key.metadata['Arn']} key usage is {key_usage} which is not valid for {context.operation.name}."
+ )
+
+ def _validate_key_for_generate_verify_mac(self, context: RequestContext, key: KmsKey):
+ key_usage = key.metadata["KeyUsage"]
+ if key_usage != "GENERATE_VERIFY_MAC":
+ raise InvalidKeyUsageException(
+ f"{key.metadata['Arn']} key usage is {key_usage} which is not valid for {context.operation.name}."
+ )
+
+ def _validate_mac_msg_length(self, msg: bytes):
+ if len(msg) > 4096:
+ raise ValidationException(
+ "1 validation error detected: Value at 'message' failed to satisfy constraint: "
+ "Member must have length less than or equal to 4096"
+ )
+
+ def _validate_mac_algorithm(self, key: KmsKey, algorithm: str):
+ if not hasattr(MacAlgorithmSpec, algorithm):
+ raise ValidationException(
+ f"1 validation error detected: Value '{algorithm}' at 'macAlgorithm' "
+ f"failed to satisfy constraint: Member must satisfy enum value set: "
+ f"[HMAC_SHA_384, HMAC_SHA_256, HMAC_SHA_224, HMAC_SHA_512]"
+ )
+
+ key_spec = key.metadata["KeySpec"]
+ if x := algorithm.split("_"):
+ if len(x) == 3 and x[0] + "_" + x[2] != key_spec:
+ raise InvalidKeyUsageException(
+ f"Algorithm {algorithm} is incompatible with key spec {key_spec}."
+ )
+
+ def _validate_plaintext_length(self, plaintext: bytes):
+ if len(plaintext) > 4096:
+ raise ValidationException(
+ "1 validation error detected: Value at 'plaintext' failed to satisfy constraint: "
+ "Member must have length less than or equal to 4096"
+ )
+
+ def _validate_grant_request(self, data: Dict):
+ if "KeyId" not in data or "GranteePrincipal" not in data or "Operations" not in data:
+ raise ValidationError("Grant ID, key ID and grantee principal must be specified")
+
+ for operation in data["Operations"]:
+ if operation not in VALID_OPERATIONS:
+ raise ValidationError(
+ f"Value {['Operations']} at 'operations' failed to satisfy constraint: Member must satisfy"
+ f" constraint: [Member must satisfy enum value set: {VALID_OPERATIONS}]"
+ )
+
+ def _validate_plaintext_key_type_based(
+ self,
+ plaintext: PlaintextType,
+ key: KmsKey,
+ encryption_algorithm: EncryptionAlgorithmSpec = None,
+ ):
+ # max size values extracted from AWS boto3 documentation
+ # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/kms/client/encrypt.html
+ max_size_bytes = 4096 # max allowed size
+ if (
+ key.metadata["KeySpec"] == KeySpec.RSA_2048
+ and encryption_algorithm == EncryptionAlgorithmSpec.RSAES_OAEP_SHA_1
+ ):
+ max_size_bytes = 214
+ elif (
+ key.metadata["KeySpec"] == KeySpec.RSA_2048
+ and encryption_algorithm == EncryptionAlgorithmSpec.RSAES_OAEP_SHA_256
+ ):
+ max_size_bytes = 190
+ elif (
+ key.metadata["KeySpec"] == KeySpec.RSA_3072
+ and encryption_algorithm == EncryptionAlgorithmSpec.RSAES_OAEP_SHA_1
+ ):
+ max_size_bytes = 342
+ elif (
+ key.metadata["KeySpec"] == KeySpec.RSA_3072
+ and encryption_algorithm == EncryptionAlgorithmSpec.RSAES_OAEP_SHA_256
+ ):
+ max_size_bytes = 318
+ elif (
+ key.metadata["KeySpec"] == KeySpec.RSA_4096
+ and encryption_algorithm == EncryptionAlgorithmSpec.RSAES_OAEP_SHA_1
+ ):
+ max_size_bytes = 470
+ elif (
+ key.metadata["KeySpec"] == KeySpec.RSA_4096
+ and encryption_algorithm == EncryptionAlgorithmSpec.RSAES_OAEP_SHA_256
+ ):
+ max_size_bytes = 446
+
+ if len(plaintext) > max_size_bytes:
+ raise ValidationException(
+ f"Algorithm {encryption_algorithm} and key spec {key.metadata['KeySpec']} cannot encrypt data larger than {max_size_bytes} bytes."
+ )
+
+
+# ---------------
+# UTIL FUNCTIONS
+# ---------------
+
+# Different AWS services have some internal integrations with KMS. Some create keys, that are used to encrypt/decrypt
+# customer's data. Such keys can't be created from outside for security reasons. So AWS services use some internal
+# APIs to do that. Functions here are supposed to be used by other LocalStack services to have similar integrations
+# with KMS in LocalStack. As such, they are supposed to be proper APIs (as in error and security handling),
+# just with more features.
+
+
+def set_key_managed(key_id: str, account_id: str, region_name: str) -> None:
+ key = KmsProvider._get_kms_key(account_id, region_name, key_id)
+ key.metadata["KeyManager"] = "AWS"
diff --git a/localstack-core/localstack/services/kms/resource_providers/__init__.py b/localstack-core/localstack/services/kms/resource_providers/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/kms/resource_providers/aws_kms_alias.py b/localstack-core/localstack/services/kms/resource_providers/aws_kms_alias.py
new file mode 100644
index 0000000000000..81ecef65ca520
--- /dev/null
+++ b/localstack-core/localstack/services/kms/resource_providers/aws_kms_alias.py
@@ -0,0 +1,105 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class KMSAliasProperties(TypedDict):
+ AliasName: Optional[str]
+ TargetKeyId: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class KMSAliasProvider(ResourceProvider[KMSAliasProperties]):
+ TYPE = "AWS::KMS::Alias" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[KMSAliasProperties],
+ ) -> ProgressEvent[KMSAliasProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/AliasName
+
+ Required properties:
+ - AliasName
+ - TargetKeyId
+
+ Create-only properties:
+ - /properties/AliasName
+
+
+
+ IAM permissions required:
+ - kms:CreateAlias
+
+ """
+ model = request.desired_state
+ kms = request.aws_client_factory.kms
+
+ kms.create_alias(AliasName=model["AliasName"], TargetKeyId=model["TargetKeyId"])
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[KMSAliasProperties],
+ ) -> ProgressEvent[KMSAliasProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - kms:ListAliases
+ """
+ raise NotImplementedError
+
+ def delete(
+ self,
+ request: ResourceRequest[KMSAliasProperties],
+ ) -> ProgressEvent[KMSAliasProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - kms:DeleteAlias
+ """
+ model = request.desired_state
+ kms = request.aws_client_factory.kms
+
+ kms.delete_alias(AliasName=model["AliasName"])
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[KMSAliasProperties],
+ ) -> ProgressEvent[KMSAliasProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - kms:UpdateAlias
+ """
+ raise NotImplementedError
diff --git a/localstack-core/localstack/services/kms/resource_providers/aws_kms_alias.schema.json b/localstack-core/localstack/services/kms/resource_providers/aws_kms_alias.schema.json
new file mode 100644
index 0000000000000..e3eb5a1591f1d
--- /dev/null
+++ b/localstack-core/localstack/services/kms/resource_providers/aws_kms_alias.schema.json
@@ -0,0 +1,61 @@
+{
+ "typeName": "AWS::KMS::Alias",
+ "description": "The AWS::KMS::Alias resource specifies a display name for an AWS KMS key in AWS Key Management Service (AWS KMS). You can use an alias to identify an AWS KMS key in cryptographic operations.",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-rpdk.git",
+ "properties": {
+ "AliasName": {
+ "description": "Specifies the alias name. This value must begin with alias/ followed by a name, such as alias/ExampleAlias. The alias name cannot begin with alias/aws/. The alias/aws/ prefix is reserved for AWS managed keys.",
+ "type": "string",
+ "pattern": "^(alias/)[a-zA-Z0-9:/_-]+$",
+ "minLength": 1,
+ "maxLength": 256
+ },
+ "TargetKeyId": {
+ "description": "Identifies the AWS KMS key to which the alias refers. Specify the key ID or the Amazon Resource Name (ARN) of the AWS KMS key. You cannot specify another alias. For help finding the key ID and ARN, see Finding the Key ID and ARN in the AWS Key Management Service Developer Guide.",
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 256
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "AliasName",
+ "TargetKeyId"
+ ],
+ "createOnlyProperties": [
+ "/properties/AliasName"
+ ],
+ "primaryIdentifier": [
+ "/properties/AliasName"
+ ],
+ "tagging": {
+ "taggable": false
+ },
+ "handlers": {
+ "create": {
+ "permissions": [
+ "kms:CreateAlias"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "kms:ListAliases"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "kms:UpdateAlias"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "kms:DeleteAlias"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "kms:ListAliases"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/kms/resource_providers/aws_kms_alias_plugin.py b/localstack-core/localstack/services/kms/resource_providers/aws_kms_alias_plugin.py
new file mode 100644
index 0000000000000..172d4915576ce
--- /dev/null
+++ b/localstack-core/localstack/services/kms/resource_providers/aws_kms_alias_plugin.py
@@ -0,0 +1,18 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class KMSAliasProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::KMS::Alias"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.kms.resource_providers.aws_kms_alias import KMSAliasProvider
+
+ self.factory = KMSAliasProvider
diff --git a/localstack-core/localstack/services/kms/resource_providers/aws_kms_key.py b/localstack-core/localstack/services/kms/resource_providers/aws_kms_key.py
new file mode 100644
index 0000000000000..6228292ed2953
--- /dev/null
+++ b/localstack-core/localstack/services/kms/resource_providers/aws_kms_key.py
@@ -0,0 +1,190 @@
+# LocalStack Resource Provider Scaffolding v2
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Optional, TypedDict
+
+import localstack.services.cloudformation.provider_utils as util
+from localstack.services.cloudformation.resource_provider import (
+ OperationStatus,
+ ProgressEvent,
+ ResourceProvider,
+ ResourceRequest,
+)
+
+
+class KMSKeyProperties(TypedDict):
+ KeyPolicy: Optional[dict | str]
+ Arn: Optional[str]
+ Description: Optional[str]
+ EnableKeyRotation: Optional[bool]
+ Enabled: Optional[bool]
+ KeyId: Optional[str]
+ KeySpec: Optional[str]
+ KeyUsage: Optional[str]
+ MultiRegion: Optional[bool]
+ PendingWindowInDays: Optional[int]
+ Tags: Optional[list[Tag]]
+
+
+class Tag(TypedDict):
+ Key: Optional[str]
+ Value: Optional[str]
+
+
+REPEATED_INVOCATION = "repeated_invocation"
+
+
+class KMSKeyProvider(ResourceProvider[KMSKeyProperties]):
+ TYPE = "AWS::KMS::Key" # Autogenerated. Don't change
+ SCHEMA = util.get_schema_path(Path(__file__)) # Autogenerated. Don't change
+
+ def create(
+ self,
+ request: ResourceRequest[KMSKeyProperties],
+ ) -> ProgressEvent[KMSKeyProperties]:
+ """
+ Create a new resource.
+
+ Primary identifier fields:
+ - /properties/KeyId
+
+ Required properties:
+ - KeyPolicy
+
+
+
+ Read-only properties:
+ - /properties/Arn
+ - /properties/KeyId
+
+ IAM permissions required:
+ - kms:CreateKey
+ - kms:EnableKeyRotation
+ - kms:DisableKey
+ - kms:TagResource
+
+ """
+ model = request.desired_state
+ kms = request.aws_client_factory.kms
+
+ params = util.select_attributes(model, ["Description", "KeySpec", "KeyUsage"])
+
+ if model.get("KeyPolicy"):
+ params["Policy"] = json.dumps(model["KeyPolicy"])
+
+ if model.get("Tags"):
+ params["Tags"] = [
+ {"TagKey": tag["Key"], "TagValue": tag["Value"]} for tag in model.get("Tags", [])
+ ]
+ response = kms.create_key(**params)
+ model["KeyId"] = response["KeyMetadata"]["KeyId"]
+ model["Arn"] = response["KeyMetadata"]["Arn"]
+
+ # key is created but some fields map to separate api calls
+ if model.get("EnableKeyRotation", False):
+ kms.enable_key_rotation(KeyId=model["KeyId"])
+ else:
+ kms.disable_key_rotation(KeyId=model["KeyId"])
+
+ if model.get("Enabled", True):
+ kms.enable_key(KeyId=model["KeyId"])
+ else:
+ kms.disable_key(KeyId=model["KeyId"])
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def read(
+ self,
+ request: ResourceRequest[KMSKeyProperties],
+ ) -> ProgressEvent[KMSKeyProperties]:
+ """
+ Fetch resource information
+
+ IAM permissions required:
+ - kms:DescribeKey
+ - kms:GetKeyPolicy
+ - kms:GetKeyRotationStatus
+ - kms:ListResourceTags
+ """
+ kms = request.aws_client_factory.kms
+ key_id = request.desired_state["KeyId"]
+
+ key = kms.describe_key(KeyId=key_id)
+
+ policy = kms.get_key_policy(KeyId=key_id, PolicyName="default")
+ rotation_status = kms.get_key_rotation_status(KeyId=key_id)
+ tags = kms.list_resource_tags(KeyId=key_id)
+
+ model = util.select_attributes(key["KeyMetadata"], self.SCHEMA["properties"])
+ model["KeyPolicy"] = json.loads(policy["Policy"])
+ model["EnableKeyRotation"] = rotation_status["KeyRotationEnabled"]
+ # Super consistent api... KMS api does return TagKey/TagValue, but the CC api transforms it to Key/Value
+ # It migth be worth noting if there are more apis for which CC does it again
+ model["Tags"] = [{"Key": tag["TagKey"], "Value": tag["TagValue"]} for tag in tags["Tags"]]
+
+ if "Origin" not in model:
+ model["Origin"] = "AWS_KMS"
+
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_model=model)
+
+ def delete(
+ self,
+ request: ResourceRequest[KMSKeyProperties],
+ ) -> ProgressEvent[KMSKeyProperties]:
+ """
+ Delete a resource
+
+ IAM permissions required:
+ - kms:DescribeKey
+ - kms:ScheduleKeyDeletion
+ """
+ model = request.desired_state
+ kms = request.aws_client_factory.kms
+
+ kms.schedule_key_deletion(KeyId=model["KeyId"])
+
+ return ProgressEvent(
+ status=OperationStatus.SUCCESS,
+ resource_model=model,
+ custom_context=request.custom_context,
+ )
+
+ def update(
+ self,
+ request: ResourceRequest[KMSKeyProperties],
+ ) -> ProgressEvent[KMSKeyProperties]:
+ """
+ Update a resource
+
+ IAM permissions required:
+ - kms:DescribeKey
+ - kms:DisableKey
+ - kms:DisableKeyRotation
+ - kms:EnableKey
+ - kms:EnableKeyRotation
+ - kms:PutKeyPolicy
+ - kms:TagResource
+ - kms:UntagResource
+ - kms:UpdateKeyDescription
+ """
+ raise NotImplementedError
+
+ def list(self, request: ResourceRequest[KMSKeyProperties]) -> ProgressEvent[KMSKeyProperties]:
+ """
+ List a resource
+
+ IAM permissions required:
+ - kms:ListKeys
+ - kms:DescribeKey
+ """
+ kms = request.aws_client_factory.kms
+
+ response = kms.list_keys(Limit=10)
+ models = [{"KeyId": key["KeyId"]} for key in response["Keys"]]
+ return ProgressEvent(status=OperationStatus.SUCCESS, resource_models=models)
diff --git a/localstack-core/localstack/services/kms/resource_providers/aws_kms_key.schema.json b/localstack-core/localstack/services/kms/resource_providers/aws_kms_key.schema.json
new file mode 100644
index 0000000000000..782d35fa134ac
--- /dev/null
+++ b/localstack-core/localstack/services/kms/resource_providers/aws_kms_key.schema.json
@@ -0,0 +1,172 @@
+{
+ "typeName": "AWS::KMS::Key",
+ "description": "The AWS::KMS::Key resource specifies an AWS KMS key in AWS Key Management Service (AWS KMS). Authorized users can use the AWS KMS key to encrypt and decrypt small amounts of data (up to 4096 bytes), but they are more commonly used to generate data keys. You can also use AWS KMS keys to encrypt data stored in AWS services that are integrated with AWS KMS or within their applications.",
+ "sourceUrl": "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-kms",
+ "definitions": {
+ "Tag": {
+ "description": "A key-value pair to associate with a resource.",
+ "type": "object",
+ "properties": {
+ "Key": {
+ "type": "string",
+ "description": "The key name of the tag. You can specify a value that is 1 to 128 Unicode characters in length and cannot be prefixed with aws:. You can use any of the following characters: the set of Unicode letters, digits, whitespace, _, ., /, =, +, and -.",
+ "minLength": 1,
+ "maxLength": 128
+ },
+ "Value": {
+ "type": "string",
+ "description": "The value for the tag. You can specify a value that is 0 to 256 Unicode characters in length and cannot be prefixed with aws:. You can use any of the following characters: the set of Unicode letters, digits, whitespace, _, ., /, =, +, and -.",
+ "minLength": 0,
+ "maxLength": 256
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "Key",
+ "Value"
+ ]
+ }
+ },
+ "properties": {
+ "Description": {
+ "description": "A description of the AWS KMS key. Use a description that helps you to distinguish this AWS KMS key from others in the account, such as its intended use.",
+ "type": "string",
+ "minLength": 0,
+ "maxLength": 8192
+ },
+ "Enabled": {
+ "description": "Specifies whether the AWS KMS key is enabled. Disabled AWS KMS keys cannot be used in cryptographic operations.",
+ "type": "boolean"
+ },
+ "EnableKeyRotation": {
+ "description": "Enables automatic rotation of the key material for the specified AWS KMS key. By default, automation key rotation is not enabled.",
+ "type": "boolean"
+ },
+ "KeyPolicy": {
+ "description": "The key policy that authorizes use of the AWS KMS key. The key policy must observe the following rules.",
+ "type": [
+ "object",
+ "string"
+ ]
+ },
+ "KeyUsage": {
+ "description": "Determines the cryptographic operations for which you can use the AWS KMS key. The default value is ENCRYPT_DECRYPT. This property is required only for asymmetric AWS KMS keys. You can't change the KeyUsage value after the AWS KMS key is created.",
+ "type": "string",
+ "default": "ENCRYPT_DECRYPT",
+ "enum": [
+ "ENCRYPT_DECRYPT",
+ "SIGN_VERIFY",
+ "GENERATE_VERIFY_MAC"
+ ]
+ },
+ "KeySpec": {
+ "description": "Specifies the type of AWS KMS key to create. The default value is SYMMETRIC_DEFAULT. This property is required only for asymmetric AWS KMS keys. You can't change the KeySpec value after the AWS KMS key is created.",
+ "type": "string",
+ "default": "SYMMETRIC_DEFAULT",
+ "enum": [
+ "SYMMETRIC_DEFAULT",
+ "RSA_2048",
+ "RSA_3072",
+ "RSA_4096",
+ "ECC_NIST_P256",
+ "ECC_NIST_P384",
+ "ECC_NIST_P521",
+ "ECC_SECG_P256K1",
+ "HMAC_224",
+ "HMAC_256",
+ "HMAC_384",
+ "HMAC_512",
+ "SM2"
+ ]
+ },
+ "MultiRegion": {
+ "description": "Specifies whether the AWS KMS key should be Multi-Region. You can't change the MultiRegion value after the AWS KMS key is created.",
+ "type": "boolean",
+ "default": false
+ },
+ "PendingWindowInDays": {
+ "description": "Specifies the number of days in the waiting period before AWS KMS deletes an AWS KMS key that has been removed from a CloudFormation stack. Enter a value between 7 and 30 days. The default value is 30 days.",
+ "type": "integer",
+ "minimum": 7,
+ "maximum": 30
+ },
+ "Tags": {
+ "description": "An array of key-value pairs to apply to this resource.",
+ "type": "array",
+ "uniqueItems": true,
+ "insertionOrder": false,
+ "items": {
+ "$ref": "#/definitions/Tag"
+ }
+ },
+ "Arn": {
+ "type": "string"
+ },
+ "KeyId": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "KeyPolicy"
+ ],
+ "readOnlyProperties": [
+ "/properties/Arn",
+ "/properties/KeyId"
+ ],
+ "primaryIdentifier": [
+ "/properties/KeyId"
+ ],
+ "writeOnlyProperties": [
+ "/properties/PendingWindowInDays"
+ ],
+ "tagging": {
+ "taggable": true,
+ "tagOnCreate": true,
+ "tagUpdatable": true,
+ "cloudFormationSystemTags": false
+ },
+ "handlers": {
+ "create": {
+ "permissions": [
+ "kms:CreateKey",
+ "kms:EnableKeyRotation",
+ "kms:DisableKey",
+ "kms:TagResource"
+ ]
+ },
+ "read": {
+ "permissions": [
+ "kms:DescribeKey",
+ "kms:GetKeyPolicy",
+ "kms:GetKeyRotationStatus",
+ "kms:ListResourceTags"
+ ]
+ },
+ "update": {
+ "permissions": [
+ "kms:DescribeKey",
+ "kms:DisableKey",
+ "kms:DisableKeyRotation",
+ "kms:EnableKey",
+ "kms:EnableKeyRotation",
+ "kms:PutKeyPolicy",
+ "kms:TagResource",
+ "kms:UntagResource",
+ "kms:UpdateKeyDescription"
+ ]
+ },
+ "delete": {
+ "permissions": [
+ "kms:DescribeKey",
+ "kms:ScheduleKeyDeletion"
+ ]
+ },
+ "list": {
+ "permissions": [
+ "kms:ListKeys",
+ "kms:DescribeKey"
+ ]
+ }
+ }
+}
diff --git a/localstack-core/localstack/services/kms/resource_providers/aws_kms_key_plugin.py b/localstack-core/localstack/services/kms/resource_providers/aws_kms_key_plugin.py
new file mode 100644
index 0000000000000..a03c3c714af8c
--- /dev/null
+++ b/localstack-core/localstack/services/kms/resource_providers/aws_kms_key_plugin.py
@@ -0,0 +1,18 @@
+from typing import Optional, Type
+
+from localstack.services.cloudformation.resource_provider import (
+ CloudFormationResourceProviderPlugin,
+ ResourceProvider,
+)
+
+
+class KMSKeyProviderPlugin(CloudFormationResourceProviderPlugin):
+ name = "AWS::KMS::Key"
+
+ def __init__(self):
+ self.factory: Optional[Type[ResourceProvider]] = None
+
+ def load(self):
+ from localstack.services.kms.resource_providers.aws_kms_key import KMSKeyProvider
+
+ self.factory = KMSKeyProvider
diff --git a/localstack-core/localstack/services/kms/utils.py b/localstack-core/localstack/services/kms/utils.py
new file mode 100644
index 0000000000000..ce1a65599e6c8
--- /dev/null
+++ b/localstack-core/localstack/services/kms/utils.py
@@ -0,0 +1,60 @@
+import re
+from typing import Tuple
+
+from localstack.aws.api.kms import Tag, TagException
+from localstack.services.kms.exceptions import ValidationException
+from localstack.utils.aws.arns import ARN_PARTITION_REGEX
+
+KMS_KEY_ARN_PATTERN = re.compile(
+ rf"{ARN_PARTITION_REGEX}:kms:(?P[^:]+):(?P\d{{12}}):key\/(?P[^:]+)$"
+)
+
+
+def get_hash_algorithm(signing_algorithm: str) -> str:
+ """
+ Return the hashing algorithm for a given signing algorithm.
+ eg. "RSASSA_PSS_SHA_512" -> "SHA_512"
+ """
+ return "_".join(signing_algorithm.rsplit(sep="_", maxsplit=-2)[-2:])
+
+
+def parse_key_arn(key_arn: str) -> Tuple[str, str, str]:
+ """
+ Parse a valid KMS key arn into its constituents.
+
+ :param key_arn: KMS key ARN
+ :return: Tuple of account ID, region name and key ID
+ """
+ return KMS_KEY_ARN_PATTERN.match(key_arn).group("account_id", "region_name", "key_id")
+
+
+def is_valid_key_arn(key_arn: str) -> bool:
+ """
+ Check if a given string is a valid KMS key ARN.
+ """
+ return KMS_KEY_ARN_PATTERN.match(key_arn) is not None
+
+
+def validate_alias_name(alias_name: str) -> None:
+ if not alias_name.startswith("alias/"):
+ raise ValidationException(
+ 'Alias must start with the prefix "alias/". Please see '
+ "https://docs.aws.amazon.com/kms/latest/developerguide/kms-alias.html"
+ )
+
+
+def validate_tag(tag_position: int, tag: Tag) -> None:
+ tag_key = tag.get("TagKey")
+ tag_value = tag.get("TagValue")
+
+ if len(tag_key) > 128:
+ raise ValidationException(
+ f"1 validation error detected: Value '{tag_key}' at 'tags.{tag_position}.member.tagKey' failed to satisfy constraint: Member must have length less than or equal to 128"
+ )
+ if len(tag_value) > 256:
+ raise ValidationException(
+ f"1 validation error detected: Value '{tag_value}' at 'tags.{tag_position}.member.tagValue' failed to satisfy constraint: Member must have length less than or equal to 256"
+ )
+
+ if tag_key.lower().startswith("aws:"):
+ raise TagException("Tags beginning with aws: are reserved")
diff --git a/localstack-core/localstack/services/lambda_/__init__.py b/localstack-core/localstack/services/lambda_/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/localstack-core/localstack/services/lambda_/api_utils.py b/localstack-core/localstack/services/lambda_/api_utils.py
new file mode 100644
index 0000000000000..18b0c7d2d09d6
--- /dev/null
+++ b/localstack-core/localstack/services/lambda_/api_utils.py
@@ -0,0 +1,762 @@
+"""Utilities related to Lambda API operations such as ARN handling, validations, and output formatting.
+Everything related to behavior or implicit functionality goes into `lambda_utils.py`.
+"""
+
+import datetime
+import random
+import re
+import string
+from typing import TYPE_CHECKING, Any, Optional, Tuple
+
+from localstack.aws.api import CommonServiceException, RequestContext
+from localstack.aws.api import lambda_ as api_spec
+from localstack.aws.api.lambda_ import (
+ AliasConfiguration,
+ Architecture,
+ DeadLetterConfig,
+ EnvironmentResponse,
+ EphemeralStorage,
+ FunctionConfiguration,
+ FunctionUrlAuthType,
+ ImageConfig,
+ ImageConfigResponse,
+ InvalidParameterValueException,
+ LayerVersionContentOutput,
+ PublishLayerVersionResponse,
+ ResourceNotFoundException,
+ TracingConfig,
+ VpcConfigResponse,
+)
+from localstack.services.lambda_.invocation import AccessDeniedException
+from localstack.services.lambda_.runtimes import ALL_RUNTIMES, VALID_LAYER_RUNTIMES, VALID_RUNTIMES
+from localstack.utils.aws.arns import ARN_PARTITION_REGEX, get_partition
+from localstack.utils.collections import merge_recursive
+
+if TYPE_CHECKING:
+ from localstack.services.lambda_.invocation.lambda_models import (
+ CodeSigningConfig,
+ Function,
+ FunctionUrlConfig,
+ FunctionVersion,
+ LayerVersion,
+ VersionAlias,
+ )
+ from localstack.services.lambda_.invocation.models import LambdaStore
+
+
+# Pattern for a full (both with and without qualifier) lambda function ARN
+FULL_FN_ARN_PATTERN = re.compile(
+ rf"{ARN_PARTITION_REGEX}:lambda:(?P[^:]+):(?P\d{{12}}):function:(?P[^:]+)(:(?P.*))?$"
+)
+
+# Pattern for a full (both with and without qualifier) lambda layer ARN
+# TODO: It looks like they added `|(arn:[a-zA-Z0-9-]+:lambda:::awslayer:[a-zA-Z0-9-_]+` in 2024-11
+LAYER_VERSION_ARN_PATTERN = re.compile(
+ rf"{ARN_PARTITION_REGEX}:lambda:(?P[^:]+):(?P\d{{12}}):layer:(?P[^:]+)(:(?P