diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a71a8a16c..b15aa915b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,25 @@ CHANGELOG ========= +Next Release (TBD) +================== + +* Add experimental support for websockets + (`#1017 `__) + +* API Gateway Endpoint Type Configuration + (`#1160 https://github.com/aws/chalice/pull/1160`__) + +* API Gateway Resource Policy Configuration + (`#1160 https://github.com/aws/chalice/pull/1160`__) + +1.9.1 +===== + +* Make MultiDict mutable + (`#1158 `__) + + 1.9.0 ===== diff --git a/Makefile b/Makefile index 20e3ec6d6..761ccd30a 100644 --- a/Makefile +++ b/Makefile @@ -42,8 +42,9 @@ doccheck: ##### DOC8 ###### # Correct rst formatting for documentation # + # TODO: Remove doc8 ## - doc8 docs/source --ignore-path docs/source/topics/multifile.rst + doc8 docs/source --ignore-path docs/source/topics/multifile.rst --ignore-path docs/source/tutorials/websockets.rst # # # Verify we have no broken external links diff --git a/chalice/__init__.py b/chalice/__init__.py index e6b33c65d..7540a3d8b 100644 --- a/chalice/__init__.py +++ b/chalice/__init__.py @@ -3,7 +3,7 @@ ChaliceViewError, BadRequestError, UnauthorizedError, ForbiddenError, NotFoundError, ConflictError, TooManyRequestsError, Response, CORSConfig, CustomAuthorizer, CognitoUserPoolAuthorizer, IAMAuthorizer, - UnprocessableEntityError, + UnprocessableEntityError, WebsocketDisconnectedError, AuthResponse, AuthRoute, Cron, Rate, __version__ as chalice_version ) # We're reassigning version here to keep mypy happy. diff --git a/chalice/analyzer.py b/chalice/analyzer.py index 6f2cab841..1af23ee98 100644 --- a/chalice/analyzer.py +++ b/chalice/analyzer.py @@ -662,7 +662,8 @@ class AppViewTransformer(ast.NodeTransformer): _CHALICE_DECORATORS = [ 'route', 'authorizer', 'lambda_function', 'schedule', 'on_s3_event', 'on_sns_message', - 'on_sqs_message', + 'on_sqs_message', 'on_ws_connect', 'on_ws_message', + 'on_ws_disconnect', ] def visit_FunctionDef(self, node): diff --git a/chalice/app.py b/chalice/app.py index 3476f8270..145897f80 100644 --- a/chalice/app.py +++ b/chalice/app.py @@ -11,7 +11,7 @@ from collections import defaultdict -__version__ = '1.9.0' +__version__ = '1.9.1' _PARAMS = re.compile(r'{\w+}') # Implementation note: This file is intended to be a standalone file @@ -22,6 +22,7 @@ try: from urllib.parse import unquote_plus from collections.abc import Mapping + from collections.abc import MutableMapping unquote_str = unquote_plus @@ -31,6 +32,7 @@ except ImportError: from urllib import unquote_plus from collections import Mapping + from collections import MutableMapping # This is borrowed from botocore/compat.py def unquote_str(value, encoding='utf-8'): @@ -93,6 +95,11 @@ class ChaliceError(Exception): pass +class WebsocketDisconnectedError(ChaliceError): + def __init__(self, connection_id): + self.connection_id = connection_id + + class ChaliceViewError(ChaliceError): STATUS_CODE = 500 @@ -150,11 +157,12 @@ class TooManyRequestsError(ChaliceViewError): TooManyRequestsError] -class MultiDict(Mapping): - """A read only mapping of key to list of values. +class MultiDict(MutableMapping): # pylint: disable=too-many-ancestors + """A mapping of key to list of values. Accessing it in the usual way will return the last value in the list. - Calling getlist will return a list of values with the same key. + Calling getlist will return a list of all the values associated with + the same key. """ def __init__(self, mapping): @@ -164,15 +172,19 @@ def __init__(self, mapping): self._dict = mapping def __getitem__(self, k): - values_list = self._dict[k] - try: - return values_list[-1] + return self._dict[k][-1] except IndexError: raise KeyError(k) + def __setitem__(self, k, v): + self._dict[k] = [v] + + def __delitem__(self, k): + del self._dict[k] + def getlist(self, k): - return list(self._dict.get(k, [])) + return list(self._dict[k]) def __len__(self): return len(self._dict) @@ -180,6 +192,12 @@ def __len__(self): def __iter__(self): return iter(self._dict) + def __repr__(self): + return 'MultiDict(%s)' % self._dict + + def __str__(self): + return repr(self) + class CaseInsensitiveMapping(Mapping): """Case insensitive and read-only mapping.""" @@ -495,6 +513,47 @@ def default_binary_types(self): return list(self._DEFAULT_BINARY_TYPES) +class WebsocketAPI(object): + _WEBSOCKET_ENDPOINT_TEMPLATE = 'https://{domain_name}/{stage}' + + def __init__(self): + self.session = None + self._endpoint = None + self._client = None + + def configure(self, domain_name, stage): + if self._endpoint is not None: + return + self._endpoint = self._WEBSOCKET_ENDPOINT_TEMPLATE.format( + domain_name=domain_name, + stage=stage, + ) + + def send(self, connection_id, message): + if self.session is None: + raise ValueError( + 'Assign app.websocket_api.session to a boto3 session before ' + 'using the WebsocketAPI' + ) + if self._endpoint is None: + raise ValueError( + 'WebsocketAPI.configure must be called before using the ' + 'WebsocketAPI' + ) + if self._client is None: + self._client = self.session.client( + 'apigatewaymanagementapi', + endpoint_url=self._endpoint, + ) + try: + self._client.post_to_connection( + ConnectionId=connection_id, + Data=message, + ) + except self._client.exceptions.GoneException: + raise WebsocketDisconnectedError(connection_id) + + class DecoratorAPI(object): def authorizer(self, ttl_seconds=None, execution_role=None, name=None): return self._create_registration_function( @@ -551,6 +610,27 @@ def lambda_function(self, name=None): return self._create_registration_function( handler_type='lambda_function', name=name) + def on_ws_connect(self, name=None): + return self._create_registration_function( + handler_type='on_ws_connect', + name=name, + registration_kwargs={'route_key': '$connect'}, + ) + + def on_ws_disconnect(self, name=None): + return self._create_registration_function( + handler_type='on_ws_disconnect', + name=name, + registration_kwargs={'route_key': '$disconnect'}, + ) + + def on_ws_message(self, name=None): + return self._create_registration_function( + handler_type='on_ws_message', + name=name, + registration_kwargs={'route_key': '$default'}, + ) + def _create_registration_function(self, handler_type, name=None, registration_kwargs=None): def _register_handler(user_handler): @@ -578,6 +658,17 @@ def _wrap_handler(self, handler_type, handler_name, user_handler): if handler_type in event_classes: return EventSourceHandler( user_handler, event_classes[handler_type]) + + websocket_event_classes = [ + 'on_ws_connect', + 'on_ws_message', + 'on_ws_disconnect', + ] + if handler_type in websocket_event_classes: + return WebsocketEventSourceHandler( + user_handler, WebsocketEvent, + self.websocket_api # pylint: disable=no-member + ) if handler_type == 'authorizer': # Authorizer is special cased and doesn't quite fit the # EventSourceHandler pattern. @@ -593,6 +684,7 @@ class _HandlerRegistration(object): def __init__(self): self.routes = defaultdict(dict) + self.websocket_handlers = {} self.builtin_auth_handlers = [] self.event_sources = [] self.pure_lambda_functions = [] @@ -622,6 +714,50 @@ def _do_register_handler(self, handler_type, name, user_handler, kwargs=kwargs, ) + def _attach_websocket_handler(self, handler): + route_key = handler.route_key_handled + decorator_name = { + '$default': 'on_ws_message', + '$connect': 'on_ws_connect', + '$disconnect': 'on_ws_disconnect', + }.get(route_key) + if route_key in self.websocket_handlers: + raise ValueError( + "Duplicate websocket handler: '%s'. There can only be one " + "handler for each websocket decorator." % decorator_name + ) + self.websocket_handlers[route_key] = handler + + def _register_on_ws_connect(self, name, user_handler, handler_string, + kwargs, **unused): + wrapper = WebsocketConnectConfig( + name=name, + handler_string=handler_string, + user_handler=user_handler, + ) + self._attach_websocket_handler(wrapper) + + def _register_on_ws_message(self, name, user_handler, handler_string, + kwargs, **unused): + route_key = kwargs['route_key'] + wrapper = WebsocketMessageConfig( + name=name, + route_key_handled=route_key, + handler_string=handler_string, + user_handler=user_handler, + ) + self._attach_websocket_handler(wrapper) + self.websocket_handlers[route_key] = wrapper + + def _register_on_ws_disconnect(self, name, user_handler, + handler_string, kwargs, **unused): + wrapper = WebsocketDisconnectConfig( + name=name, + handler_string=handler_string, + user_handler=user_handler, + ) + self._attach_websocket_handler(wrapper) + def _register_lambda_function(self, name, user_handler, handler_string, **unused): wrapper = LambdaFunction( @@ -734,6 +870,7 @@ def __init__(self, app_name, debug=False, configure_logs=True, env=None): super(Chalice, self).__init__() self.app_name = app_name self.api = APIGateway() + self.websocket_api = WebsocketAPI() self.current_request = None self.lambda_context = None self._debug = debug @@ -802,6 +939,24 @@ def _register_handler(self, handler_type, name, user_handler, self._do_register_handler(handler_type, name, user_handler, wrapped_handler, kwargs, options) + def _register_on_ws_connect(self, name, user_handler, handler_string, + kwargs, **unused): + self._features_used.add('WEBSOCKETS') + super(Chalice, self)._register_on_ws_connect( + name, user_handler, handler_string, kwargs, **unused) + + def _register_on_ws_message(self, name, user_handler, handler_string, + kwargs, **unused): + self._features_used.add('WEBSOCKETS') + super(Chalice, self)._register_on_ws_message( + name, user_handler, handler_string, kwargs, **unused) + + def _register_on_ws_disconnect(self, name, user_handler, + handler_string, kwargs, **unused): + self._features_used.add('WEBSOCKETS') + super(Chalice, self)._register_on_ws_disconnect( + name, user_handler, handler_string, kwargs, **unused) + def __call__(self, event, context): # This is what's invoked via lambda. # Sometimes the event can be something that's not @@ -1175,6 +1330,31 @@ def __init__(self, name, handler_string, queue, batch_size): self.batch_size = batch_size +class WebsocketConnectConfig(BaseEventSourceConfig): + CONNECT_ROUTE = '$connect' + + def __init__(self, name, handler_string, user_handler): + super(WebsocketConnectConfig, self).__init__(name, handler_string) + self.route_key_handled = self.CONNECT_ROUTE + self.handler_function = user_handler + + +class WebsocketMessageConfig(BaseEventSourceConfig): + def __init__(self, name, route_key_handled, handler_string, user_handler): + super(WebsocketMessageConfig, self).__init__(name, handler_string) + self.route_key_handled = route_key_handled + self.handler_function = user_handler + + +class WebsocketDisconnectConfig(BaseEventSourceConfig): + DISCONNECT_ROUTE = '$disconnect' + + def __init__(self, name, handler_string, user_handler): + super(WebsocketDisconnectConfig, self).__init__(name, handler_string) + self.route_key_handled = self.DISCONNECT_ROUTE + self.handler_function = user_handler + + class EventSourceHandler(object): def __init__(self, func, event_class): @@ -1186,6 +1366,22 @@ def __call__(self, event, context): return self.func(event_obj) +class WebsocketEventSourceHandler(object): + def __init__(self, func, event_class, websocket_api): + self.func = func + self.event_class = event_class + self.websocket_api = websocket_api + + def __call__(self, event, context): + event_obj = self.event_class(event, context) + self.websocket_api.configure( + event_obj.domain_name, + event_obj.stage, + ) + self.func(event_obj) + return {'statusCode': 200} + + # These classes contain all the event types that are passed # in as arguments in the lambda event handlers. These are # part of Chalice's public API and must be backwards compatible. @@ -1216,6 +1412,28 @@ def _extract_attributes(self, event_dict): self.resources = event_dict['resources'] +class WebsocketEvent(BaseLambdaEvent): + def __init__(self, event_dict, context): + super(WebsocketEvent, self).__init__(event_dict, context) + self._json_body = None + + def _extract_attributes(self, event_dict): + request_context = event_dict['requestContext'] + self.domain_name = request_context['domainName'] + self.stage = request_context['stage'] + self.connection_id = request_context['connectionId'] + self.body = event_dict.get('body') + + @property + def json_body(self): + if self._json_body is None: + try: + self._json_body = json.loads(self.body) + except ValueError: + raise BadRequestError('Error Parsing JSON') + return self._json_body + + class SNSEvent(BaseLambdaEvent): def _extract_attributes(self, event_dict): first_record = event_dict['Records'][0] diff --git a/chalice/app.pyi b/chalice/app.pyi index d5a807147..46249c422 100644 --- a/chalice/app.pyi +++ b/chalice/app.pyi @@ -5,6 +5,7 @@ from chalice.local import LambdaContext __version__ = ... # type: str class ChaliceError(Exception): ... +class WebsocketDisconnectedError(Exception): ... class ChaliceViewError(ChaliceError): __name__ = ... # type: str STATUS_CODE = ... # type: int @@ -116,6 +117,18 @@ class APIGateway(object): binary_types = ... # type: List[str] +class WebsocketAPI(object): + session = ... # type: Optional[Any] + + def configure(self, + domain_name: str, + stage: str) -> None: ... + + def send(self, + connection_id: str, + message: str) -> None: ... + + class DecoratorAPI(object): def authorizer(self, ttl_seconds: Optional[int]=None, @@ -151,6 +164,8 @@ class Chalice(DecoratorAPI): app_name = ... # type: str api = ... # type: APIGateway routes = ... # type: Dict[str, Dict[str, RouteEntry]] + websocket_api = ... # type: WebsocketAPI + websocket_handlers = ... # type: Dict[str, Any] current_request = ... # type: Request lambda_context = ... # type: LambdaContext debug = ... # type: bool diff --git a/chalice/awsclient.py b/chalice/awsclient.py index c9011f7be..d5ff7373d 100644 --- a/chalice/awsclient.py +++ b/chalice/awsclient.py @@ -450,21 +450,23 @@ def get_rest_api_id(self, name): return api['id'] return None - def rest_api_exists(self, rest_api_id): - # type: (str) -> bool + def get_rest_api(self, rest_api_id): + # type: (str) -> Dict[str, Any] """Check if an an API Gateway REST API exists.""" client = self._client('apigateway') try: - client.get_rest_api(restApiId=rest_api_id) - return True + result = client.get_rest_api(restApiId=rest_api_id) + result.pop('ResponseMetadata', None) + return result except client.exceptions.NotFoundException: - return False + return {} - def import_rest_api(self, swagger_document): - # type: (Dict[str, Any]) -> str + def import_rest_api(self, swagger_document, endpoint_type): + # type: (Dict[str, Any], str) -> str client = self._client('apigateway') response = client.import_rest_api( - body=json.dumps(swagger_document, indent=2) + body=json.dumps(swagger_document, indent=2), + parameters={'endpointConfigurationTypes': endpoint_type} ) rest_api_id = response['id'] return rest_api_id @@ -520,6 +522,19 @@ def add_permission_for_apigateway(self, function_name, service_name='apigateway', ) + def add_permission_for_apigateway_v2(self, function_name, + region_name, account_id, + api_id, random_id=None): + # type: (str, str, str, str, Optional[str]) -> None + """Authorize API gateway v2 to invoke a lambda function.""" + source_arn = self._build_source_arn_str(region_name, account_id, + api_id) + self._add_lambda_permission_if_needed( + source_arn=source_arn, + function_arn=function_name, + service_name='apigateway' + ) + def get_function_policy(self, function_name): # type: (str) -> Dict[str, Any] """Return the function policy for a lambda function. @@ -984,6 +999,119 @@ def verify_event_source_current(self, event_uuid, resource_name, except client.exceptions.ResourceNotFoundException: return False + def create_websocket_api(self, name): + # type: (str) -> str + client = self._client('apigatewayv2') + return self._call_client_method_with_retries( + client.create_api, + kwargs={ + 'Name': name, + 'ProtocolType': 'WEBSOCKET', + 'RouteSelectionExpression': '$request.body.action', + }, + max_attempts=10, + should_retry=self._is_settling_error, + )['ApiId'] + + def get_websocket_api_id(self, name): + # type: (str) -> Optional[str] + apis = self._client('apigatewayv2').get_apis()['Items'] + for api in apis: + if api['Name'] == name: + return api['ApiId'] + return None + + def websocket_api_exists(self, api_id): + # type: (str) -> bool + """Check if an API Gateway WEBSOCKET API exists.""" + client = self._client('apigatewayv2') + try: + client.get_api(ApiId=api_id) + return True + except client.exceptions.NotFoundException: + return False + + def delete_websocket_api(self, api_id): + # type: (str) -> None + client = self._client('apigatewayv2') + try: + client.delete_api(ApiId=api_id) + except client.exceptions.NotFoundException: + raise ResourceDoesNotExistError(api_id) + + def create_websocket_integration( + self, + api_id, + lambda_function, + handler_type, + ): + # type: (str, str, str) -> str + client = self._client('apigatewayv2') + return client.create_integration( + ApiId=api_id, + ConnectionType='INTERNET', + ContentHandlingStrategy='CONVERT_TO_TEXT', + Description=handler_type, + IntegrationType='AWS_PROXY', + IntegrationUri=lambda_function, + )['IntegrationId'] + + def create_websocket_route(self, api_id, route_key, integration_id): + # type: (str, str, str, ) -> None + client = self._client('apigatewayv2') + client.create_route( + ApiId=api_id, + RouteKey=route_key, + RouteResponseSelectionExpression='$default', + Target='integrations/%s' % integration_id, + ) + + def delete_websocket_routes(self, api_id, routes): + # type: (str, List[str]) -> None + client = self._client('apigatewayv2') + for route_id in routes: + client.delete_route( + ApiId=api_id, + RouteId=route_id, + ) + + def delete_websocket_integrations(self, api_id, integrations): + # type: (str, Dict[str, str]) -> None + client = self._client('apigatewayv2') + for integration_id in integrations: + client.delete_integration( + ApiId=api_id, + IntegrationId=integration_id, + ) + + def deploy_websocket_api(self, api_id): + # type: (str) -> str + client = self._client('apigatewayv2') + return client.create_deployment( + ApiId=api_id, + )['DeploymentId'] + + def get_websocket_routes(self, api_id): + # type: (str) -> List[str] + client = self._client('apigatewayv2') + return [i['RouteId'] + for i in client.get_routes(ApiId=api_id,)['Items']] + + def get_websocket_integrations(self, api_id): + # type: (str) -> List[str] + client = self._client('apigatewayv2') + return [item['IntegrationId'] + for item in client.get_integrations(ApiId=api_id)['Items']] + + def create_stage(self, api_id, stage_name, deployment_id): + # type: (str, str, str) -> None + client = self._client('apigatewayv2') + client.create_stage( + ApiId=api_id, + StageName=stage_name, + DeploymentId=deployment_id, + ) + def _call_client_method_with_retries( self, method, # type: ClientMethod diff --git a/chalice/cli/factory.py b/chalice/cli/factory.py index 590d4cc3a..031447b0e 100644 --- a/chalice/cli/factory.py +++ b/chalice/cli/factory.py @@ -19,6 +19,7 @@ from chalice.package import AppPackager # noqa from chalice.constants import DEFAULT_STAGE_NAME from chalice.constants import DEFAULT_APIGATEWAY_STAGE_NAME +from chalice.constants import DEFAULT_ENDPOINT_TYPE from chalice.logs import LogRetriever from chalice import local from chalice.utils import UI # noqa @@ -142,6 +143,7 @@ def create_config_obj(self, chalice_stage_name=DEFAULT_STAGE_NAME, user_provided_params = {} # type: Dict[str, Any] default_params = {'project_dir': self.project_dir, 'api_gateway_stage': DEFAULT_APIGATEWAY_STAGE_NAME, + 'api_gateway_endpoint_type': DEFAULT_ENDPOINT_TYPE, 'autogen_policy': True} try: config_from_disk = self.load_project_config() diff --git a/chalice/config.py b/chalice/config.py index 89b6dd0d3..ffee34474 100644 --- a/chalice/config.py +++ b/chalice/config.py @@ -2,7 +2,7 @@ import sys import json -from typing import Dict, Any, Optional, List # noqa +from typing import Dict, Any, Optional, List, Union # noqa from chalice import __version__ as current_chalice_version from chalice.app import Chalice # noqa from chalice.constants import DEFAULT_STAGE_NAME @@ -223,6 +223,24 @@ def api_gateway_stage(self): return self._chain_lookup('api_gateway_stage', varies_per_chalice_stage=True) + @property + def api_gateway_endpoint_type(self): + # type: () -> str + return self._chain_lookup('api_gateway_endpoint_type', + varies_per_chalice_stage=True) + + @property + def api_gateway_endpoint_vpce(self): + # type: () -> Union[str, List[str]] + return self._chain_lookup('api_gateway_endpoint_vpce', + varies_per_chalice_stage=True) + + @property + def api_gateway_policy_file(self): + # type: () -> str + return self._chain_lookup('api_gateway_policy_file', + varies_per_chalice_stage=True) + @property def minimum_compression_size(self): # type: () -> int diff --git a/chalice/constants.py b/chalice/constants.py index 7708f7a79..0328e1a65 100644 --- a/chalice/constants.py +++ b/chalice/constants.py @@ -45,7 +45,7 @@ def index(): DEFAULT_STAGE_NAME = 'dev' DEFAULT_APIGATEWAY_STAGE_NAME = 'api' - +DEFAULT_ENDPOINT_TYPE = 'EDGE' DEFAULT_LAMBDA_TIMEOUT = 60 DEFAULT_LAMBDA_MEMORY_SIZE = 128 @@ -240,3 +240,12 @@ def index(): ], "Resource": "*", } + + +POST_TO_WEBSOCKET_CONNECTION_POLICY = { + "Effect": "Allow", + "Action": [ + "execute-api:ManageConnections" + ], + "Resource": "arn:aws:execute-api:*:*:*/@connections/*" +} diff --git a/chalice/deploy/deployer.py b/chalice/deploy/deployer.py index f86a8e65e..3b0a43689 100644 --- a/chalice/deploy/deployer.py +++ b/chalice/deploy/deployer.py @@ -81,7 +81,7 @@ """ - +# pylint: disable=too-many-lines import json import os import textwrap @@ -108,6 +108,7 @@ from chalice.constants import DEFAULT_LAMBDA_MEMORY_SIZE from chalice.constants import LAMBDA_TRUST_POLICY from chalice.constants import SQS_EVENT_SOURCE_POLICY +from chalice.constants import POST_TO_WEBSOCKET_CONNECTION_POLICY from chalice.deploy import models from chalice.deploy.executor import Executor from chalice.deploy.packager import PipRunner @@ -116,10 +117,10 @@ from chalice.deploy.packager import LambdaDeploymentPackager from chalice.deploy.planner import PlanStage from chalice.deploy.planner import RemoteState -from chalice.deploy.planner import ResourceSweeper from chalice.deploy.planner import NoopPlanner from chalice.deploy.swagger import TemplatedSwaggerGenerator from chalice.deploy.swagger import SwaggerGenerator # noqa +from chalice.deploy.sweeper import ResourceSweeper from chalice.deploy.validate import validate_configuration from chalice.policy import AppPolicyGenerator from chalice.utils import OSUtils @@ -189,7 +190,6 @@ def _get_error_message_for_connection_error(self, connection_error): # ) message = connection_error.args[0].args[0] underlying_error = connection_error.args[0].args[1] - if is_broken_pipe_error(underlying_error): message += ( ' Lambda closed the connection before chalice finished ' @@ -296,6 +296,7 @@ def create_build_stage(osutils, ui, swagger_gen): swagger_generator=swagger_gen, ), LambdaEventSourcePolicyInjector(), + WebsocketPolicyInjector() ], ) return build_stage @@ -315,7 +316,6 @@ def create_deletion_deployer(client, ui): class Deployer(object): - BACKEND_NAME = 'api' def __init__(self, @@ -395,6 +395,10 @@ def build(self, config, stage_name): rest_api = self._create_rest_api_model( config, deployment, stage_name) resources.append(rest_api) + if config.chalice_app.websocket_handlers: + websocket_api = self._create_websocket_api_model( + config, deployment, stage_name) + resources.append(websocket_api) return models.Application(stage_name, resources) def _create_lambda_event_resources(self, config, deployment, stage_name): @@ -455,13 +459,85 @@ def _create_rest_api_model(self, handler_name=auth.handler_string, stage_name=stage_name, ) authorizers.append(auth_lambda) + + policy = None + policy_path = config.api_gateway_policy_file + if (config.api_gateway_endpoint_type == 'PRIVATE' and not policy_path): + policy = models.IAMPolicy( + document=self._get_default_private_api_policy(config)) + elif policy_path: + policy = models.FileBasedIAMPolicy( + document=models.Placeholder.BUILD_STAGE, + filename=os.path.join( + config.project_dir, '.chalice', policy_path)) + return models.RestAPI( resource_name='rest_api', swagger_doc=models.Placeholder.BUILD_STAGE, + endpoint_type=config.api_gateway_endpoint_type, minimum_compression=minimum_compression, api_gateway_stage=config.api_gateway_stage, lambda_function=lambda_function, authorizers=authorizers, + policy=policy + ) + + def _get_default_private_api_policy(self, config): + # type: (Config) -> Dict[str, Any] + statements = [{ + "Effect": "Allow", + "Principal": "*", + "Action": "execute-api:Invoke", + "Resource": "arn:aws:execute-api:*:*:*", + "Condition": { + "StringEquals": { + "aws:SourceVpce": config.api_gateway_endpoint_vpce + } + } + }] + return {"Version": "2012-10-17", "Statement": statements} + + def _create_websocket_api_model( + self, + config, # type: Config + deployment, # type: models.DeploymentPackage + stage_name, # type: str + ): + # type: (...) -> models.WebsocketAPI + connect_handler = None # type: Optional[models.LambdaFunction] + message_handler = None # type: Optional[models.LambdaFunction] + disconnect_handler = None # type: Optional[models.LambdaFunction] + + routes = {h.route_key_handled: h.handler_string for h + in config.chalice_app.websocket_handlers.values()} + if '$connect' in routes: + connect_handler = self._create_lambda_model( + config=config, deployment=deployment, name='websocket_connect', + handler_name=routes['$connect'], stage_name=stage_name) + routes.pop('$connect') + if '$disconnect' in routes: + disconnect_handler = self._create_lambda_model( + config=config, deployment=deployment, + name='websocket_disconnect', + handler_name=routes['$disconnect'], stage_name=stage_name) + routes.pop('$disconnect') + if routes: + # If there are left over routes they are message handlers. + handler_string = list(routes.values())[0] + message_handler = self._create_lambda_model( + config=config, deployment=deployment, name='websocket_message', + handler_name=handler_string, stage_name=stage_name + ) + + return models.WebsocketAPI( + name='%s-%s-websocket-api' % (config.app_name, stage_name), + resource_name='websocket_api', + connect_function=connect_handler, + message_function=message_handler, + disconnect_function=disconnect_handler, + routes=[h.route_key_handled for h + in config.chalice_app.websocket_handlers.values()], + api_gateway_stage=config.api_gateway_stage, ) def _create_event_model(self, @@ -736,7 +812,6 @@ def _traverse(self, resource, ordered, seen): class BaseDeployStep(object): - def handle(self, config, resource): # type: (Config, models.Model) -> None name = 'handle_%s' % resource.__class__.__name__.lower() @@ -769,8 +844,7 @@ def handle_deploymentpackage(self, config, resource): # type: (Config, models.DeploymentPackage) -> None if isinstance(resource.filename, models.Placeholder): zip_filename = self._packager.create_deployment_package( - config.project_dir, config.lambda_python_version - ) + config.project_dir, config.lambda_python_version) resource.filename = zip_filename @@ -782,7 +856,7 @@ def __init__(self, swagger_generator): def handle_restapi(self, config, resource): # type: (Config, models.RestAPI) -> None swagger_doc = self._swagger_generator.generate_swagger( - config.chalice_app) + config.chalice_app, resource) resource.swagger_doc = swagger_doc @@ -811,6 +885,39 @@ def _inject_trigger_policy(self, document, policy): document['Statement'].append(policy) +class WebsocketPolicyInjector(BaseDeployStep): + def __init__(self): + # type: () -> None + self._policy_injected = False + + def handle_websocketapi(self, config, resource): + # type: (Config, models.WebsocketAPI) -> None + self._inject_into_function(config, resource.connect_function) + self._inject_into_function(config, resource.message_function) + self._inject_into_function(config, resource.disconnect_function) + + def _inject_into_function(self, config, lambda_function): + # type: (Config, Optional[models.LambdaFunction]) -> None + if lambda_function is None: + return + role = lambda_function.role + if role is None: + return + if (not self._policy_injected and + isinstance(role, models.ManagedIAMRole) and + isinstance(role.policy, models.AutoGenIAMPolicy) and + not isinstance(role.policy.document, + models.Placeholder)): + self._inject_policy( + role.policy.document, + POST_TO_WEBSOCKET_CONNECTION_POLICY.copy()) + self._policy_injected = True + + def _inject_policy(self, document, policy): + # type: (Dict[str, Any], Dict[str, Any]) -> None + document['Statement'].append(policy) + + class PolicyGenerator(BaseDeployStep): def __init__(self, policy_gen, osutils): # type: (AppPolicyGenerator, OSUtils) -> None @@ -869,9 +976,10 @@ def record_results(self, results, chalice_stage_name, project_dir): class DeploymentReporter(object): - # We want the Rest API to be displayed last. + # We want the API URLs to be displayed last. _SORT_ORDER = { 'rest_api': 100, + 'websocket_api': 100, } # The default is chosen to sort before the rest_api _DEFAULT_ORDERING = 50 @@ -896,16 +1004,17 @@ def generate_report(self, deployed_values): return '\n'.join(report) def _report_rest_api(self, resource, report): + # type: (Dict[str, Any], List[str]) -> None + report.append(' - Rest API URL: %s' % resource['rest_api_url']) + + def _report_websocket_api(self, resource, report): # type: (Dict[str, Any], List[str]) -> None report.append( - ' - Rest API URL: %s' % resource['rest_api_url'] - ) + ' - Websocket API URL: %s' % resource['websocket_api_url']) def _report_lambda_function(self, resource, report): # type: (Dict[str, Any], List[str]) -> None - report.append( - ' - Lambda ARN: %s' % resource['lambda_arn'] - ) + report.append(' - Lambda ARN: %s' % resource['lambda_arn']) def _default_report(self, resource, report): # type: (Dict[str, Any], List[str]) -> None diff --git a/chalice/deploy/models.py b/chalice/deploy/models.py index f2b70628e..d84b3bab9 100644 --- a/chalice/deploy/models.py +++ b/chalice/deploy/models.py @@ -1,3 +1,4 @@ +# pylint: disable=line-too-long import enum from typing import List, Dict, Optional, Any, TypeVar, Union, Set # noqa from typing import cast @@ -46,6 +47,13 @@ class CopyVariable(Instruction): to_var = attrib() # type: str +@attrs(frozen=True) +class CopyVariableFromDict(Instruction): + from_var = attrib() # type: str + key = attrib() # type: str + to_var = attrib() # type: str + + @attrs(frozen=True) class RecordResource(Instruction): resource_type = attrib() # type: str @@ -182,7 +190,9 @@ class RestAPI(ManagedModel): swagger_doc = attrib() # type: DV[Dict[str, Any]] minimum_compression = attrib() # type: str api_gateway_stage = attrib() # type: str + endpoint_type = attrib() # type: str lambda_function = attrib() # type: LambdaFunction + policy = attrib(default=None) # type: Optional[IAMPolicy] authorizers = attrib(default=Factory(list)) # type: List[LambdaFunction] def dependencies(self): @@ -190,6 +200,28 @@ def dependencies(self): return cast(List[Model], [self.lambda_function] + self.authorizers) +@attrs +class WebsocketAPI(ManagedModel): + resource_type = 'websocket_api' + name = attrib() # type: str + api_gateway_stage = attrib() # type: str + routes = attrib() # type: List[str] + connect_function = attrib() # type: Optional[LambdaFunction] + message_function = attrib() # type: Optional[LambdaFunction] + disconnect_function = attrib() # type: Optional[LambdaFunction] + + def dependencies(self): + # type: () -> List[Model] + functions = [] # type: List[Model] + if self.connect_function is not None: + functions.append(self.connect_function) + if self.message_function is not None: + functions.append(self.message_function) + if self.disconnect_function is not None: + functions.append(self.disconnect_function) + return functions + + @attrs class S3BucketNotification(ManagedModel): resource_type = 's3_event' diff --git a/chalice/deploy/planner.py b/chalice/deploy/planner.py index 34c1fc86a..18d731198 100644 --- a/chalice/deploy/planner.py +++ b/chalice/deploy/planner.py @@ -1,3 +1,6 @@ +# pylint: disable=too-many-lines +from collections import OrderedDict + from typing import List, Dict, Any, Optional, Union, Tuple, Set, cast # noqa from typing import Sequence # noqa @@ -105,145 +108,17 @@ def _resource_exists_restapi(self, resource): except ValueError: return False rest_api_id = deployed_values['rest_api_id'] - return self._client.rest_api_exists(rest_api_id) - - -class ResourceSweeper(object): - - def execute(self, plan, config): - # type: (models.Plan, Config) -> None - instructions = plan.instructions - marked = self._mark_resources(instructions) - deployed = config.deployed_resources(config.chalice_stage) - if deployed is not None: - remaining = self._determine_remaining(plan, deployed, marked) - self._plan_deletion(instructions, plan.messages, - remaining, deployed) - - def _determine_remaining(self, plan, deployed, marked): - # type: (models.Plan, DeployedResources, MarkedResource) -> List[str] - remaining = [] - deployed_resource_names = reversed(deployed.resource_names()) - for name in deployed_resource_names: - resource_values = deployed.resource_values(name) - if name not in marked: - remaining.append(name) - elif resource_values['resource_type'] == 's3_event': - # Special case, we have to check the resource values - # to see if they've changed. For s3 events, the resource - # name is not tied to the bucket, which means if you change - # the bucket, the resource name will stay the same. - # So we match up the bucket referenced in the instruction - # and the bucket recorded in the deployed values match up. - # If they don't then we need to clean up the bucket config - # referenced in the deployed values. - bucket = [instruction for instruction in marked[name] - if instruction.name == 'bucket' and - isinstance(instruction, - models.RecordResourceValue)][0] - if bucket.value != resource_values['bucket']: - remaining.append(name) - elif resource_values['resource_type'] == 'sns_event': - existing_topic = resource_values['topic'] - referenced_topic = [instruction for instruction in marked[name] - if instruction.name == 'topic' and - isinstance(instruction, - models.RecordResourceValue)][0] - if referenced_topic.value != existing_topic: - remaining.append(name) - elif resource_values['resource_type'] == 'sqs_event': - existing_queue = resource_values['queue'] - referenced_queue = [instruction for instruction in marked[name] - if instruction.name == 'queue' and - isinstance(instruction, - models.RecordResourceValue)][0] - if referenced_queue.value != existing_queue: - remaining.append(name) - return remaining - - def _mark_resources(self, plan): - # type: (List[models.Instruction]) -> MarkedResource - marked = {} # type: MarkedResource - for instruction in plan: - if isinstance(instruction, models.RecordResource): - marked.setdefault(instruction.resource_name, []).append( - instruction) - return marked - - def _plan_deletion(self, - plan, # type: List[models.Instruction] - messages, # type: Dict[int, str] - remaining, # type: List[str] - deployed, # type: DeployedResources - ): - # type: (...) -> None - for name in remaining: - resource_values = deployed.resource_values(name) - if resource_values['resource_type'] == 'lambda_function': - apicall = models.APICall( - method_name='delete_function', - params={'function_name': resource_values['lambda_arn']},) - messages[id(apicall)] = ( - "Deleting function: %s\n" % resource_values['lambda_arn']) - plan.append(apicall) - elif resource_values['resource_type'] == 'iam_role': - apicall = models.APICall( - method_name='delete_role', - params={'name': resource_values['role_name']}, - ) - messages[id(apicall)] = ( - "Deleting IAM role: %s\n" % resource_values['role_name']) - plan.append(apicall) - elif resource_values['resource_type'] == 'cloudwatch_event': - apicall = models.APICall( - method_name='delete_rule', - params={'rule_name': resource_values['rule_name']}, - ) - plan.append(apicall) - elif resource_values['resource_type'] == 'rest_api': - rest_api_id = resource_values['rest_api_id'] - apicall = models.APICall( - method_name='delete_rest_api', - params={'rest_api_id': rest_api_id} - ) - messages[id(apicall)] = ( - "Deleting Rest API: %s\n" % resource_values['rest_api_id']) - plan.append(apicall) - elif resource_values['resource_type'] == 's3_event': - bucket = resource_values['bucket'] - function_arn = resource_values['lambda_arn'] - plan.extend([ - models.APICall( - method_name='disconnect_s3_bucket_from_lambda', - params={'bucket': bucket, 'function_arn': function_arn} - ), - models.APICall( - method_name='remove_permission_for_s3_event', - params={'bucket': bucket, 'function_arn': function_arn} - ) - ]) - elif resource_values['resource_type'] == 'sns_event': - subscription_arn = resource_values['subscription_arn'] - plan.extend([ - models.APICall( - method_name='unsubscribe_from_topic', - params={'subscription_arn': subscription_arn}, - ), - models.APICall( - method_name='remove_permission_for_sns_topic', - params={ - 'topic_arn': resource_values['topic_arn'], - 'function_arn': resource_values['lambda_arn'], - }, - ) - ]) - elif resource_values['resource_type'] == 'sqs_event': - plan.extend([ - models.APICall( - method_name='remove_sqs_event_source', - params={'event_uuid': resource_values['event_uuid']}, - ) - ]) + return bool(self._client.get_rest_api(rest_api_id)) + + def _resource_exists_websocketapi(self, resource): + # type: (models.WebsocketAPI) -> bool + try: + deployed_values = self._deployed_resources.resource_values( + resource.resource_name) + except ValueError: + return False + api_id = deployed_values['websocket_api_id'] + return self._client.websocket_api_exists(api_id) class PlanStage(object): @@ -724,6 +599,217 @@ def _plan_scheduledevent(self, resource): ] return plan + def _create_websocket_function_configs(self, resource): + # type: (models.WebsocketAPI) -> Dict[str, Dict[str, Any]] + configs = OrderedDict() # type: Dict[str, Dict[str, Any]] + if resource.connect_function is not None: + configs['connect'] = self._create_websocket_function_config( + resource.connect_function) + if resource.message_function is not None: + configs['message'] = self._create_websocket_function_config( + resource.message_function) + if resource.disconnect_function is not None: + configs['disconnect'] = self._create_websocket_function_config( + resource.disconnect_function) + return configs + + def _create_websocket_function_config(self, function): + # type: (models.LambdaFunction) -> Dict[str, Any] + varname = '%s_lambda_arn' % function.resource_name + return { + 'function': function, + 'name': function.function_name, + 'varname': varname, + 'lambda_arn_var': Variable(varname), + } + + def _inject_websocket_integrations(self, configs): + # type: (Dict[str, Any]) -> Sequence[InstructionMsg] + instructions = [] # type: List[InstructionMsg] + for key, config in configs.items(): + instructions.append( + models.StoreValue( + name='websocket-%s-integration-lambda-path' % key, + value=StringFormat( + 'arn:aws:apigateway:{region_name}:lambda:path/' + '2015-03-31/functions/arn:aws:lambda:{region_name}:' + '{account_id}:function:%s/' + 'invocations' % config['name'], + ['region_name', 'account_id'], + ), + ), + ) + instructions.append( + models.APICall( + method_name='create_websocket_integration', + params={ + 'api_id': Variable('websocket_api_id'), + 'lambda_function': Variable( + 'websocket-%s-integration-lambda-path' % key), + 'handler_type': key, + }, + output_var='%s-integration-id' % key, + ), + ) + return instructions + + def _create_route_for_key(self, route_key): + # type: (str) -> models.APICall + integration_id = { + '$connect': 'connect-integration-id', + '$disconnect': 'disconnect-integration-id', + }.get(route_key, 'message-integration-id') + return models.APICall( + method_name='create_websocket_route', + params={ + 'api_id': Variable('websocket_api_id'), + 'route_key': route_key, + 'integration_id': Variable(integration_id), + }, + ) + + def _plan_websocketapi(self, resource): + # type: (models.WebsocketAPI) -> Sequence[InstructionMsg] + configs = self._create_websocket_function_configs(resource) + routes = resource.routes + + # Which lambda function we use here does not matter. We are only using + # it to find the account id and the region. + lambda_arn_var = list(configs.values())[0]['lambda_arn_var'] + shared_plan_preamble = [ + # The various API gateway API calls need + # to know the region name and account id so + # we'll take care of that up front and store + # them in variables. + models.BuiltinFunction( + 'parse_arn', + [lambda_arn_var], + output_var='parsed_lambda_arn', + ), + models.JPSearch('account_id', + input_var='parsed_lambda_arn', + output_var='account_id'), + models.JPSearch('region', + input_var='parsed_lambda_arn', + output_var='region_name'), + ] # type: List[InstructionMsg] + + # There's also a set of instructions that are needed + # at the end of deploying a websocket API that apply to both + # the update and create case. + shared_plan_epilogue = [ + models.StoreValue( + name='websocket_api_url', + value=StringFormat( + 'wss://{websocket_api_id}.execute-api.{region_name}' + '.amazonaws.com/%s/' % resource.api_gateway_stage, + ['websocket_api_id', 'region_name'], + ), + ), + models.RecordResourceVariable( + resource_type='websocket_api', + resource_name=resource.resource_name, + name='websocket_api_url', + variable_name='websocket_api_url', + ), + models.RecordResourceVariable( + resource_type='websocket_api', + resource_name=resource.resource_name, + name='websocket_api_id', + variable_name='websocket_api_id', + ), + ] # type: List[InstructionMsg] + + shared_plan_epilogue += [ + models.APICall( + method_name='add_permission_for_apigateway_v2', + params={'function_name': function_config['name'], + 'region_name': Variable('region_name'), + 'account_id': Variable('account_id'), + 'api_id': Variable('websocket_api_id')}, + ) for function_config in configs.values() + ] + + main_plan = [] # type: List[InstructionMsg] + if not self._remote_state.resource_exists(resource): + # The resource does not exist, we create it in full here. + main_plan += [ + (models.APICall( + method_name='create_websocket_api', + params={'name': resource.name}, + output_var='websocket_api_id', + ), "Creating websocket api: %s\n" % resource.name), + models.StoreValue( + name='routes', + value=[], + ), + ] + main_plan += self._inject_websocket_integrations(configs) + + for route_key in routes: + main_plan += [self._create_route_for_key(route_key)] + main_plan += [ + models.APICall( + method_name='deploy_websocket_api', + params={ + 'api_id': Variable('websocket_api_id'), + }, + output_var='deployment-id', + ), + models.APICall( + method_name='create_stage', + params={ + 'api_id': Variable('websocket_api_id'), + 'stage_name': resource.api_gateway_stage, + 'deployment_id': Variable('deployment-id'), + } + ), + ] + else: + # Already exists. Need to sync up the routes, the easiest way to do + # this is to delete them and their integrations and re-create them. + # They will not work if the lambda function changes from under + # them, and the logic for detecting that and making just the needed + # changes is complex. There is an integration test to ensure there + # no dropped messages during a redeployment. + deployed = self._remote_state.resource_deployed_values(resource) + main_plan += [ + models.StoreValue( + name='websocket_api_id', + value=deployed['websocket_api_id'] + ), + models.APICall( + method_name='get_websocket_routes', + params={'api_id': Variable('websocket_api_id')}, + output_var='routes', + ), + models.APICall( + method_name='delete_websocket_routes', + params={ + 'api_id': Variable('websocket_api_id'), + 'routes': Variable('routes'), + }, + ), + models.APICall( + method_name='get_websocket_integrations', + params={ + 'api_id': Variable('websocket_api_id'), + }, + output_var='integrations' + ), + models.APICall( + method_name='delete_websocket_integrations', + params={ + 'api_id': Variable('websocket_api_id'), + 'integrations': Variable('integrations'), + } + ) + ] + main_plan += self._inject_websocket_integrations(configs) + for route_key in routes: + main_plan += [self._create_route_for_key(route_key)] + return shared_plan_preamble + main_plan + shared_plan_epilogue + def _plan_restapi(self, resource): # type: (models.RestAPI) -> Sequence[InstructionMsg] function = resource.lambda_function @@ -758,17 +844,19 @@ def _plan_restapi(self, resource): # There's also a set of instructions that are needed # at the end of deploying a rest API that apply to both # the update and create case. + shared_plan_patch_ops = [{ + 'op': 'replace', + 'path': '/minimumCompressionSize', + 'value': resource.minimum_compression} + ] # type: List[Dict] + shared_plan_epilogue = [ models.APICall( method_name='update_rest_api', params={ 'rest_api_id': Variable('rest_api_id'), - 'patch_operations': [{ - 'op': 'replace', - 'path': '/minimumCompressionSize', - 'value': resource.minimum_compression, - }], - }, + 'patch_operations': shared_plan_patch_ops + } ), models.APICall( method_name='add_permission_for_apigateway', @@ -777,6 +865,11 @@ def _plan_restapi(self, resource): 'account_id': Variable('account_id'), 'rest_api_id': Variable('rest_api_id')}, ), + models.APICall( + method_name='deploy_rest_api', + params={'rest_api_id': Variable('rest_api_id'), + 'api_gateway_stage': resource.api_gateway_stage}, + ), models.StoreValue( name='rest_api_url', value=StringFormat( @@ -806,7 +899,8 @@ def _plan_restapi(self, resource): plan = shared_plan_preamble + [ (models.APICall( method_name='import_rest_api', - params={'swagger_document': resource.swagger_doc}, + params={'swagger_document': resource.swagger_doc, + 'endpoint_type': resource.endpoint_type}, output_var='rest_api_id', ), "Creating Rest API\n"), models.RecordResourceVariable( @@ -815,14 +909,24 @@ def _plan_restapi(self, resource): name='rest_api_id', variable_name='rest_api_id', ), - models.APICall( - method_name='deploy_rest_api', - params={'rest_api_id': Variable('rest_api_id'), - 'api_gateway_stage': resource.api_gateway_stage}, - ), - ] + shared_plan_epilogue + ] else: deployed = self._remote_state.resource_deployed_values(resource) + shared_plan_epilogue.insert( + 0, + models.APICall( + method_name='get_rest_api', + params={'rest_api_id': Variable('rest_api_id')}, + output_var='rest_api') + ) + shared_plan_patch_ops.append({ + 'op': 'replace', + 'path': StringFormat( + '/endpointConfiguration/types/%s' % ( + '{rest_api[endpointConfiguration][types][0]}'), + ['rest_api']), + 'value': resource.endpoint_type} + ) plan = shared_plan_preamble + [ models.StoreValue( name='rest_api_id', @@ -840,19 +944,9 @@ def _plan_restapi(self, resource): 'swagger_document': resource.swagger_doc, }, ), "Updating rest API\n"), - models.APICall( - method_name='deploy_rest_api', - params={'rest_api_id': Variable('rest_api_id'), - 'api_gateway_stage': resource.api_gateway_stage}, - ), - models.APICall( - method_name='add_permission_for_apigateway', - params={'function_name': function_name, - 'region_name': Variable('region_name'), - 'account_id': Variable('account_id'), - 'rest_api_id': Variable('rest_api_id')}, - ), - ] + shared_plan_epilogue + ] + + plan.extend(shared_plan_epilogue) return plan def _get_role_arn(self, resource): diff --git a/chalice/deploy/swagger.py b/chalice/deploy/swagger.py index 725ab7721..588188fa3 100644 --- a/chalice/deploy/swagger.py +++ b/chalice/deploy/swagger.py @@ -1,11 +1,12 @@ import copy import inspect -from typing import Any, List, Dict, Optional # noqa +from typing import Any, List, Dict, Optional, Union # noqa from chalice.app import Chalice, RouteEntry, Authorizer, CORSConfig # noqa from chalice.app import ChaliceAuthorizer from chalice.deploy.planner import StringFormat +from chalice.deploy.models import RestAPI # noqa from chalice.utils import to_cfn_resource_name @@ -32,14 +33,20 @@ def __init__(self, region, deployed_resources): self._region = region self._deployed_resources = deployed_resources - def generate_swagger(self, app): - # type: (Chalice) -> Dict[str, Any] + def generate_swagger(self, app, rest_api=None): + # type: (Chalice, Optional[RestAPI]) -> Dict[str, Any] api = copy.deepcopy(self._BASE_TEMPLATE) api['info']['title'] = app.app_name self._add_binary_types(api, app) self._add_route_paths(api, app) + self._add_resource_policy(api, rest_api) return api + def _add_resource_policy(self, api, rest_api): + # type: (Dict[str, Any], Optional[RestAPI]) -> None + if rest_api and rest_api.policy: + api['x-amazon-apigateway-policy'] = rest_api.policy.document + def _add_binary_types(self, api, app): # type: (Dict[str, Any], Chalice) -> None api['x-amazon-apigateway-binary-media-types'] = app.api.binary_types diff --git a/chalice/deploy/sweeper.py b/chalice/deploy/sweeper.py new file mode 100644 index 000000000..bdb3cc3f0 --- /dev/null +++ b/chalice/deploy/sweeper.py @@ -0,0 +1,156 @@ +from typing import List, Dict # noqa + +from chalice.config import Config, DeployedResources # noqa +from chalice.deploy import models + + +MarkedResource = Dict[str, List[models.RecordResource]] + + +class ResourceSweeper(object): + + def execute(self, plan, config): + # type: (models.Plan, Config) -> None + instructions = plan.instructions + marked = self._mark_resources(instructions) + deployed = config.deployed_resources(config.chalice_stage) + if deployed is not None: + remaining = self._determine_remaining(plan, deployed, marked) + self._plan_deletion(instructions, plan.messages, + remaining, deployed) + + def _determine_remaining(self, plan, deployed, marked): + # type: (models.Plan, DeployedResources, MarkedResource) -> List[str] + remaining = [] + deployed_resource_names = reversed(deployed.resource_names()) + for name in deployed_resource_names: + resource_values = deployed.resource_values(name) + if name not in marked: + remaining.append(name) + elif resource_values['resource_type'] == 's3_event': + # Special case, we have to check the resource values + # to see if they've changed. For s3 events, the resource + # name is not tied to the bucket, which means if you change + # the bucket, the resource name will stay the same. + # So we match up the bucket referenced in the instruction + # and the bucket recorded in the deployed values match up. + # If they don't then we need to clean up the bucket config + # referenced in the deployed values. + bucket = [instruction for instruction in marked[name] + if instruction.name == 'bucket' and + isinstance(instruction, + models.RecordResourceValue)][0] + if bucket.value != resource_values['bucket']: + remaining.append(name) + elif resource_values['resource_type'] == 'sns_event': + existing_topic = resource_values['topic'] + referenced_topic = [instruction for instruction in marked[name] + if instruction.name == 'topic' and + isinstance(instruction, + models.RecordResourceValue)][0] + if referenced_topic.value != existing_topic: + remaining.append(name) + elif resource_values['resource_type'] == 'sqs_event': + existing_queue = resource_values['queue'] + referenced_queue = [instruction for instruction in marked[name] + if instruction.name == 'queue' and + isinstance(instruction, + models.RecordResourceValue)][0] + if referenced_queue.value != existing_queue: + remaining.append(name) + return remaining + + def _mark_resources(self, plan): + # type: (List[models.Instruction]) -> MarkedResource + marked = {} # type: MarkedResource + for instruction in plan: + if isinstance(instruction, models.RecordResource): + marked.setdefault(instruction.resource_name, []).append( + instruction) + return marked + + def _plan_deletion(self, + plan, # type: List[models.Instruction] + messages, # type: Dict[int, str] + remaining, # type: List[str] + deployed, # type: DeployedResources + ): + # type: (...) -> None + for name in remaining: + resource_values = deployed.resource_values(name) + if resource_values['resource_type'] == 'lambda_function': + apicall = models.APICall( + method_name='delete_function', + params={'function_name': resource_values['lambda_arn']},) + messages[id(apicall)] = ( + "Deleting function: %s\n" % resource_values['lambda_arn']) + plan.append(apicall) + elif resource_values['resource_type'] == 'iam_role': + apicall = models.APICall( + method_name='delete_role', + params={'name': resource_values['role_name']}, + ) + messages[id(apicall)] = ( + "Deleting IAM role: %s\n" % resource_values['role_name']) + plan.append(apicall) + elif resource_values['resource_type'] == 'cloudwatch_event': + apicall = models.APICall( + method_name='delete_rule', + params={'rule_name': resource_values['rule_name']}, + ) + plan.append(apicall) + elif resource_values['resource_type'] == 'rest_api': + rest_api_id = resource_values['rest_api_id'] + apicall = models.APICall( + method_name='delete_rest_api', + params={'rest_api_id': rest_api_id} + ) + messages[id(apicall)] = ( + "Deleting Rest API: %s\n" % resource_values['rest_api_id']) + plan.append(apicall) + elif resource_values['resource_type'] == 's3_event': + bucket = resource_values['bucket'] + function_arn = resource_values['lambda_arn'] + plan.extend([ + models.APICall( + method_name='disconnect_s3_bucket_from_lambda', + params={'bucket': bucket, 'function_arn': function_arn} + ), + models.APICall( + method_name='remove_permission_for_s3_event', + params={'bucket': bucket, 'function_arn': function_arn} + ) + ]) + elif resource_values['resource_type'] == 'sns_event': + subscription_arn = resource_values['subscription_arn'] + plan.extend([ + models.APICall( + method_name='unsubscribe_from_topic', + params={'subscription_arn': subscription_arn}, + ), + models.APICall( + method_name='remove_permission_for_sns_topic', + params={ + 'topic_arn': resource_values['topic_arn'], + 'function_arn': resource_values['lambda_arn'], + }, + ) + ]) + elif resource_values['resource_type'] == 'sqs_event': + plan.extend([ + models.APICall( + method_name='remove_sqs_event_source', + params={'event_uuid': resource_values['event_uuid']}, + ) + ]) + elif resource_values['resource_type'] == 'websocket_api': + plan.append( + models.APICall( + method_name='delete_websocket_api', + params={'api_id': resource_values['websocket_api_id']}, + ) + ) + messages[id(plan[-1])] = ( + "Deleting Websocket API: %s\n" % + resource_values['websocket_api_id'] + ) diff --git a/chalice/deploy/validate.py b/chalice/deploy/validate.py index f394ccfea..9861aa693 100644 --- a/chalice/deploy/validate.py +++ b/chalice/deploy/validate.py @@ -45,6 +45,40 @@ def validate_configuration(config): validate_python_version(config) validate_unique_function_names(config) validate_feature_flags(config.chalice_app) + validate_endpoint_type(config) + validate_resource_policy(config) + + +def validate_resource_policy(config): + # type: (Config) -> None + if (config.api_gateway_endpoint_type != 'PRIVATE' and + config.api_gateway_endpoint_vpce): + raise ValueError( + "config.api_gateway_endpoint_vpce should only be " + "specified for PRIVATE api_gateway_endpoint_type") + if config.api_gateway_endpoint_type != 'PRIVATE': + return + if config.api_gateway_policy_file and config.api_gateway_endpoint_vpce: + raise ValueError( + "Can only specify one of api_gateway_policy_file and " + "api_gateway_endpoint_vpce") + if config.api_gateway_policy_file: + return + if not config.api_gateway_endpoint_vpce: + raise ValueError( + ("Private Endpoints require api_gateway_policy_file or " + "api_gateway_endpoint_vpce specified")) + + +def validate_endpoint_type(config): + # type: (Config) -> None + if not config.api_gateway_endpoint_type: + return + valid_types = ('EDGE', 'REGIONAL', 'PRIVATE') + if config.api_gateway_endpoint_type not in valid_types: + raise ValueError( + "api gateway endpoint type must be one of %s" % ( + ", ".join(valid_types))) def validate_feature_flags(chalice_app): diff --git a/chalice/package.py b/chalice/package.py index bc3a5333d..9f8c17dfa 100644 --- a/chalice/package.py +++ b/chalice/package.py @@ -215,6 +215,7 @@ def _generate_restapi(self, resource, template): resources['RestAPI'] = { 'Type': 'AWS::Serverless::Api', 'Properties': { + 'EndpointConfiguration': resource.endpoint_type, 'StageName': resource.api_gateway_stage, 'DefinitionBody': resource.swagger_doc, } @@ -296,6 +297,178 @@ def _inject_restapi_outputs(self, template): } } + def _add_websocket_lambda_integration( + self, api_ref, websocket_handler, resources): + # type: (Dict[str, Any], str, Dict[str, Any]) -> None + resources['%sAPIIntegration' % websocket_handler] = { + 'Type': 'AWS::ApiGatewayV2::Integration', + 'Properties': { + 'ApiId': api_ref, + 'ConnectionType': 'INTERNET', + 'ContentHandlingStrategy': 'CONVERT_TO_TEXT', + 'IntegrationType': 'AWS_PROXY', + 'IntegrationUri': { + 'Fn::Sub': [ + ( + 'arn:aws:apigateway:${AWS::Region}:lambda:path/' + '2015-03-31/functions/arn:aws:lambda:' + '${AWS::Region}:' '${AWS::AccountId}:function:' + '${WebsocketHandler}/invocations' + ), + {'WebsocketHandler': {'Ref': websocket_handler}} + ], + } + } + } + + def _add_websocket_lambda_invoke_permission( + self, api_ref, websocket_handler, resources): + # type: (Dict[str, str], str, Dict[str, Any]) -> None + resources['%sInvokePermission' % websocket_handler] = { + 'Type': 'AWS::Lambda::Permission', + 'Properties': { + 'FunctionName': {'Ref': websocket_handler}, + 'Action': 'lambda:InvokeFunction', + 'Principal': 'apigateway.amazonaws.com', + 'SourceArn': { + 'Fn::Sub': [ + ('arn:aws:execute-api:${AWS::Region}:${AWS::AccountId}' + ':${WebsocketAPIId}/*'), + {'WebsocketAPIId': api_ref}, + ] + }, + } + } + + def _add_websocket_lambda_integrations(self, api_ref, resources): + # type: (Dict[str, str], Dict[str, Any]) -> None + websocket_handlers = [ + 'WebsocketConnect', + 'WebsocketMessage', + 'WebsocketDisconnect', + ] + for handler in websocket_handlers: + if handler in resources: + self._add_websocket_lambda_integration( + api_ref, handler, resources) + self._add_websocket_lambda_invoke_permission( + api_ref, handler, resources) + + def _create_route_for_key(self, route_key, api_ref): + # type: (str, Dict[str, str]) -> Dict[str, Any] + integration_ref = { + '$connect': 'WebsocketConnectAPIIntegration', + '$disconnect': 'WebsocketDisconnectAPIIntegration', + }.get(route_key, 'WebsocketMessageAPIIntegration') + + return { + 'Type': 'AWS::ApiGatewayV2::Route', + 'Properties': { + 'ApiId': api_ref, + 'RouteKey': route_key, + 'Target': { + 'Fn::Join': [ + '/', + [ + 'integrations', + {'Ref': integration_ref}, + ] + ] + }, + }, + } + + def _generate_websocketapi(self, resource, template): + # type: (models.WebsocketAPI, Dict[str, Any]) -> None + resources = template['Resources'] + api_ref = {'Ref': 'WebsocketAPI'} + resources['WebsocketAPI'] = { + 'Type': 'AWS::ApiGatewayV2::Api', + 'Properties': { + 'Name': resource.name, + 'RouteSelectionExpression': '$request.body.action', + 'ProtocolType': 'WEBSOCKET', + } + } + + self._add_websocket_lambda_integrations(api_ref, resources) + + route_key_names = [] + for route in resource.routes: + key_name = 'Websocket%sRoute' % route.replace( + '$', '').replace('default', 'message').capitalize() + route_key_names.append(key_name) + resources[key_name] = self._create_route_for_key(route, api_ref) + + resources['WebsocketAPIDeployment'] = { + 'Type': 'AWS::ApiGatewayV2::Deployment', + 'DependsOn': route_key_names, + 'Properties': { + 'ApiId': api_ref, + } + } + + resources['WebsocketAPIStage'] = { + 'Type': 'AWS::ApiGatewayV2::Stage', + 'Properties': { + 'ApiId': api_ref, + 'DeploymentId': {'Ref': 'WebsocketAPIDeployment'}, + 'StageName': resource.api_gateway_stage, + } + } + + self._inject_websocketapi_outputs(template) + + def _inject_websocketapi_outputs(self, template): + # type: (Dict[str, Any]) -> None + # The 'Outputs' of the SAM template are considered + # part of the public API of chalice and therefore + # need to maintain backwards compatibility. This + # method uses the same output key names as the old + # deployer. + # For now, we aren't adding any of the new resources + # to the Outputs section until we can figure out + # a consist naming scheme. Ideally we don't use + # the autogen'd names that contain the md5 suffixes. + stage_name = template['Resources']['WebsocketAPIStage'][ + 'Properties']['StageName'] + outputs = template['Outputs'] + resources = template['Resources'] + outputs['WebsocketAPIId'] = { + 'Value': {'Ref': 'WebsocketAPI'} + } + if 'WebsocketConnect' in resources: + outputs['WebsocketConnectHandlerArn'] = { + 'Value': {'Fn::GetAtt': ['WebsocketConnect', 'Arn']} + } + outputs['WebsocketConnectHandlerName'] = { + 'Value': {'Ref': 'WebsocketConnect'} + } + if 'WebsocketMessage' in resources: + outputs['WebsocketMessageHandlerArn'] = { + 'Value': {'Fn::GetAtt': ['WebsocketMessage', 'Arn']} + } + outputs['WebsocketMessageHandlerName'] = { + 'Value': {'Ref': 'WebsocketMessage'} + } + if 'WebsocketDisconnect' in resources: + outputs['WebsocketDisconnectHandlerArn'] = { + 'Value': {'Fn::GetAtt': ['WebsocketDisconnect', 'Arn']} + } # There is not a lot of green in here. + outputs['WebsocketDisconnectHandlerName'] = { + 'Value': {'Ref': 'WebsocketDisconnect'} + } + outputs['WebsocketConnectEndpointURL'] = { + 'Value': { + 'Fn::Sub': ( + 'wss://${WebsocketAPI}.execute-api.${AWS::Region}' + # The api_gateway_stage is filled in when + # the template is built. + '.amazonaws.com/%s/' + ) % stage_name + } + } + # The various IAM roles/policies are handled in the # Lambda function generation. We're creating these # noop methods to indicate we've accounted for these diff --git a/docs/source/api.rst b/docs/source/api.rst index 37db01d77..a23343a6d 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -48,6 +48,11 @@ Chalice app = Chalice(app_name="appname") app.debug = True + .. attribute:: websocket_api + + An object of type :class:`WebsocketAPI`. This attribute can be used to + send messages to websocket clients connected through API Gateway. + .. method:: route(path, \*\*options) Register a view function for a particular URI path. This method @@ -320,6 +325,37 @@ Chalice routes defined the Blueprint. This allows you to set the root mount point for all URLs in a Blueprint. + .. method:: on_ws_connect(event) + + Create a Websocket API connect event handler. + + :param event: The :class:`WebsocketEvent` received to indicate a new + connection has been registered with API Gateway. The identifier of this + connection is under the :attr:`WebsocketEvent.connection_id` attribute. + + see :doc:`topics/websockets` for more information. + + .. method:: on_ws_message(event) + + Create a Websocket API message event handler. + + :param event: The :class:`WebsocketEvent` received to indicate API Gateway + received a message from a connected client. The identifier of the + client that sent the message is under the + :attr:`WebsocketEvent.connection_id` attribute. The content of the + message is available in the :attr:`WebsocketEvent.body` attribute. + + see :doc:`topics/websockets` for more information. + + .. method:: on_ws_disconnect(event) + + Create a Websocket API disconnect event handler. + + :param event: The :class:`WebsocketEvent` received to indicate an existing + connection has been disconnected from API Gateway. The identifier of this + connection is under the :attr:`WebsocketEvent.connection_id` attribute. + + see :doc:`topics/websockets` for more information. Request ======= @@ -622,6 +658,64 @@ APIGateway will manifest as a ``502`` Bad Gateway error. +WebsocketAPI +============ + +.. class:: WebsocketAPI + + This class is used to send messages to websocket clients connected to an API + Gateway Websocket API. + + .. attribute:: session + + A boto3 Session that will be used to send websocket messages to + clients. Any custom configuration can be set through a botocore + ``session``. This **must** be manually set before websocket features can + be used. + + .. code-block:: python + + import botocore + from boto3.session import Session + from chalice import Chalice + + app = Chalice('example') + session = botocore.session.Session() + session.set_config_variable('retries', {'max_attempts': 0}) + app.websocket_api.session = Session(botocore_session=session) + + .. method:: configure(domain_name, stage) + + Configure prepares the :class:`WebsocketAPI` to call the :meth:`send` + method. Without first calling this method calls to :meth:`send` will fail + with the message ``WebsocketAPI needs to be configured before sending + messages.``. This is because a boto3 ``apigatewaymanagmentapi`` client + must be created from the :attr:`session` with a custom endpoint in order + to properly communicate with our API Gateway WebsocketAPI. This method is + called on your behalf before each of the websocket handlers: + ``on_ws_connect``, ``on_ws_message``, ``on_ws_disconnect``. This ensures + that the :meth:`send` method is available in each of those handlers. + +.. _websocket-send: + + .. method:: send(connection_id, message) + + Method to send a message to a client. The ``connection_id`` is the unique + identifier of the socket to send the ``message`` to. The ``message`` must + be a utf-8 string. + + If the socket is disconnected it raises a :class:`WebsocketDisconnectedError` + error. + +.. class:: WebsocketDisconnectedError + + An exception raised when a message is sent to a websocket that has disconnected. + + .. attribute:: connection_id + + The unique identifier of the websocket that was disconnected. + + CORS ==== @@ -1045,3 +1139,38 @@ Blueprints @myblueprint.route('/') def index(): return {'hello': 'world'} + + +Websockets +========== +.. _websocket-api: + +.. class:: WebsocketEvent() + + Event object event that is passed as the sole arugment to any handler + function decorated with one of the three websocket related handlers: + ``on_ws_connect``, ``on_ws_disconnect``, ``on_ws_message``. + + .. attribute:: domain_name + + The domain name of the endpoint for the API Gateway Websocket API. + + .. attribute:: stage + + The API Gateway stage of the Websocket API. + + .. attribute:: connection_id + + A handle that uniquely identifies a connection with API Gateway. + + .. attribute:: body + + The message body received. This is only populated on the ``on_ws_message`` + otherwise it will be set to ``None``. + + .. attribute:: json_body + + The parsed JSON body (``json.loads(body)``) of the message. If the body is + not JSON parsable then using this attribute will raise a ``ValueError``. + + See :doc:`topics/websockets` for more information. diff --git a/docs/source/conf.py b/docs/source/conf.py index 371d368b3..1ddd32c10 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -63,7 +63,7 @@ # The short X.Y version. version = u'1.9' # The full version, including alpha/beta/rc tags. -release = u'1.9.0' +release = u'1.9.1' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/source/index.rst b/docs/source/index.rst index a9d86dd81..47e1aae57 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -65,6 +65,7 @@ Topics topics/events topics/purelambda topics/blueprints + topics/websockets topics/cd topics/experimental @@ -78,6 +79,14 @@ API Reference api +Tutorials +--------- + +.. toctree:: + :maxdepth: 2 + + tutorials/websockets + Upgrade Notes ------------- diff --git a/docs/source/topics/configfile.rst b/docs/source/topics/configfile.rst index bb5022f86..0a6a39978 100644 --- a/docs/source/topics/configfile.rst +++ b/docs/source/topics/configfile.rst @@ -40,11 +40,26 @@ a stage specific configuration value is needed, the ``stages`` mapping is checked first. If no value is found then the top level keys will be checked. - * ``api_gateway_stage`` - The name of the API gateway stage. This will also be the URL prefix for your API (``https://endpoint/prefix/your-api``). +* ``api_gateway_endpoint_type`` - The endpoint configuration of the + deployed API Gateway which determines how the API will be accessed, + can be EDGE, REGIONAL, PRIVATE. Note this value can only be set as a + top level key and defaults to EDGE. For more information see + https://amzn.to/2LofApt + +* ``api_gateway_endpoint_vpce`` - When configuring a Private API a VPC + Endpoint id must be specified to configure a default resource policy on + the API if an explicit policy is not specified. This value can be a + list or a string of endpoint ids. + +* ``api_gateway_policy_file`` - A file pointing to an IAM resource + policy for the REST API. If not specified chalice will autogenerate + this policy when endpoint_type is PRIVATE. This filename is relative + to the ``.chalice`` directory. + * ``minimum_compression_size`` - An integer value that indicates the minimum compression size to apply to the API gateway. If this key is specified in both a stage specific config option diff --git a/docs/source/topics/experimental.rst b/docs/source/topics/experimental.rst index 25384922b..6b8dc18dc 100644 --- a/docs/source/topics/experimental.rst +++ b/docs/source/topics/experimental.rst @@ -84,6 +84,12 @@ The status of an experimental API can be: - Trial - `#1023 `__, `#651 `__ + * - :doc:`websockets` + - ``WEBSOCKETS`` + - 1.10.0 + - Trial + - `#1041 `__, + `#1017 `__ See the `original discussion `__ diff --git a/docs/source/topics/views.rst b/docs/source/topics/views.rst index 2b0c701e2..190a3d975 100644 --- a/docs/source/topics/views.rst +++ b/docs/source/topics/views.rst @@ -100,7 +100,7 @@ was instantiated. For example: from chalice import Chalice from chalice import BadRequestError - app = Chalice(app_name="badrequset") + app = Chalice(app_name="badrequest") @app.route('/badrequest') def badrequest(): diff --git a/docs/source/topics/websockets.rst b/docs/source/topics/websockets.rst new file mode 100644 index 000000000..6639a3fa0 --- /dev/null +++ b/docs/source/topics/websockets.rst @@ -0,0 +1,96 @@ +Websockets +========== + +.. warning:: + + Websockets are considered an experimental API. You'll need to opt-in + to this feature using the ``WEBSOCKETS`` feature flag: + + .. code-block:: python + + app = Chalice('myapp') + app.experimental_feature_flags.extend([ + 'WEBSOCKETS' + ]) + + See :doc:`experimental` for more information. + + +Chalice supports websockets through integration with an API Gateway Websocket +API. If any of the decorators are present in a Chalice app, then an API +Gateway Websocket API will be deployed and wired to Lambda Functions. + + +Responding to websocket events +------------------------------ + +In a Chalice app the websocket API is accessed through the three decorators +``on_ws_connect``, ``on_ws_message``, ``on_ws_disconnect``. These handle a new +websocket connection, an incoming message on an existing connection, and a +connection being cleaned up respectively. + +A decorated websocket handler function takes one argument ``event`` with the +type :ref:`WebsocketEvent `. This class allows easy access to +information about the API Gateway Websocket API, and information about the +particular socket the handler is being invoked to serve. + +Below is a simple working example application that prints to CloudWatch Logs +for each of the events. + +.. code-block:: python + + from boto3.session import Session + from chalice import Chalice + + app = Chalice(app_name='test-websockets') + app.experimental_feature_flags.update([ + 'WEBSOCKETS', + ]) + app.websocket_api.session = Session() + + + @app.on_ws_connect() + def connect(event): + print('New connection: %s' % event.connection_id) + + + @app.on_ws_message() + def message(event): + print('%s: %s' % (event.connection_id, event.body)) + + + @app.on_ws_disconnect() + def disconnect(event): + print('%s disconnected' % event.connection_id) + + +Sending a message over a websocket +---------------------------------- + +To send a message to a websocket client Chalice, use the +:ref:`app.websocket_api.send() ` method. This method will work in any +of the decorated functions outlined in the above section. + +Two pieces of information are needed to send a message. The identifier of the +websocket, and the contents for the message. Below is a simple example that when +it receives a message, it sends back the message ``"I got your message!"`` over +the same socket. + +.. code-block:: python + + from boto3.session import Session + from chalice import Chalice + + app = Chalice(app_name='test-websockets') + app.experimental_feature_flags.update([ + 'WEBSOCKETS', + ]) + app.websocket_api.session = Session() + + + @app.on_ws_message() + def message(event): + app.websocket_api.send(event.connection_id, 'I got your message!') + + +See :ref:`websocket-tutorial` for completely worked example applications. diff --git a/docs/source/tutorials/websockets.rst b/docs/source/tutorials/websockets.rst new file mode 100644 index 000000000..8df920c3a --- /dev/null +++ b/docs/source/tutorials/websockets.rst @@ -0,0 +1,1113 @@ +.. _websocket-tutorial: + +Websocket Tutorials +=================== + +Echo Server Example +------------------- + +An echo server is a simple server that echos any message it receives back to +the client that sent it. + +First install a copy of Chalice in a fresh environment, create a new project +and cd into the directory:: + + $ pip install -U chalice + $ chalice new-project echo-server + $ cd echo-server + +Our Chalice application will need boto3 as a dependency for both API Gateway +to send websocket messages. Let's add a boto3 to the ``requirements.txt`` +file:: + + $ echo "boto3>=1.9.91" > requirements.txt + + +Now that the requirement has been added. Let's install it locally since our +next script will need it as well:: + + $ pip install -r requirements.txt + + +Next replace the contents of the ``app.py`` file with the code below. + +.. code-block:: python + :caption: app.py + :linenos: + + from boto3.session import Session + + from chalice import Chalice + from chalice import WebsocketDisconnectedError + + app = Chalice(app_name="echo-server") + app.websocket_api.session = Session() + app.experimental_feature_flags.update([ + 'WEBSOCKETS' + ]) + + + @app.on_ws_message() + def message(event): + try: + app.websocket_api.send( + connection_id=event.connection_id, + message=event.body, + ) + except WebsocketDisconnectedError as e: + pass # Disconnected so we can't send the message back. + + +Stepping through this app line by line, the first thing to note is that we +need to import and instantiate a boto3 session. This session is manually +assigned to ``app.websocket_api.session``. +This is needed because in order to send websocket responses to API Gateway we +need to construct a boto3 client. Chalice does not take a direct dependency +on boto3 or botocore, so we need to provide the Session ourselves. + +.. code-block:: python + + from boto3.session import Session + app.websocket_api.session = Session() + + +Next we enable the experimental feature ``WEBSOCKETS``. Websockets are an +experimental feature and are subject to API changes. This includes all aspects +of the Websocket API exposted in Chalice. Including any public members of +``app.websocket_api``, and the three decorators ``on_ws_connect``, +``on_ws_message``, and ``on_ws_disconnect``. + +.. code-block:: python + + app.experimental_feature_flags.update([ + 'WEBSOCKETS' + ]) + + +To register a websocket handler, and cause Chalice to deploy an +API Gateway Websocket API we use the ``app.on_ws_message()`` decorator. +The event parameter here is a wrapper object with some convenience +parameters attached. The most useful are ``event.connection_id`` and +``event.body``. The ``connection_id`` is an API Gateway specific identifier +that allows you to refer to the connection that sent the message. The ``body`` +is the content of the message. + +.. code-block:: python + + @app.on_ws_message() + def message(event): + + +Since this is an echo server, the message handler simply reads the content it +received on the socket, and rewrites it back to the same socket. To send a +message to a socket we call ``app.websocket_api.send(connection_id, message)``. +In this case, we just use the same ``connection_id`` we got the message from, +and use the ``body`` we got from the event as the ``message`` to send. + +.. code-block:: python + + app.websocket_api.send( + connection_id=event.connection_id, + message=event.body, + ) + + +Finally, we catch the exception ``WebsocketDisconnectedError`` which is raised +by ``app.websocket_api.send`` if the provided ``connection_id`` is not +connected anymore. In our case this doesn't really matter since we don't have +anything tracking our connections. The error has a ``connection_id`` property +that contains the offending connection id. + +.. code-block:: python + + except WebsocketDisconnectedError as e: + pass # Disconnected so we can't send the message back. + + +Now that we understand the code, lets deploy it with ``chalice deploy``:: + + $ chalice deploy + Creating deployment package. + Creating IAM role: echo-server-dev + Creating lambda function: echo-server-dev-websocket_message + Creating websocket api: echo-server-dev-websocket-api + Resources deployed: + - Lambda ARN: arn:aws:lambda:region:0123456789:function:echo-server-dev-websocket_message + - Websocket API URL: wss://{websocket_api_id}.execute-api.region.amazonaws.com/api/ + +To test out the echo server we will use the ``websocket-client`` package. You +install it from PyPI:: + + $ pip install websocket-client + + +After deploying the Chalice app the output will contain a URL for connecting +to the websocket API labeled: ``- Websocket API URL:``. The +``websocket-client`` package installs a command line tool called ``wsdump.py`` +which can be used to test websocket echo server:: + + $ wsdump.py wss://{websocket_api_id}.execute-api.region.amazonaws.com/api/ + Press Ctrl+C to quit + > foo + < foo + > bar + < bar + > foo bar baz + < foo bar baz + > + + +Every message sent to the server (lines that start with ``>``) result in a +message sent to us (lines that start with ``<``) with the same content. + +If something goes wrong, you can check the chalice error logs using the +following command:: + + $ chalice logs -n websocket_message + +.. note:: + If you encounter an Internal Server Error here it is likely that you forgot + to include ``boto3>=1.9.91`` in the ``requirements.txt`` file. + +To tear down the example. Just run:: + + $ chalice delete + Deleting Websocket API: {websocket_api_id} + Deleting function: arn:aws:lambda:us-west-2:0123456789:function:echo-server-dev-websocket_message + Deleting IAM role: echo-server-dev + +Chat Server Example +------------------- + + +Note:: + + This example is for illustration purposes and does not represent best + practices. + +A simple chat server example application. This example will walk through +deploying a chat application with separate chat rooms and nicknames. It uses +a DynamoDB table to store state like connection IDs between websocket messages. + + +First install a copy of Chalice in a fresh environment, create a new project +and cd into the directory:: + + $ pip install -U chalice + $ chalice new-project chalice-chat-example + $ cd chalice-chat-example + + +Our Chalice application will need boto3 as a dependency for both DynamoDB +access and in order to communicate back with API Gateway to send websocket +messages. Let's add a boto3 to the ``requirements.txt`` file:: + + $ echo "boto3>=1.9.91" > requirements.txt + + +Now that the requirement has been added. Let's install it locally since our +next script will need it as well:: + + $ pip install -r requirements.txt + +To set up the DynamoDB table use the following script. Create a new file +in the root of the project called ``create-resources.py``. + + +.. code-block:: python + :caption: create-resources.py + + import json + + import boto3 + + + def iam_policy(table_arn): + resources = [ + table_arn, + '%s/index/ReverseLookup' % table_arn, + ] + return { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "dynamodb:DeleteItem", + "dynamodb:PutItem", + "dynamodb:GetItem", + "dynamodb:UpdateItem", + "dynamodb:Query", + "dynamodb:Scan" + ], + "Resource": resources, + }, + { + "Effect": "Allow", + "Action": [ + "logs:CreateLogGroup", + "logs:CreateLogStream", + "logs:PutLogEvents" + ], + "Resource": "arn:aws:logs:*:*:*" + }, + { + "Effect": "Allow", + "Action": [ + "execute-api:ManageConnections" + ], + "Resource": "arn:aws:execute-api:*:*:*/@connections/*" + } + ] + } + + + def main(): + ddb = boto3.client('dynamodb') + result = ddb.create_table( + AttributeDefinitions=[ + { + 'AttributeName': 'PK', + 'AttributeType': 'S', + }, + { + 'AttributeName': 'SK', + 'AttributeType': 'S', + }, + ], + TableName='ChaliceChatTable', + KeySchema=[ + { + 'AttributeName': 'PK', + 'KeyType': 'HASH', + }, + { + 'AttributeName': 'SK', + 'KeyType': 'RANGE', + }, + ], + ProvisionedThroughput={ + 'ReadCapacityUnits': 5, + 'WriteCapacityUnits': 5, + }, + GlobalSecondaryIndexes=[ + { + 'IndexName': 'ReverseLookup', + 'KeySchema': [ + { + 'AttributeName': 'SK', + 'KeyType': 'HASH', + }, + { + 'AttributeName': 'PK', + 'KeyType': 'RANGE', + }, + ], + 'Projection': { + 'ProjectionType': 'ALL', + }, + 'ProvisionedThroughput': { + 'ReadCapacityUnits': 1, + 'WriteCapacityUnits': 1, + } + }, + ], + ) + table_arn = result['TableDescription']['TableArn'] + with open('.chalice/config.json', 'r') as f: + config = json.loads(f.read()) + + config['stages']['dev']['environment_variables'] = { + 'TABLE': 'ChaliceChatTable', + } + config['autogen_policy'] = False + + with open('.chalice/config.json', 'w') as f: + f.write(json.dumps(config, indent=2)) + + with open('.chalice/policy-dev.json', 'w') as f: + f.write(json.dumps(iam_policy(table_arn), indent=2)) + + + if __name__ == "__main__": + main() + + +The current directory layout should now look like this:: + + tree -a . + . + ├── .chalice + │   └── config.json + ├── .gitignore + ├── app.py + ├── create-resources.py + └── requirements.txt + + 1 directory, 5 files + +Run the python script we just created (``create-resources.py``), which will +deploy our DynamoDB table, and setup the Chalice configuration to have an +environment variable with the table name in it, as well as a policy that allows +the Lambda function to access the table:: + + $ python create-resources.py + + +You can verify the configuration is correct by checking config file looks +correct:: + + $ cat .chalice/config.json + { + "version": "2.0", + "app_name": "chalice-chat-example", + "stages": { + "dev": { + "api_gateway_stage": "api", + "environment_variables": { + "TABLE": "ChaliceChatTable" + } + } + }, + "autogen_policy": false + } + +And the policy file is correct:: + + $ cat .chalice/policy-dev.json + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "dynamodb:DeleteItem", + "dynamodb:PutItem", + "dynamodb:GetItem", + "dynamodb:UpdateItem", + "dynamodb:Query", + "dynamodb:Scan" + ], + "Resource": [ + "arn:aws:dynamodb:{region}:{id}:table/ChaliceChatTable", + "arn:aws:dynamodb:{region}:{id}:table/ChaliceChatTable/index/ReverseLookup" + ] + }, + { + "Effect": "Allow", + "Action": [ + "logs:CreateLogGroup", + "logs:CreateLogStream", + "logs:PutLogEvents" + ], + "Resource": "arn:aws:logs:*:*:*" + }, + { + "Effect": "Allow", + "Action": [ + "execute-api:ManageConnections" + ], + "Resource": "arn:aws:execute-api:*:*:*/@connections/*" + } + ] + } + + +Next let's fill out the ``app.py`` file since it is pretty simple. Most of this +example is contained in the ``chalicelib/`` directory. + +.. code-block:: python + :caption: chalice-chat-example/app.py + + from boto3.session import Session + + from chalice import Chalice + + from chalicelib import Storage + from chalicelib import Sender + from chalicelib import Handler + + app = Chalice(app_name="chalice-chat-example") + app.websocket_api.session = Session() + app.experimental_feature_flags.update([ + 'WEBSOCKETS' + ]) + + STORAGE = Storage.from_env() + SENDER = Sender(app, STORAGE) + HANDLER = Handler(STORAGE, SENDER) + + + @app.on_ws_connect() + def connect(event): + STORAGE.create_connection(event.connection_id) + + + @app.on_ws_disconnect() + def disconnect(event): + STORAGE.delete_connection(event.connection_id) + + + @app.on_ws_message() + def message(event): + HANDLER.handle(event.connection_id, event.body) + + +Similar to the previous example. We need to use ``boto3`` to construct a +Session and pass it to ``app.websocket_api.session``. We opt into the +usage of the ``WEBSOCKET`` experimental feature. Most of the actual work is +done in some classes that we import from ``chalicelib/``. These classes are +detailed below, and the various parts are explained in comments and Doc +strings. In addition to the previous example, we register a handler for +``on_ws_connect`` and ``on_ws_disconnect`` to handle events from API gateway +when a new socket is trying to connect, or an existing socket is disconnected. + + +Finally before being able to deploy and test the app out, we need to fill out +the chalicelib directory. This is the bulk of the app and it is explained +inline in comments. Create a new directory called ``chalicelib`` and inside +that directory create an ``__init__.py`` file and fill it out with the +following file. + +.. code-block:: python + :caption: chalice-chat-example/chalicelib/__init__.py + + import os + + import boto3 + from boto3.dynamodb.conditions import Key + + from chalice import WebsocketDisconnectedError + + + class Storage(object): + """An abstraction to interact with the DynamoDB Table.""" + def __init__(self, table): + """Initialize Storage object + + :param table: A boto3 dynamodb Table resource object. + """ + self._table = table + + @classmethod + def from_env(cls): + """Create table from the environment. + + The environment variable TABLE is assumed to be present + as it is set by the create-resources.py file. + """ + table_name = os.environ.get('TABLE') + table = boto3.resource('dynamodb').Table(table_name) + return cls(table) + + def create_connection(self, connection_id): + """Create a new connection object in the dtabase. + + When a new connection is created, we create a stub for + it in the table. The stub uses a primary key of the + connection_id and a sort key of username_. This translates + to a connection with an unset username. The first message + sent over the wire from the connection is to be used as the + username, and this entry will be re-written. + + :param connection_id: The connection id to write to + the table. + """ + self._table.put_item( + Item={ + 'PK': connection_id, + 'SK': 'username_', + }, + ) + + def set_username(self, connection_id, old_name, username): + """Set the username. + + The SK entry that goes with this conneciton id that starts + with username_ is taken to be the username. The previous + entry needs to be deleted, and a new entry needs to be + written. + + :param connection_id: Connection id of the user trying to + change their name. + + :param old_name: The original username. Since this is part of + the key, it needs to be deleted and re-created rather than + updated. + + :param username: The new username the user wants. + """ + self._table.delete_item( + Key={ + 'PK': connection_id, + 'SK': 'username_%s' % old_name, + }, + ) + self._table.put_item( + Item={ + 'PK': connection_id, + 'SK': 'username_%s' % username, + }, + ) + + def list_rooms(self): + """Get a list of all rooms that exist. + + Scan through the table looking for SKs that start with room_ + which indicates a room that a user is in. Collect a unique set + of those and return them. + """ + r = self._table.scan() + rooms = set([item['SK'].split('_', 1)[1] for item in r['Items'] + if item['SK'].startswith('room_')]) + return rooms + + def set_room(self, connection_id, room): + """Set the room a user is currently in. + + The room a user is in is in the form of an SK that starts with + room_ prefix. + + :param connection_id: The connection id to move to a room. + + :param room: The room name to join. + """ + self._table.put_item( + Item={ + 'PK': connection_id, + 'SK': 'room_%s' % room, + }, + ) + + def remove_room(self, connection_id, room): + """Remove a user from a room. + + The room a user is in is in the form of an SK that starts with + room_ prefix. To leave a room we need to delete this entry. + + :param connection_id: The connection id to move to a room. + + :param room: The room name to join. + """ + self._table.delete_item( + Key={ + 'PK': connection_id, + 'SK': 'room_%s' % room, + }, + ) + + def get_connection_ids_by_room(self, room): + """Find all connection ids that go to a room. + + This is needed whenever we broadcast to a room. We collect all + their connection ids so we can send messages to them. We use a + ReverseLookup table here which inverts the PK, SK relationship + creating a partition called room_{room}. Everything in that + partition is a connection in the room. + + :param room: Room name to get all connection ids from. + """ + r = self._table.query( + IndexName='ReverseLookup', + KeyConditionExpression=( + Key('SK').eq('room_%s' % room) + ), + Select='ALL_ATTRIBUTES', + ) + return [item['PK'] for item in r['Items']] + + def delete_connection(self, connection_id): + """Delete a connection. + + Called when a connection is disconnected and all its entries need + to be deleted. + + :param connection_id: The connection partition to delete from + the table. + """ + try: + r = self._table.query( + KeyConditionExpression=( + Key('PK').eq(connection_id) + ), + Select='ALL_ATTRIBUTES', + ) + for item in r['Items']: + self._table.delete_item( + Key={ + 'PK': connection_id, + 'SK': item['SK'], + }, + ) + except Exception as e: + print(e) + + def get_record_by_connection(self, connection_id): + """Get all the properties associated with a connection. + + Each connection_id creates a partition in the table with multiple + SK entries. Each SK entry is in the format {property}_{value}. + This method reads all those records from the database and puts them + all into dictionary and returns it. + + :param connection_id: The connection to get properties for. + """ + r = self._table.query( + KeyConditionExpression=( + Key('PK').eq(connection_id) + ), + Select='ALL_ATTRIBUTES', + ) + r = { + entry['SK'].split('_', 1)[0]: entry['SK'].split('_', 1)[1] + for entry in r['Items'] + } + return r + + + class Sender(object): + """Class to send messages over websockets.""" + def __init__(self, app, storage): + """Initialize a sender object. + + :param app: A Chalice application object. + + :param storage: A Storage object. + """ + self._app = app + self._storage = storage + + def send(self, connection_id, message): + """Send a message over a websocket. + + :param connection_id: API Gateway Connection ID to send a + message to. + + :param message: The message to send to the connection. + """ + try: + # Call the chalice websocket api send method + self._app.websocket_api.send(connection_id, message) + except WebsocketDisconnectedError as e: + # If the websocket has been closed, we delete the connection + # from our database. + self._storage.delete_connection(e.connection_id) + + def broadcast(self, connection_ids, message): + """"Send a message to multiple connections. + + :param connection_id: A list of API Gateway Connection IDs to + send the message to. + + :param message: The message to send to the connections. + """ + for cid in connection_ids: + self.send(cid, message) + + + class Handler(object): + """Handler object that handles messages received from a websocket. + + This class implements the bulk of our app behavior. + """ + def __init__(self, storage, sender): + """Initialize a Handler object. + + :param storage: Storage object to interact with database. + + :param sender: Sender object to send messages to websockets. + """ + self._storage = storage + self._sender = sender + # Command table to translate a string command name into a + # method to call. + self._command_table = { + 'help': self._help, + 'nick': self._nick, + 'join': self._join, + 'room': self._room, + 'quit': self._quit, + 'ls': self._list, + } + + def handle(self, connection_id, message): + """Entry point for our application. + + :param connection_id: Connection id that the message came from. + + :param message: Message we got from the connection. + """ + # First look the user up in the database and get a record for it. + record = self._storage.get_record_by_connection(connection_id) + if record['username'] == '': + # If the user does not have a username, we assume that the message + # is the username they want and we call _handle_login_message. + self._handle_login_message(connection_id, message) + else: + # Otherwise we assume the user is logged in. So we call + # a method to handle the message. We pass along the + # record we loaded from the database so we don't need to + # again. + self._handle_message(connection_id, message, record) + + def _handle_login_message(self, connection_id, message): + """Handle a login message. + + The message is the username to give the user. Re-write the + database entry for this user to reset their username from '' + to {message}. Once that is done send a message back to the user + to confirm the name choice. Also send a /help prompt. + """ + self._storage.set_username(connection_id, '', message) + self._sender.send( + connection_id, + 'Using nickname: %s\nType /help for list of commands.' % message + ) + + def _handle_message(self, connection_id, message, record): + """"Handle a message from a connected and logged in user. + + If the message starts with a / it's a command. Otherwise its a + text message to send to all rooms in the room. + + :param connection_id: Connection id that the message came from. + + :param message: Message we got from the connection. + + :param record: A data record about the sender. + """ + if message.startswith('/'): + self._handle_command(connection_id, message[1:], record) + else: + self._handle_text(connection_id, message, record) + + def _handle_command(self, connection_id, message, record): + """Handle a command message. + + Check the command name and look it up in our command table. + If there is an entry, we call that method and pass along + the connection_id, arguments, and the loaded record. + + :param connection_id: Connection id that the message came from. + + :param message: Message we got from the connection. + + :param record: A data record about the sender. + """ + args = message.split(' ') + command_name = args.pop(0).lower() + command = self._command_table.get(command_name) + if command: + command(connection_id, args, record) + else: + # If no command method is found, send an error message + # back to the user. + self._sender( + connection_id, 'Unknown command: %s' % command_name) + + def _handle_text(self, connection_id, message, record): + """Handle a raw text message. + + :param connection_id: Connection id that the message came from. + + :param message: Message we got from the connection. + + :param record: A data record about the sender. + """ + if 'room' not in record: + # If the user is not in a room send them an error message + # and return early. + self._sender.send( + connection_id, 'Cannot send message if not in chatroom.') + return + # Collect a list of connection_ids in the same room as the message + # sender. + connection_ids = self._storage.get_connection_ids_by_room( + record['room']) + # Prefix the message with the sender's name. + message = '%s: %s' % (record['username'], message) + # Broadcast the new message to everyone in the room. + self._sender.broadcast(connection_ids, message) + + def _help(self, connection_id, _message, _record): + """Send the help message. + + Build a help message and send back to the same connection. + + :param connection_id: Connection id that the message came from. + """ + self._sender.send( + connection_id, + '\n'.join([ + 'Commands available:', + ' /help', + ' Display this message.', + ' /join {chat_room_name}', + ' Join a chatroom named {chat_room_name}.', + ' /nick {nickname}', + ' Change your name to {nickname}. If no {nickname}', + ' is provided then your current name will be printed', + ' /room', + ' Print out the name of the room you are currently ', + ' in.', + ' /ls', + ' If you are in a room, list all users also in the', + ' room. Otherwise, list all rooms.', + ' /quit', + ' Leave current room.', + '', + 'If you are in a room, raw text messages that do not start ', + 'with a / will be sent to everyone else in the room.', + ]), + ) + + def _nick(self, connection_id, args, record): + """Change or check nickname (username). + + :param connection_id: Connection id that the message came from. + + :param args: Argument list that came after the command. + + :param record: A data record about the sender. + """ + if not args: + # If a nickname argument was not provided, we just want to + # report the current nickname to the user. + self._sender.send( + connection_id, 'Current nickname: %s' % record['username']) + return + # The first argument is assumed to be the new desired nickname. + nick = args[0] + # Change the username from record['username'] to nick in the storage + # layer. + self._storage.set_username(connection_id, record['username'], nick) + # Send a message to the requestor to confirm the nickname change. + self._sender.send(connection_id, 'Nickname is: %s' % nick) + # Get the room the user is in. + room = record.get('room') + if room: + # If the user was in a room, announce to the room they have + # changed their name. Don't send this me sage to the user since + # they already got a name change message. + room_connections = self._storage.get_connection_ids_by_room(room) + room_connections.remove(connection_id) + self._sender.broadcast( + room_connections, + '%s is now known as %s.' % (record['username'], nick)) + + def _join(self, connection_id, args, record): + """Join a chat room. + + :param connection_id: Connection id that the message came from. + + :param args: Argument list. The first argument should be the + name of the room to join. + + :param record: A data record about the sender. + """ + # Get the room name to join. + room = args[0] + # Call quit to leave the current room we are in if there is any. + self._quit(connection_id, '', record) + # Get a list of connections in the target chat room. + room_connections = self._storage.get_connection_ids_by_room(room) + # Join the target chat room. + self._storage.set_room(connection_id, room) + # Send a message to the requestor that they have joined the room. + # At the same time send an announcement to everyone who was already + # in the room to alert them of the new user. + self._sender.send( + connection_id, 'Joined chat room "%s"' % room) + message = '%s joined room.' % record['username'] + self._sender.broadcast(room_connections, message) + + def _room(self, connection_id, _args, record): + """Report the name of the current room. + + :param connection_id: Connection id that the message came from. + + :param record: A data record about the sender. + """ + if 'room' in record: + # If the user is in a room send them the name back. + self._sender.send(connection_id, record['room']) + else: + # If the user is not in a room. Tell them so, and how to + # join a room. + self._sender.send( + connection_id, + 'Not currently in a room. Type /join {room_name} to do so.' + ) + + def _quit(self, connection_id, _args, record): + """Quit from a room. + + :param connection_id: Connection id that the message came from. + + :param record: A data record about the sender. + """ + if 'room' not in record: + # If the user is not in a room there is nothing to do. + return + # Find the current room name, and delete that entry from + # the database. + room_name = record['room'] + self._storage.remove_room(connection_id, room_name) + # Send a message to the user to inform them they left the room. + self._sender.send( + connection_id, 'Left chat room "%s"' % room_name) + # Tell everyone in the room that the user has left. + self._sender.broadcast( + self._storage.get_connection_ids_by_room(room_name), + '%s left room.' % record['username'], + ) + + def _list(self, connection_id, _args, record): + """Show a context dependent listing. + + :param connection_id: Connection id that the message came from. + + :param record: A data record about the sender. + """ + room = record.get('room') + if room: + # If the user is in a room, get a listing of everyone + # in the room. + result = [ + self._storage.get_record_by_connection(c_id)['username'] + for c_id in self._storage.get_connection_ids_by_room(room) + ] + else: + # If they are not in a room. Get a listing of all rooms + # currently open. + result = self._storage.list_rooms() + # Send the result list back to the requestor. + self._sender.send(connection_id, '\n'.join(result)) + + +The final directory layout should be :: + + $ tree -a . + . + ├── .chalice + │   ├── config.json + │   └── policy-dev.json + ├── .gitignore + ├── app.py + ├── chalicelib + │   └── __init__.py + ├── create-resources.py + └── requirements.txt + + 2 directories, 7 files + + +To deploy the app run the following command:: + + $ chalice deploy + Creating deployment package. + Creating IAM role: chalice-chat-example-dev-websocket_handler + Creating lambda function: chalice-chat-example-dev-websocket_handler + Creating websocket api: chalice-chat-example-dev-websocket-api + Resources deployed: + - Lambda ARN: arn:aws:lambda:::chalice-chat-example-dev-websocket_handler + - Websocket API URL: wss://{id}.execute-api.{region}.amazonaws.com/api/ + +Once deployed we can take the ``Websocket API URL`` and connect to it in the +same way we did in the previous example using the ``wsdump.py`` command line +tool. Below is a sample of two running clients, the first message sent to the +server is used as the client's username. + + +.. code-block:: bash + :caption: client-1 + + $ wsdump.py wss://{id}.execute-api.{region}.amazonaws.com/api/ + Press Ctrl+C to quit + > John + < Using nickname: John + Type /help for list of commands. + > /help + < Commands available: + /help + Display this message. + /join {chat_room_name} + Join a chatroom named {chat_room_name}. + /nick {nickname} + Change your name to {nickname}. If no {nickname} + is provided then your current name will be printed + /room + Print out the name of the room you are currently + in. + /ls + If you are in a room, list all users also in the + room. Otherwise, list all rooms. + /quit + Leave current room. + + If you are in a room, raw text messages that do not start + with a / will be sent to everyone else in the room. + > /join chalice + < Joined chat room "chalice" + < Jenny joined room. + > Hi + < John: Hi + < Jenny is now known as JennyJones. + > /quit + < Left chat room "chalice" + > /ls + < chalice + > Ctrl-C + +.. code-block:: bash + :caption: client-2 + + $ wsdump.py wss://{id}.execute-api.{region}.amazonaws.com/api/ + Press Ctrl+C to quit + > Jenny + < Using nickname: Jenny + Type /help for list of commands. + > /help + < Commands available: + /help + Display this message. + /join {chat_room_name} + Join a chatroom named {chat_room_name}. + /nick {nickname} + Change your name to {nickname}. If no {nickname} + is provided then your current name will be printed + /room + Print out the name of the room you are currently + in. + /ls + If you are in a room, list all users also in the + room. Otherwise, list all rooms. + /quit + Leave current room. + + If you are in a room, raw text messages that do not start + with a / will be sent to everyone else in the room. + > /join chalice + < Joined chat room "chalice" + > /ls + < John + Jenny + < John: Hi + > /nick JennyJones + < Nickname is: JennyJones + < John left room. + > /ls + < JennyJones + > /room + < chalice + > /nick + < Current nickname: JennyJones + > Ctrl-C + + +To delete the resources you can run chalice delete and use the AWS CLI +to delete the DynamoDB table:: + + $ chalice delete + $ pip install -U awscli + $ aws dynamodb delete-table --table-name ChaliceChatTable diff --git a/requirements-dev.txt b/requirements-dev.txt index 1cbcd22a2..13c79e6a9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -20,6 +20,8 @@ py==1.5.3 pygments==2.1.3 mock==2.0.0 requests==2.20.0 +boto3==1.9.188 +websocket-client==0.54.0 hypothesis==3.56.3 # pip does not catch the <0.7 requirement in pytest so we need to add it to # the top level requirements. diff --git a/setup.py b/setup.py index a362ad21f..0e47da630 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ 'botocore>=1.12.86,<2.0.0', 'typing==3.6.4;python_version<"3.7"', 'six>=1.10.0,<2.0.0', - 'pip>=9,<=19.2', + 'pip>=9,<19.3', 'attrs>=17.4.0,<20.0.0', 'enum-compat>=0.0.2', 'jmespath>=0.9.3,<1.0.0', @@ -21,7 +21,7 @@ setup( name='chalice', - version='1.9.0', + version='1.9.1', description="Microframework", long_description=README, author="James Saryerwinnie", diff --git a/tests/functional/test_awsclient.py b/tests/functional/test_awsclient.py index ec03d3ab1..b1205a342 100644 --- a/tests/functional/test_awsclient.py +++ b/tests/functional/test_awsclient.py @@ -64,11 +64,11 @@ def test_put_role_policy(stubbed_session): def test_rest_api_exists(stubbed_session): stubbed_session.stub('apigateway').get_rest_api( - restApiId='api').returns({}) + restApiId='api').returns({'id': 'api'}) stubbed_session.activate_stubs() awsclient = TypedAWSClient(stubbed_session) - assert awsclient.rest_api_exists('api') + assert awsclient.get_rest_api('api') stubbed_session.verify_stubs() @@ -81,7 +81,7 @@ def test_rest_api_not_exists(stubbed_session): stubbed_session.activate_stubs() awsclient = TypedAWSClient(stubbed_session) - assert not awsclient.rest_api_exists('api') + assert not awsclient.get_rest_api('api') stubbed_session.verify_stubs() @@ -1230,6 +1230,351 @@ def test_can_add_permission_when_policy_does_not_exist(self, stubbed_session.verify_stubs() +class TestAddPermissionsForAPIGatewayV2(object): + def should_call_add_permission(self, lambda_stub, + statement_id=stub.ANY): + lambda_stub.add_permission( + Action='lambda:InvokeFunction', + FunctionName='name', + StatementId=statement_id, + Principal='apigateway.amazonaws.com', + SourceArn='arn:aws:execute-api:us-west-2:123:websocket-api-id/*', + ).returns({}) + + def test_can_add_permission_for_apigateway_v2_needed(self, + stubbed_session): + # An empty policy means we need to add permissions. + lambda_stub = stubbed_session.stub('lambda') + lambda_stub.get_policy(FunctionName='name').returns({'Policy': '{}'}) + self.should_call_add_permission(lambda_stub) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + client.add_permission_for_apigateway_v2( + 'name', 'us-west-2', '123', 'websocket-api-id') + stubbed_session.verify_stubs() + + def test_can_add_permission_random_id_optional(self, stubbed_session): + lambda_stub = stubbed_session.stub('lambda') + lambda_stub.get_policy(FunctionName='name').returns({'Policy': '{}'}) + self.should_call_add_permission(lambda_stub) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + client.add_permission_for_apigateway_v2( + 'name', 'us-west-2', '123', 'websocket-api-id') + stubbed_session.verify_stubs() + + def test_can_add_permission_for_apigateway_v2_not_needed(self, + stubbed_session): + source_arn = 'arn:aws:execute-api:us-west-2:123:websocket-api-id/*' + wrong_action = { + 'Action': 'lambda:NotInvoke', + 'Condition': { + 'ArnLike': { + 'AWS:SourceArn': source_arn, + } + }, + 'Effect': 'Allow', + 'Principal': {'Service': 'apigateway.amazonaws.com'}, + 'Resource': 'arn:aws:lambda:us-west-2:account_id:function:name', + 'Sid': 'e4755709-067e-4254-b6ec-e7f9639e6f7b', + } + wrong_service_name = { + 'Action': 'lambda:Invoke', + 'Condition': { + 'ArnLike': { + 'AWS:SourceArn': source_arn, + } + }, + 'Effect': 'Allow', + 'Principal': {'Service': 'NOT-apigateway.amazonaws.com'}, + 'Resource': 'arn:aws:lambda:us-west-2:account_id:function:name', + 'Sid': 'e4755709-067e-4254-b6ec-e7f9639e6f7b', + } + correct_statement = { + 'Action': 'lambda:InvokeFunction', + 'Condition': { + 'ArnLike': { + 'AWS:SourceArn': source_arn, + } + }, + 'Effect': 'Allow', + 'Principal': {'Service': 'apigateway.amazonaws.com'}, + 'Resource': 'arn:aws:lambda:us-west-2:account_id:function:name', + 'Sid': 'e4755709-067e-4254-b6ec-e7f9639e6f7b', + } + policy = { + 'Id': 'default', + 'Statement': [ + wrong_action, + wrong_service_name, + correct_statement, + ], + 'Version': '2012-10-17' + } + stubbed_session.stub('lambda').get_policy( + FunctionName='name').returns({'Policy': json.dumps(policy)}) + + # Because the policy above indicates that API gateway already has the + # necessary permissions, we should not call add_permission. + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + client.add_permission_for_apigateway( + 'name', 'us-west-2', '123', 'websocket-api-id') + stubbed_session.verify_stubs() + + def test_can_add_permission_when_policy_does_not_exist(self, + stubbed_session): + # It's also possible to receive a ResourceNotFoundException + # if you call get_policy() on a lambda function with no policy. + lambda_stub = stubbed_session.stub('lambda') + lambda_stub.get_policy(FunctionName='name').raises_error( + error_code='ResourceNotFoundException', message='Does not exist.') + self.should_call_add_permission(lambda_stub) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + client.add_permission_for_apigateway_v2( + 'name', 'us-west-2', '123', 'websocket-api-id', 'random-id') + stubbed_session.verify_stubs() + + +class TestWebsocketAPI(object): + def test_can_create_websocket_api(self, stubbed_session): + stubbed_session.stub('apigatewayv2').create_api( + Name='name', + ProtocolType='WEBSOCKET', + RouteSelectionExpression='$request.body.action', + ).returns({'ApiId': 'id'}) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + api_id = client.create_websocket_api('name') + stubbed_session.verify_stubs() + assert api_id == 'id' + + def test_can_get_websocket_api(self, stubbed_session): + stubbed_session.stub('apigatewayv2').get_apis( + ).returns({ + 'Items': [ + {'Name': 'some-other-api', + 'ApiId': 'foo bar', + 'RouteSelectionExpression': 'unused', + 'ProtocolType': 'WEBSOCKET'}, + {'Name': 'target-api', + 'ApiId': 'id', + 'RouteSelectionExpression': 'unused', + 'ProtocolType': 'WEBSOCKET'}, + ], + }) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + api_id = client.get_websocket_api_id('target-api') + stubbed_session.verify_stubs() + assert api_id == 'id' + + def test_does_return_none_on_websocket_api_missing(self, stubbed_session): + stubbed_session.stub('apigatewayv2').get_apis( + ).returns({ + 'Items': [], + }) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + api_id = client.get_websocket_api_id('target-api') + stubbed_session.verify_stubs() + assert api_id is None + + def test_can_check_get_websocket_api_exists(self, stubbed_session): + stubbed_session.stub('apigatewayv2').get_api( + ApiId='api-id', + ).returns({}) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + exists = client.websocket_api_exists('api-id') + stubbed_session.verify_stubs() + assert exists is True + + def test_can_check_get_websocket_api_not_exists(self, stubbed_session): + stubbed_session.stub('apigatewayv2').get_api( + ApiId='api-id', + ).raises_error( + error_code='NotFoundException', + message='Does not exists.', + ) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + exists = client.websocket_api_exists('api-id') + stubbed_session.verify_stubs() + assert exists is False + + def test_can_delete_websocket_api(self, stubbed_session): + stubbed_session.stub('apigatewayv2').delete_api( + ApiId='id', + ).returns({}) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + client.delete_websocket_api('id') + stubbed_session.verify_stubs() + + def test_rest_api_delete_already_deleted(self, stubbed_session): + stubbed_session.stub('apigatewayv2')\ + .delete_api(ApiId='name')\ + .raises_error(error_code='NotFoundException', + message='Unknown') + stubbed_session.activate_stubs() + + awsclient = TypedAWSClient(stubbed_session) + with pytest.raises(ResourceDoesNotExistError): + assert awsclient.delete_websocket_api('name') + + def test_can_create_integration(self, stubbed_session): + stubbed_session.stub('apigatewayv2').create_integration( + ApiId='api-id', + ConnectionType='INTERNET', + ContentHandlingStrategy='CONVERT_TO_TEXT', + Description='connect', + IntegrationType='AWS_PROXY', + IntegrationUri='arn:aws:lambda', + ).returns({'IntegrationId': 'integration-id'}) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + integration_id = client.create_websocket_integration( + api_id='api-id', + lambda_function='arn:aws:lambda', + handler_type='connect', + ) + stubbed_session.verify_stubs() + assert integration_id == 'integration-id' + + def test_can_create_route(self, stubbed_session): + stubbed_session.stub('apigatewayv2').create_route( + ApiId='api-id', + RouteKey='route-key', + RouteResponseSelectionExpression='$default', + Target='integrations/integration-id', + ).returns({}) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + client.create_websocket_route( + api_id='api-id', + route_key='route-key', + integration_id='integration-id', + ) + stubbed_session.verify_stubs() + + def test_can_delete_all_websocket_routes(self, stubbed_session): + stubbed_session.stub('apigatewayv2').delete_route( + ApiId='api-id', + RouteId='route-id', + ).returns({}) + stubbed_session.stub('apigatewayv2').delete_route( + ApiId='api-id', + RouteId='old-route-id', + ).returns({}) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + client.delete_websocket_routes( + api_id='api-id', + routes=['route-id', 'old-route-id'], + ) + stubbed_session.verify_stubs() + + def test_can_delete_all_websocket_integrations(self, stubbed_session): + stubbed_session.stub('apigatewayv2').delete_integration( + ApiId='api-id', + IntegrationId='integration-id', + ).returns({}) + stubbed_session.stub('apigatewayv2').delete_integration( + ApiId='api-id', + IntegrationId='old-integration-id', + ).returns({}) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + client.delete_websocket_integrations( + api_id='api-id', + integrations=['integration-id', 'old-integration-id'], + ) + stubbed_session.verify_stubs() + + def test_can_deploy_websocket_api(self, stubbed_session): + stubbed_session.stub('apigatewayv2').create_deployment( + ApiId='api-id', + ).returns({'DeploymentId': 'deployment-id'}) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + deployment_id = client.deploy_websocket_api( + api_id='api-id', + ) + stubbed_session.verify_stubs() + assert deployment_id == 'deployment-id' + + def test_can_get_routes(self, stubbed_session): + stubbed_session.stub('apigatewayv2').get_routes( + ApiId='api-id', + ).returns( + { + 'Items': [ + {'RouteKey': 'route-key-foo', + 'RouteId': 'route-id-foo'}, + {'RouteKey': 'route-key-bar', + 'RouteId': 'route-id-bar'}, + ], + } + ) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + routes = client.get_websocket_routes( + api_id='api-id', + ) + stubbed_session.verify_stubs() + assert routes == ['route-id-foo', 'route-id-bar'] + + def test_can_get_integrations(self, stubbed_session): + stubbed_session.stub('apigatewayv2').get_integrations( + ApiId='api-id', + ).returns( + { + 'Items': [ + { + 'Description': 'connect', + 'IntegrationId': 'connect-integration-id' + }, + { + 'Description': 'message', + 'IntegrationId': 'message-integration-id' + }, + { + 'Description': 'disconnect', + 'IntegrationId': 'disconnect-integration-id' + }, + ] + } + ) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + integration_ids = client.get_websocket_integrations( + api_id='api-id', + ) + stubbed_session.verify_stubs() + assert integration_ids == [ + 'connect-integration-id', + 'message-integration-id', + 'disconnect-integration-id', + ] + + def test_can_create_stage(self, stubbed_session): + stubbed_session.stub('apigatewayv2').create_stage( + ApiId='api-id', + StageName='stage-name', + DeploymentId='deployment-id', + ).returns({}) + stubbed_session.activate_stubs() + client = TypedAWSClient(stubbed_session) + client.create_stage( + api_id='api-id', + stage_name='stage-name', + deployment_id='deployment-id', + ) + stubbed_session.verify_stubs() + + class TestAddPermissionsForAuthorizer(object): FUNCTION_ARN = ( @@ -1325,12 +1670,13 @@ def test_import_rest_api(stubbed_session): apig = stubbed_session.stub('apigateway') swagger_doc = {'swagger': 'doc'} apig.import_rest_api( + parameters={'endpointConfigurationTypes': 'EDGE'}, body=json.dumps(swagger_doc, indent=2)).returns( {'id': 'rest_api_id'}) stubbed_session.activate_stubs() awsclient = TypedAWSClient(stubbed_session) - rest_api_id = awsclient.import_rest_api(swagger_doc) + rest_api_id = awsclient.import_rest_api(swagger_doc, 'EDGE') stubbed_session.verify_stubs() assert rest_api_id == 'rest_api_id' diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/test_features.py b/tests/integration/test_features.py index 4fcb42d04..d9bbdc3ae 100644 --- a/tests/integration/test_features.py +++ b/tests/integration/test_features.py @@ -5,9 +5,11 @@ import shutil import uuid +import mock import botocore.session import pytest import requests +import websocket from chalice.cli.factory import CLIFactory from chalice.utils import OSUtils, UI @@ -18,7 +20,7 @@ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) PROJECT_DIR = os.path.join(CURRENT_DIR, 'testapp') APP_FILE = os.path.join(PROJECT_DIR, 'app.py') -RANDOM_APP_NAME = 'smoketest-%s' % str(uuid.uuid4()) +RANDOM_APP_NAME = 'smoketest-%s' % str(uuid.uuid4())[:13] def retry(max_attempts, delay): @@ -67,6 +69,22 @@ def rest_api_id(self): return self._deployed_resources.resource_values( 'rest_api')['rest_api_id'] + @property + def websocket_api_id(self): + return self._deployed_resources.resource_values( + 'websocket_api')['websocket_api_id'] + + @property + def websocket_connect_url(self): + return ( + "wss://{websocket_api_id}.execute-api.{region}.amazonaws.com/" + "{api_gateway_stage}".format( + websocket_api_id=self.websocket_api_id, + region=self._region, + api_gateway_stage='api', + ) + ) + def get_json(self, url): if not url.startswith('/'): url = '/' + url @@ -499,6 +517,24 @@ def test_empty_raw_body(smoke_test_app): assert response.json() == {'repr-raw-body': ''} +def test_websocket_lifecycle(smoke_test_app): + ws = websocket.create_connection(smoke_test_app.websocket_connect_url) + ws.send("Hello, World 1") + ws.recv() + ws.close() + ws = websocket.create_connection(smoke_test_app.websocket_connect_url) + ws.send("Hello, World 2") + second_response = json.loads(ws.recv()) + ws.close() + + expected_second_response = [ + [mock.ANY, 'Hello, World 1'], + [mock.ANY, 'Hello, World 2'] + ] + assert expected_second_response == second_response + assert second_response[0][0] != second_response[1][0] + + @pytest.mark.on_redeploy def test_redeploy_no_change_view(smoke_test_app): smoke_test_app.redeploy_once() diff --git a/tests/integration/test_websockets.py b/tests/integration/test_websockets.py new file mode 100644 index 000000000..fa3ffaf46 --- /dev/null +++ b/tests/integration/test_websockets.py @@ -0,0 +1,316 @@ +import os +import sys +import json +import uuid +import threading +import shutil +import time + +import pytest +import websocket + +from chalice.cli.factory import CLIFactory +from chalice.utils import OSUtils, UI +from chalice.deploy.deployer import ChaliceDeploymentError +from chalice.config import DeployedResources + + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PROJECT_DIR = os.path.join(CURRENT_DIR, 'testwebsocketapp') +APP_FILE = os.path.join(PROJECT_DIR, 'app.py') +RANDOM_APP_NAME = 'smoketest-%s' % str(uuid.uuid4())[:13] + + +def retry(max_attempts, delay): + def _create_wrapped_retry_function(function): + def _wrapped_with_retry(*args, **kwargs): + for _ in range(max_attempts): + result = function(*args, **kwargs) + if result is not None: + return result + time.sleep(delay) + raise RuntimeError("Exhausted max retries of %s for function: %s" + % (max_attempts, function)) + return _wrapped_with_retry + return _create_wrapped_retry_function + + +def _inject_app_name(dirname): + config_filename = os.path.join(dirname, '.chalice', 'config.json') + with open(config_filename) as f: + data = json.load(f) + data['app_name'] = RANDOM_APP_NAME + data['stages']['dev']['environment_variables']['APP_NAME'] = \ + RANDOM_APP_NAME + with open(config_filename, 'w') as f: + f.write(json.dumps(data, indent=2)) + + +def _deploy_app(temp_dirname): + factory = CLIFactory(temp_dirname) + config = factory.create_config_obj( + chalice_stage_name='dev', + autogen_policy=True + ) + session = factory.create_botocore_session() + d = factory.create_default_deployer(session, config, UI()) + region = session.get_config_variable('region') + deployed = _deploy_with_retries(d, config) + application = SmokeTestApplication( + region=region, + deployed_values=deployed, + stage_name='dev', + app_name=RANDOM_APP_NAME, + app_dir=temp_dirname, + ) + return application + + +@retry(max_attempts=10, delay=20) +def _deploy_with_retries(deployer, config): + try: + deployed_stages = deployer.deploy(config, 'dev') + return deployed_stages + except ChaliceDeploymentError as e: + # API Gateway aggressively throttles deployments. + # If we run into this case, we just wait and try + # again. + error_code = _get_error_code_from_exception(e) + if error_code != 'TooManyRequestsException': + raise + + +def _get_error_code_from_exception(exception): + error_response = getattr(exception.original_error, 'response', None) + if error_response is None: + return None + return error_response.get('Error', {}).get('Code') + + +def _delete_app(application, temp_dirname): + factory = CLIFactory(temp_dirname) + config = factory.create_config_obj( + chalice_stage_name='dev', + autogen_policy=True + ) + session = factory.create_botocore_session() + d = factory.create_deletion_deployer(session, UI()) + _deploy_with_retries(d, config) + + +class SmokeTestApplication(object): + + # Number of seconds to wait after redeploy before starting + # to poll for successful 200. + _REDEPLOY_SLEEP = 20 + # Seconds to wait between poll attempts after redeploy. + _POLLING_DELAY = 5 + + def __init__(self, deployed_values, stage_name, app_name, + app_dir, region): + self._deployed_resources = DeployedResources(deployed_values) + self.stage_name = stage_name + self.app_name = app_name + # The name of the tmpdir where the app is copied. + self.app_dir = app_dir + self._has_redeployed = False + self._region = region + + @property + def websocket_api_id(self): + return self._deployed_resources.resource_values( + 'websocket_api')['websocket_api_id'] + + @property + def websocket_connect_url(self): + return ( + "wss://{websocket_api_id}.execute-api.{region}.amazonaws.com/" + "{api_gateway_stage}".format( + websocket_api_id=self.websocket_api_id, + region=self._region, + api_gateway_stage='api', + ) + ) + + @property + def websocket_message_handler_arn(self): + return self._deployed_resources.resource_values( + 'websocket_message')['lambda_arn'] + + @property + def region(self): + return self._region + + def redeploy_once(self): + # Redeploy the application once. If a redeploy + # has already happened, this function is a noop. + if self._has_redeployed: + return + new_file = os.path.join(self.app_dir, 'app-redeploy.py') + original_app_py = os.path.join(self.app_dir, 'app.py') + shutil.move(original_app_py, original_app_py + '.bak') + shutil.copy(new_file, original_app_py) + self._clear_app_import() + _deploy_app(self.app_dir) + self._has_redeployed = True + # Give it settling time before running more tests. + time.sleep(self._REDEPLOY_SLEEP) + + def _clear_app_import(self): + # Now that we're using `import` instead of `exec` we need + # to clear out sys.modules in order to pick up the new + # version of the app we just copied over. + del sys.modules['app'] + + +@pytest.fixture(scope='module') +def smoke_test_app(tmpdir_factory): + # We can't use the monkeypatch fixture here because this is a module scope + # fixture and monkeypatch is a function scoped fixture. + os.environ['APP_NAME'] = RANDOM_APP_NAME + tmpdir = str(tmpdir_factory.mktemp(RANDOM_APP_NAME)) + _create_dynamodb_table(RANDOM_APP_NAME, tmpdir) + OSUtils().copytree(PROJECT_DIR, tmpdir) + _inject_app_name(tmpdir) + application = _deploy_app(tmpdir) + yield application + _delete_app(application, tmpdir) + _delete_dynamodb_table(RANDOM_APP_NAME, tmpdir) + os.environ.pop('APP_NAME') + + +def _create_dynamodb_table(table_name, temp_dirname): + factory = CLIFactory(temp_dirname) + session = factory.create_botocore_session() + ddb = session.create_client('dynamodb') + ddb.create_table( + TableName=table_name, + AttributeDefinitions=[ + { + 'AttributeName': 'entry', + 'AttributeType': 'N', + }, + ], + KeySchema=[ + { + 'AttributeName': 'entry', + 'KeyType': 'HASH', + }, + ], + ProvisionedThroughput={ + 'ReadCapacityUnits': 5, + 'WriteCapacityUnits': 5, + }, + ) + + +def _delete_dynamodb_table(table_name, temp_dirname): + factory = CLIFactory(temp_dirname) + session = factory.create_botocore_session() + ddb = session.create_client('dynamodb') + ddb.delete_table( + TableName=table_name, + ) + + +class Task(threading.Thread): + def __init__(self, action, delay=0.05): + threading.Thread.__init__(self) + self._action = action + self._done = threading.Event() + self._delay = delay + + def run(self): + while not self._done.is_set(): + self._action() + time.sleep(self._delay) + + def stop(self): + self._done.set() + + +def counter(): + """Generator of sequential increasing numbers""" + yield + count = 1 + while True: + yield count + count += 1 + + +class CountingMessageSender(object): + """Class to send values from a counter over a websocket.""" + def __init__(self, ws, counter): + self._ws = ws + self._counter = counter + self._last_sent = None + + def send(self): + value = next(self._counter) + self._ws.send('%s' % value) + self._last_sent = value + + @property + def last_sent(self): + return self._last_sent + + +def get_numbers_from_dynamodb(temp_dirname): + """Get numbers from DynamoDB in the format written by testwebsocketapp. + """ + factory = CLIFactory(temp_dirname) + session = factory.create_botocore_session() + ddb = session.create_client('dynamodb') + paginator = ddb.get_paginator('scan') + numbers = sorted([ + int(item['entry']['N']) + for page in paginator.paginate( + TableName=RANDOM_APP_NAME, + ConsistentRead=True, + ) + for item in page['Items'] + ]) + return numbers + + +def find_skips_in_seq(numbers): + """Find non-sequential gaps in a sequence of numbers + + :type numbers: Iterable of ints + :param numbers: Iterable to check for gaps + + :returns: List of tuples with the gaps in the format + [(start_of_gap, end_of_gap, ...)]. If the list is empty then there + are no gaps. + """ + last = numbers[0] - 1 + skips = [] + for elem in numbers: + if elem != last + 1: + skips.append((last, elem)) + last = elem + return skips + + +def test_websocket_redployment_does_not_lose_messages(smoke_test_app): + # This test is to check if one persistant connection is affected by an app + # redeployment. A connetion is made to the app, and a sequence of numbers + # is sent over the socket and written to a DynamoDB table. The app is + # redeployed in a seprate thread. After the redeployment we wait a + # second to ensure more numbers have been sent. Finally we inspect the + # DynamoDB table to ensure there are no gaps in the numbers we saw on the + # server side, and that the first and last number we sent is also present. + ws = websocket.create_connection(smoke_test_app.websocket_connect_url) + counter_generator = counter() + sender = CountingMessageSender(ws, counter_generator) + ping_endpoint = Task(sender.send) + ping_endpoint.start() + smoke_test_app.redeploy_once() + time.sleep(1) + ping_endpoint.stop() + + numbers = get_numbers_from_dynamodb(smoke_test_app.app_dir) + assert 1 in numbers + assert sender.last_sent in numbers + skips = find_skips_in_seq(numbers) + assert skips == [] diff --git a/tests/integration/testapp/app.py b/tests/integration/testapp/app.py index ad634225b..c70289f91 100644 --- a/tests/integration/testapp/app.py +++ b/tests/integration/testapp/app.py @@ -1,10 +1,12 @@ import os +import json try: from urllib.parse import parse_qs except ImportError: from urlparse import parse_qs +import boto3.session from chalice import Chalice, BadRequestError, NotFoundError, Response,\ CORSConfig, UnauthorizedError, AuthResponse, AuthRoute @@ -14,6 +16,10 @@ # and helps prevent regressions. app = Chalice(app_name=os.environ['APP_NAME']) +app.websocket_api.session = boto3.session.Session() +app.experimental_feature_flags.update([ + 'WEBSOCKETS' +]) app.api.binary_types.append('application/binary') @@ -213,3 +219,22 @@ def fake_profile_post(): @app.route('/repr-raw-body', methods=['POST']) def repr_raw_body(): return {'repr-raw-body': app.current_request.raw_body.decode('utf-8')} + + +SOCKET_MESSAGES = [] + + +@app.on_ws_connect() +def connect(event): + pass + + +@app.on_ws_message() +def message(event): + SOCKET_MESSAGES.append((event.connection_id, event.body)) + app.websocket_api.send(event.connection_id, json.dumps(SOCKET_MESSAGES)) + + +@app.on_ws_disconnect() +def disconnect(event): + pass diff --git a/tests/integration/testapp/requirements.txt b/tests/integration/testapp/requirements.txt index e69de29bb..58806b636 100644 --- a/tests/integration/testapp/requirements.txt +++ b/tests/integration/testapp/requirements.txt @@ -0,0 +1 @@ +boto3==1.9.91 diff --git a/tests/integration/testwebsocketapp/.chalice/config.json b/tests/integration/testwebsocketapp/.chalice/config.json new file mode 100644 index 000000000..3698f1de0 --- /dev/null +++ b/tests/integration/testwebsocketapp/.chalice/config.json @@ -0,0 +1,10 @@ +{ + "version": "2.0", + "app_name": "testwebsocketapp", + "stages": { + "dev": { + "api_gateway_stage": "api", + "environment_variables": {} + } + } +} diff --git a/tests/integration/testwebsocketapp/.gitignore b/tests/integration/testwebsocketapp/.gitignore new file mode 100644 index 000000000..3dd60a972 --- /dev/null +++ b/tests/integration/testwebsocketapp/.gitignore @@ -0,0 +1,2 @@ +.chalice/deployments/ +.chalice/venv/ diff --git a/tests/integration/testwebsocketapp/app-redeploy.py b/tests/integration/testwebsocketapp/app-redeploy.py new file mode 100644 index 000000000..39be56dfd --- /dev/null +++ b/tests/integration/testwebsocketapp/app-redeploy.py @@ -0,0 +1,25 @@ +import os + +import boto3 +from chalice import Chalice + +app = Chalice(app_name=os.environ['APP_NAME']) +app.websocket_api.session = boto3.session.Session() +app.experimental_feature_flags.update([ + 'WEBSOCKETS' +]) +ddb = boto3.client('dynamodb') + + +# This comment is to cause a change which triggers a redeployment +# of the Lambda Function, this is needed to properly test redeployment. +@app.on_ws_message() +def message(event): + ddb.put_item( + TableName=os.environ['APP_NAME'], + Item={ + 'entry': { + 'N': event.body + }, + }, + ) diff --git a/tests/integration/testwebsocketapp/app.py b/tests/integration/testwebsocketapp/app.py new file mode 100644 index 000000000..9f82e1da8 --- /dev/null +++ b/tests/integration/testwebsocketapp/app.py @@ -0,0 +1,23 @@ +import os + +import boto3 +from chalice import Chalice + +app = Chalice(app_name=os.environ['APP_NAME']) +app.websocket_api.session = boto3.session.Session() +app.experimental_feature_flags.update([ + 'WEBSOCKETS' +]) +ddb = boto3.client('dynamodb') + + +@app.on_ws_message() +def message(event): + ddb.put_item( + TableName=os.environ['APP_NAME'], + Item={ + 'entry': { + 'N': event.body + }, + }, + ) diff --git a/tests/integration/testwebsocketapp/requirements.txt b/tests/integration/testwebsocketapp/requirements.txt new file mode 100644 index 000000000..58806b636 --- /dev/null +++ b/tests/integration/testwebsocketapp/requirements.txt @@ -0,0 +1 @@ +boto3==1.9.91 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index ff5542fc9..8032a9a63 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -77,6 +77,25 @@ def myfunction(event, context): return app +@fixture +def sample_websocket_app(): + app = Chalice('sample') + + @app.on_ws_connect() + def foo(): + pass + + @app.on_ws_message() + def bar(): + pass + + @app.on_ws_disconnect() + def baz(): + pass + + return app + + @fixture def create_event(): def create_event_inner(uri, method, path, content_type='application/json'): @@ -96,6 +115,21 @@ def create_event_inner(uri, method, path, content_type='application/json'): return create_event_inner +@fixture +def create_websocket_event(): + def create_event_inner(route_key, body=''): + return { + 'requestContext': { + 'routeKey': route_key, + 'domainName': 'abcd1234.us-west-2.amazonaws.com', + 'stage': 'api', + 'connectionId': 'ABCD1234=', + }, + 'body': body, + } + return create_event_inner + + @fixture def create_empty_header_event(): def create_empty_header_event_inner(uri, method, path, diff --git a/tests/unit/deploy/test_deployer.py b/tests/unit/deploy/test_deployer.py index 32cf8f92a..3f6d44dbc 100644 --- a/tests/unit/deploy/test_deployer.py +++ b/tests/unit/deploy/test_deployer.py @@ -1,6 +1,7 @@ -import botocore.session import os + import socket +import botocore.session import pytest import mock @@ -33,12 +34,15 @@ from chalice.deploy.executor import Executor from chalice.deploy.swagger import SwaggerGenerator, TemplatedSwaggerGenerator from chalice.deploy.planner import PlanStage -from chalice.deploy.planner import ResourceSweeper, StringFormat +from chalice.deploy.planner import StringFormat +from chalice.deploy.sweeper import ResourceSweeper from chalice.deploy.models import APICall from chalice.constants import LAMBDA_TRUST_POLICY, VPC_ATTACH_POLICY from chalice.constants import SQS_EVENT_SOURCE_POLICY +from chalice.constants import POST_TO_WEBSOCKET_CONNECTION_POLICY from chalice.deploy.deployer import ChaliceBuildError from chalice.deploy.deployer import LambdaEventSourcePolicyInjector +from chalice.deploy.deployer import WebsocketPolicyInjector _SESSION = None @@ -316,6 +320,70 @@ def handler(event): return app +@fixture +def websocket_app(): + app = Chalice('websocket-event') + + @app.on_ws_connect() + def connect(event): + pass + + @app.on_ws_message() + def message(event): + pass + + @app.on_ws_disconnect() + def disconnect(event): + pass + + return app + + +@fixture +def websocket_app_without_connect(): + app = Chalice('websocket-event-no-connect') + + @app.on_ws_message() + def message(event): + pass + + @app.on_ws_disconnect() + def disconnect(event): + pass + + return app + + +@fixture +def websocket_app_without_message(): + app = Chalice('websocket-event-no-message') + + @app.on_ws_connect() + def connect(event): + pass + + @app.on_ws_disconnect() + def disconnect(event): + pass + + return app + + +@fixture +def websocket_app_without_disconnect(): + app = Chalice('websocket-event-no-disconnect') + + @app.on_ws_connect() + def connect(event): + pass + + @app.on_ws_message() + def message(event): + pass + + return app + + @fixture def mock_client(): return mock.Mock(spec=TypedAWSClient) @@ -404,12 +472,19 @@ def create_config(self, app, app_name='lambda-only', iam_role_arn=None, policy_file=None, api_gateway_stage='api', autogen_policy=False, security_group_ids=None, - subnet_ids=None, reserved_concurrency=None, layers=None): + subnet_ids=None, reserved_concurrency=None, layers=None, + api_gateway_endpoint_type=None, + api_gateway_endpoint_vpce=None, + api_gateway_policy_file=None, + project_dir='.'): kwargs = { 'chalice_app': app, 'app_name': app_name, - 'project_dir': '.', + 'project_dir': project_dir, 'api_gateway_stage': api_gateway_stage, + 'api_gateway_policy_file': api_gateway_policy_file, + 'api_gateway_endpoint_type': api_gateway_endpoint_type, + 'api_gateway_endpoint_vpce': api_gateway_endpoint_vpce } if iam_role_arn is not None: # We want to use an existing role. @@ -631,6 +706,45 @@ def test_scheduled_event_models(self, scheduled_event_app): assert isinstance(event.lambda_function, models.LambdaFunction) assert event.lambda_function.resource_name == 'foo' + def test_can_build_private_rest_api(self, rest_api_app): + config = self.create_config(rest_api_app, + app_name='rest-api-app', + api_gateway_endpoint_type='PRIVATE', + api_gateway_endpoint_vpce='vpce-abc123') + builder = ApplicationGraphBuilder() + application = builder.build(config, stage_name='dev') + rest_api = application.resources[0] + assert isinstance(rest_api, models.RestAPI) + assert rest_api.policy.document == { + 'Version': '2012-10-17', + 'Statement': [ + {'Action': 'execute-api:Invoke', + 'Effect': 'Allow', + 'Principal': '*', + 'Resource': 'arn:aws:execute-api:*:*:*', + 'Condition': { + 'StringEquals': { + 'aws:SourceVpce': 'vpce-abc123'}}}, + ] + } + + def test_can_build_private_rest_api_custom_policy( + self, tmpdir, rest_api_app): + config = self.create_config(rest_api_app, + app_name='rest-api-app', + api_gateway_policy_file='foo.json', + api_gateway_endpoint_type='PRIVATE', + project_dir=str(tmpdir)) + tmpdir.mkdir('.chalice').join('foo.json').write( + serialize_to_json({'Version': '2012-10-17', 'Statement': []})) + + builder = ApplicationGraphBuilder() + application = builder.build(config, stage_name='dev') + rest_api = application.resources[0] + rest_api.policy.document == { + 'Version': '2012-10-17', 'Statement': [] + } + def test_can_build_rest_api(self, rest_api_app): config = self.create_config(rest_api_app, app_name='rest-api-app', @@ -713,6 +827,104 @@ def test_can_create_sqs_event_handler(self, sqs_event_app): assert lambda_function.resource_name == 'handler' assert lambda_function.handler == 'app.handler' + def test_can_create_websocket_event_handler(self, websocket_app): + config = self.create_config(websocket_app, + app_name='websocket-app', + autogen_policy=True) + builder = ApplicationGraphBuilder() + application = builder.build(config, stage_name='dev') + assert len(application.resources) == 1 + websocket_api = application.resources[0] + assert isinstance(websocket_api, models.WebsocketAPI) + assert websocket_api.resource_name == 'websocket_api' + assert sorted(websocket_api.routes) == sorted( + ['$connect', '$default', '$disconnect']) + assert websocket_api.api_gateway_stage == 'api' + + connect_function = websocket_api.connect_function + assert connect_function.resource_name == 'websocket_connect' + assert connect_function.handler == 'app.connect' + + message_function = websocket_api.message_function + assert message_function.resource_name == 'websocket_message' + assert message_function.handler == 'app.message' + + disconnect_function = websocket_api.disconnect_function + assert disconnect_function.resource_name == 'websocket_disconnect' + assert disconnect_function.handler == 'app.disconnect' + + def test_can_create_websocket_app_missing_connect( + self, websocket_app_without_connect): + config = self.create_config(websocket_app_without_connect, + app_name='websocket-app', + autogen_policy=True) + builder = ApplicationGraphBuilder() + application = builder.build(config, stage_name='dev') + assert len(application.resources) == 1 + websocket_api = application.resources[0] + assert isinstance(websocket_api, models.WebsocketAPI) + assert websocket_api.resource_name == 'websocket_api' + assert sorted(websocket_api.routes) == sorted( + ['$default', '$disconnect']) + assert websocket_api.api_gateway_stage == 'api' + + connect_function = websocket_api.connect_function + assert connect_function is None + + message_function = websocket_api.message_function + assert message_function.resource_name == 'websocket_message' + assert message_function.handler == 'app.message' + + disconnect_function = websocket_api.disconnect_function + assert disconnect_function.resource_name == 'websocket_disconnect' + assert disconnect_function.handler == 'app.disconnect' + + def test_can_create_websocket_app_missing_message( + self, websocket_app_without_message): + config = self.create_config(websocket_app_without_message, + app_name='websocket-app', + autogen_policy=True) + builder = ApplicationGraphBuilder() + application = builder.build(config, stage_name='dev') + assert len(application.resources) == 1 + websocket_api = application.resources[0] + assert isinstance(websocket_api, models.WebsocketAPI) + assert websocket_api.resource_name == 'websocket_api' + assert sorted(websocket_api.routes) == sorted( + ['$connect', '$disconnect']) + assert websocket_api.api_gateway_stage == 'api' + + connect_function = websocket_api.connect_function + assert connect_function.resource_name == 'websocket_connect' + assert connect_function.handler == 'app.connect' + + disconnect_function = websocket_api.disconnect_function + assert disconnect_function.resource_name == 'websocket_disconnect' + assert disconnect_function.handler == 'app.disconnect' + + def test_can_create_websocket_app_missing_disconnect( + self, websocket_app_without_disconnect): + config = self.create_config(websocket_app_without_disconnect, + app_name='websocket-app', + autogen_policy=True) + builder = ApplicationGraphBuilder() + application = builder.build(config, stage_name='dev') + assert len(application.resources) == 1 + websocket_api = application.resources[0] + assert isinstance(websocket_api, models.WebsocketAPI) + assert websocket_api.resource_name == 'websocket_api' + assert sorted(websocket_api.routes) == sorted( + ['$connect', '$default']) + assert websocket_api.api_gateway_stage == 'api' + + connect_function = websocket_api.connect_function + assert connect_function.resource_name == 'websocket_connect' + assert connect_function.handler == 'app.connect' + + message_function = websocket_api.message_function + assert message_function.resource_name == 'websocket_message' + assert message_function.handler == 'app.message' + class RoleTestCase(object): def __init__(self, given, roles, app_name='appname'): @@ -1062,6 +1274,7 @@ def test_can_generate_swagger_builder(self): resource_name='foo', swagger_doc=models.Placeholder.BUILD_STAGE, minimum_compression='', + endpoint_type='EDGE', api_gateway_stage='api', lambda_function=None, ) @@ -1070,7 +1283,7 @@ def test_can_generate_swagger_builder(self): p = SwaggerBuilder(generator) p.handle(config, rest_api) assert rest_api.swagger_doc == {'swagger': '2.0'} - generator.generate_swagger.assert_called_with(app) + generator.generate_swagger.assert_called_with(app, rest_api) class TestDeploymentPackager(object): @@ -1291,6 +1504,10 @@ def test_can_generate_report(self): "rest_api_id": "rest_api_id", "rest_api_url": "https://host/api", "resource_type": "rest_api"}, + {"name": "websocket_api", + "websocket_api_id": "websocket_api_id", + "websocket_api_url": "wss://host/api", + "resource_type": "websocket_api"}, ], } report = self.reporter.generate_report(deployed_values) @@ -1299,6 +1516,7 @@ def test_can_generate_report(self): " - Lambda ARN: lambda-arn-foo\n" " - Lambda ARN: lambda-arn-dev\n" " - Rest API URL: https://host/api\n" + " - Websocket API URL: wss://host/api\n" ) def test_can_display_report(self): @@ -1360,3 +1578,34 @@ def second_handler(event): assert role.policy.document == { 'Statement': [SQS_EVENT_SOURCE_POLICY.copy()], } + + +class TestWebsocketPolicyInjector(object): + def create_model_from_app(self, app, config): + builder = ApplicationGraphBuilder() + application = builder.build(config, stage_name='dev') + return application.resources[0] + + def test_can_inject_policy(self, websocket_app): + config = Config.create(chalice_app=websocket_app, + autogen_policy=True, + project_dir='.') + event_source = self.create_model_from_app(websocket_app, config) + role = event_source.connect_function.role + role.policy.document = {'Statement': []} + injector = WebsocketPolicyInjector() + injector.handle(config, event_source) + assert role.policy.document == { + 'Statement': [POST_TO_WEBSOCKET_CONNECTION_POLICY.copy()], + } + + def test_no_inject_if_not_autogen_policy(self, websocket_app): + config = Config.create(chalice_app=websocket_app, + autogen_policy=False, + project_dir='.') + event_source = self.create_model_from_app(websocket_app, config) + role = event_source.connect_function.role + role.policy.document = {'Statement': []} + injector = LambdaEventSourcePolicyInjector() + injector.handle(config, event_source) + assert role.policy.document == {'Statement': []} diff --git a/tests/unit/deploy/test_executor.py b/tests/unit/deploy/test_executor.py index 940747260..fef3faf24 100644 --- a/tests/unit/deploy/test_executor.py +++ b/tests/unit/deploy/test_executor.py @@ -6,7 +6,8 @@ from chalice.deploy.executor import Executor, UnresolvedValueError, \ VariableResolver from chalice.deploy.models import APICall, RecordResourceVariable, \ - RecordResourceValue, StoreValue, JPSearch, BuiltinFunction, Instruction + RecordResourceValue, StoreValue, JPSearch, BuiltinFunction, Instruction, \ + CopyVariable from chalice.deploy.planner import Variable, StringFormat from chalice.utils import UI @@ -169,6 +170,13 @@ def test_can_jp_search(self): ]) assert self.executor.variables['result'] == 'baz' + def test_can_copy_variable(self): + self.execute([ + StoreValue(name='foo', value='bar'), + CopyVariable(from_var='foo', to_var='baz'), + ]) + assert self.executor.variables['baz'] == 'bar' + def test_can_call_builtin_function(self): self.execute([ StoreValue( diff --git a/tests/unit/deploy/test_models.py b/tests/unit/deploy/test_models.py index b70ecfe0c..21a54e490 100644 --- a/tests/unit/deploy/test_models.py +++ b/tests/unit/deploy/test_models.py @@ -41,6 +41,7 @@ def test_can_default_to_no_auths_in_rest_api(lambda_function): swagger_doc={'swagger': '2.0'}, minimum_compression='', api_gateway_stage='api', + endpoint_type='EDGE', lambda_function=lambda_function, ) assert rest_api.dependencies() == [lambda_function] @@ -54,7 +55,47 @@ def test_can_add_authorizers_to_dependencies(lambda_function): swagger_doc={'swagger': '2.0'}, minimum_compression='', api_gateway_stage='api', + endpoint_type='EDGE', lambda_function=lambda_function, authorizers=[auth1, auth2], ) assert rest_api.dependencies() == [lambda_function, auth1, auth2] + + +def test_can_add_connect_to_dependencies(lambda_function): + api = models.WebsocketAPI( + resource_name='websocket_api', + name='name', + api_gateway_stage='api', + routes=['$connect'], + connect_function=lambda_function, + message_function=None, + disconnect_function=None, + ) + assert api.dependencies() == [lambda_function] + + +def test_can_add_message_to_dependencies(lambda_function): + api = models.WebsocketAPI( + resource_name='websocket_api', + name='name', + api_gateway_stage='api', + routes=['$default'], + connect_function=None, + message_function=lambda_function, + disconnect_function=None, + ) + assert api.dependencies() == [lambda_function] + + +def test_can_add_disconnect_to_dependencies(lambda_function): + api = models.WebsocketAPI( + resource_name='websocket_api', + name='name', + api_gateway_stage='api', + routes=['$disconnect'], + connect_function=None, + message_function=None, + disconnect_function=lambda_function, + ) + assert api.dependencies() == [lambda_function] diff --git a/tests/unit/deploy/test_planner.py b/tests/unit/deploy/test_planner.py index ac86393c2..02d7603ef 100644 --- a/tests/unit/deploy/test_planner.py +++ b/tests/unit/deploy/test_planner.py @@ -9,7 +9,8 @@ from chalice.utils import OSUtils from chalice.deploy.planner import PlanStage, Variable, RemoteState from chalice.deploy.planner import StringFormat -from chalice.deploy.planner import ResourceSweeper +from chalice.deploy.models import APICall +from chalice.deploy.sweeper import ResourceSweeper def create_function_resource(name, function_name=None, @@ -490,6 +491,371 @@ def test_can_plan_scheduled_event(self): ) +class TestPlanWebsocketAPI(BasePlannerTests): + def assert_loads_needed_variables(self, plan): + # Parse arn and store region/account id for future + # API calls. + assert plan[0:3] == [ + models.BuiltinFunction( + 'parse_arn', [Variable('function_name_connect_lambda_arn')], + output_var='parsed_lambda_arn', + ), + models.JPSearch('account_id', + input_var='parsed_lambda_arn', + output_var='account_id'), + models.JPSearch('region', + input_var='parsed_lambda_arn', + output_var='region_name'), + ] + + def test_can_plan_websocket_api(self): + connect_function = create_function_resource( + 'function_name_connect') + message_function = create_function_resource( + 'function_name_message') + disconnect_function = create_function_resource( + 'function_name_disconnect') + websocket_api = models.WebsocketAPI( + resource_name='websocket_api', + name='app-dev-websocket-api', + api_gateway_stage='api', + routes=['$connect', '$default', '$disconnect'], + connect_function=connect_function, + message_function=message_function, + disconnect_function=disconnect_function, + ) + plan = self.determine_plan(websocket_api) + self.assert_loads_needed_variables(plan) + assert plan[3:] == [ + models.APICall( + method_name='create_websocket_api', + params={'name': 'app-dev-websocket-api'}, + output_var='websocket_api_id', + ), + models.StoreValue( + name='routes', + value=[], + ), + models.StoreValue( + name='websocket-connect-integration-lambda-path', + value=StringFormat( + 'arn:aws:apigateway:{region_name}:lambda:path/' + '2015-03-31/functions/arn:aws:lambda:{region_name}:' + '{account_id}:function:%s/' + 'invocations' % 'appname-dev-function_name_connect', + ['region_name', 'account_id'], + ), + ), + models.APICall( + method_name='create_websocket_integration', + params={ + 'api_id': Variable('websocket_api_id'), + 'lambda_function': Variable( + 'websocket-connect-integration-lambda-path'), + 'handler_type': 'connect', + }, + output_var='connect-integration-id', + ), + models.StoreValue( + name='websocket-message-integration-lambda-path', + value=StringFormat( + 'arn:aws:apigateway:{region_name}:lambda:path/' + '2015-03-31/functions/arn:aws:lambda:{region_name}:' + '{account_id}:function:%s/' + 'invocations' % 'appname-dev-function_name_message', + ['region_name', 'account_id'], + ), + ), + models.APICall( + method_name='create_websocket_integration', + params={ + 'api_id': Variable('websocket_api_id'), + 'lambda_function': Variable( + 'websocket-message-integration-lambda-path'), + 'handler_type': 'message', + }, + output_var='message-integration-id', + ), + models.StoreValue( + name='websocket-disconnect-integration-lambda-path', + value=StringFormat( + 'arn:aws:apigateway:{region_name}:lambda:path/' + '2015-03-31/functions/arn:aws:lambda:{region_name}:' + '{account_id}:function:%s/' + 'invocations' % 'appname-dev-function_name_disconnect', + ['region_name', 'account_id'], + ), + ), + models.APICall( + method_name='create_websocket_integration', + params={ + 'api_id': Variable('websocket_api_id'), + 'lambda_function': Variable( + 'websocket-disconnect-integration-lambda-path'), + 'handler_type': 'disconnect', + }, + output_var='disconnect-integration-id', + ), + models.APICall( + method_name='create_websocket_route', + params={ + 'api_id': Variable('websocket_api_id'), + 'route_key': '$connect', + 'integration_id': Variable('connect-integration-id'), + }, + ), + models.APICall( + method_name='create_websocket_route', + params={ + 'api_id': Variable('websocket_api_id'), + 'route_key': '$default', + 'integration_id': Variable('message-integration-id'), + }, + ), + models.APICall( + method_name='create_websocket_route', + params={ + 'api_id': Variable('websocket_api_id'), + 'route_key': '$disconnect', + 'integration_id': Variable('disconnect-integration-id'), + }, + ), + models.APICall( + method_name='deploy_websocket_api', + params={ + 'api_id': Variable('websocket_api_id'), + }, + output_var='deployment-id', + ), + models.APICall( + method_name='create_stage', + params={ + 'api_id': Variable('websocket_api_id'), + 'stage_name': 'api', + 'deployment_id': Variable('deployment-id'), + } + ), + models.StoreValue( + name='websocket_api_url', + value=StringFormat( + 'wss://{websocket_api_id}.execute-api.{region_name}' + '.amazonaws.com/%s/' % 'api', + ['websocket_api_id', 'region_name'], + ), + ), + models.RecordResourceVariable( + resource_type='websocket_api', + resource_name='websocket_api', + name='websocket_api_url', + variable_name='websocket_api_url', + ), + models.RecordResourceVariable( + resource_type='websocket_api', + resource_name='websocket_api', + name='websocket_api_id', + variable_name='websocket_api_id', + ), + models.APICall( + method_name='add_permission_for_apigateway_v2', + params={'function_name': 'appname-dev-function_name_connect', + 'region_name': Variable('region_name'), + 'account_id': Variable('account_id'), + 'api_id': Variable('websocket_api_id')}, + ), + models.APICall( + method_name='add_permission_for_apigateway_v2', + params={'function_name': 'appname-dev-function_name_message', + 'region_name': Variable('region_name'), + 'account_id': Variable('account_id'), + 'api_id': Variable('websocket_api_id')}, + ), + models.APICall( + method_name='add_permission_for_apigateway_v2', + params={ + 'function_name': 'appname-dev-function_name_disconnect', + 'region_name': Variable('region_name'), + 'account_id': Variable('account_id'), + 'api_id': Variable('websocket_api_id')}, + ), + ] + + def test_can_update_websocket_api(self): + connect_function = create_function_resource( + 'function_name_connect') + message_function = create_function_resource( + 'function_name_message') + disconnect_function = create_function_resource( + 'function_name_disconnect') + websocket_api = models.WebsocketAPI( + resource_name='websocket_api', + name='app-dev-websocket-api', + api_gateway_stage='api', + routes=['$connect', '$default', '$disconnect'], + connect_function=connect_function, + message_function=message_function, + disconnect_function=disconnect_function, + ) + self.remote_state.declare_resource_exists(websocket_api) + self.remote_state.deployed_values['websocket_api'] = { + 'websocket_api_id': 'my_websocket_api_id', + } + plan = self.determine_plan(websocket_api) + self.assert_loads_needed_variables(plan) + assert plan[3:] == [ + models.StoreValue( + name='websocket_api_id', + value='my_websocket_api_id', + ), + models.APICall( + method_name='get_websocket_routes', + params={'api_id': Variable('websocket_api_id')}, + output_var='routes', + ), + models.APICall( + method_name='delete_websocket_routes', + params={'api_id': Variable('websocket_api_id'), + 'routes': Variable('routes')}, + ), + models.APICall( + method_name='get_websocket_integrations', + params={'api_id': Variable('websocket_api_id')}, + output_var='integrations', + ), + models.APICall( + method_name='delete_websocket_integrations', + params={'api_id': Variable('websocket_api_id'), + 'integrations': Variable('integrations')}, + ), + models.StoreValue( + name='websocket-connect-integration-lambda-path', + value=StringFormat( + 'arn:aws:apigateway:{region_name}:lambda:path/' + '2015-03-31/functions/arn:aws:lambda:{region_name}:' + '{account_id}:function:%s/' + 'invocations' % 'appname-dev-function_name_connect', + ['region_name', 'account_id'], + ), + ), + models.APICall( + method_name='create_websocket_integration', + params={ + 'api_id': Variable('websocket_api_id'), + 'lambda_function': Variable( + 'websocket-connect-integration-lambda-path'), + 'handler_type': 'connect', + }, + output_var='connect-integration-id', + ), + models.StoreValue( + name='websocket-message-integration-lambda-path', + value=StringFormat( + 'arn:aws:apigateway:{region_name}:lambda:path/' + '2015-03-31/functions/arn:aws:lambda:{region_name}:' + '{account_id}:function:%s/' + 'invocations' % 'appname-dev-function_name_message', + ['region_name', 'account_id'], + ), + ), + models.APICall( + method_name='create_websocket_integration', + params={ + 'api_id': Variable('websocket_api_id'), + 'lambda_function': Variable( + 'websocket-message-integration-lambda-path'), + 'handler_type': 'message', + }, + output_var='message-integration-id', + ), + models.StoreValue( + name='websocket-disconnect-integration-lambda-path', + value=StringFormat( + 'arn:aws:apigateway:{region_name}:lambda:path/' + '2015-03-31/functions/arn:aws:lambda:{region_name}:' + '{account_id}:function:%s/' + 'invocations' % 'appname-dev-function_name_disconnect', + ['region_name', 'account_id'], + ), + ), + models.APICall( + method_name='create_websocket_integration', + params={ + 'api_id': Variable('websocket_api_id'), + 'lambda_function': Variable( + 'websocket-disconnect-integration-lambda-path'), + 'handler_type': 'disconnect', + }, + output_var='disconnect-integration-id', + ), + models.APICall( + method_name='create_websocket_route', + params={ + 'api_id': Variable('websocket_api_id'), + 'route_key': '$connect', + 'integration_id': Variable('connect-integration-id'), + }, + ), + models.APICall( + method_name='create_websocket_route', + params={ + 'api_id': Variable('websocket_api_id'), + 'route_key': '$default', + 'integration_id': Variable('message-integration-id'), + }, + ), + models.APICall( + method_name='create_websocket_route', + params={ + 'api_id': Variable('websocket_api_id'), + 'route_key': '$disconnect', + 'integration_id': Variable('disconnect-integration-id'), + }, + ), + models.StoreValue( + name='websocket_api_url', + value=StringFormat( + 'wss://{websocket_api_id}.execute-api.{region_name}' + '.amazonaws.com/%s/' % 'api', + ['websocket_api_id', 'region_name'], + ), + ), + models.RecordResourceVariable( + resource_type='websocket_api', + resource_name='websocket_api', + name='websocket_api_url', + variable_name='websocket_api_url', + ), + models.RecordResourceVariable( + resource_type='websocket_api', + resource_name='websocket_api', + name='websocket_api_id', + variable_name='websocket_api_id', + ), + models.APICall( + method_name='add_permission_for_apigateway_v2', + params={'function_name': 'appname-dev-function_name_connect', + 'region_name': Variable('region_name'), + 'account_id': Variable('account_id'), + 'api_id': Variable('websocket_api_id')}, + ), + models.APICall( + method_name='add_permission_for_apigateway_v2', + params={'function_name': 'appname-dev-function_name_message', + 'region_name': Variable('region_name'), + 'account_id': Variable('account_id'), + 'api_id': Variable('websocket_api_id')}, + ), + models.APICall( + method_name='add_permission_for_apigateway_v2', + params={ + 'function_name': 'appname-dev-function_name_disconnect', + 'region_name': Variable('region_name'), + 'account_id': Variable('account_id'), + 'api_id': Variable('websocket_api_id'), + }, + ), + ] + + class TestPlanRestAPI(BasePlannerTests): def assert_loads_needed_variables(self, plan): # Parse arn and store region/account id for future @@ -516,16 +882,19 @@ def test_can_plan_rest_api(self): rest_api = models.RestAPI( resource_name='rest_api', swagger_doc={'swagger': '2.0'}, + endpoint_type='EDGE', minimum_compression='100', api_gateway_stage='api', lambda_function=function, ) plan = self.determine_plan(rest_api) self.assert_loads_needed_variables(plan) + assert plan[4:] == [ models.APICall( method_name='import_rest_api', - params={'swagger_document': {'swagger': '2.0'}}, + params={'swagger_document': {'swagger': '2.0'}, + 'endpoint_type': 'EDGE'}, output_var='rest_api_id', ), models.RecordResourceVariable( @@ -534,9 +903,6 @@ def test_can_plan_rest_api(self): name='rest_api_id', variable_name='rest_api_id', ), - models.APICall(method_name='deploy_rest_api', - params={'rest_api_id': Variable('rest_api_id'), - 'api_gateway_stage': 'api'}), models.APICall( method_name='update_rest_api', params={ @@ -546,7 +912,7 @@ def test_can_plan_rest_api(self): 'path': '/minimumCompressionSize', 'value': '100', }], - }, + } ), models.APICall( method_name='add_permission_for_apigateway', @@ -557,6 +923,9 @@ def test_can_plan_rest_api(self): 'rest_api_id': Variable('rest_api_id'), } ), + models.APICall(method_name='deploy_rest_api', + params={'rest_api_id': Variable('rest_api_id'), + 'api_gateway_stage': 'api'}), models.StoreValue( name='rest_api_url', value=StringFormat( @@ -576,6 +945,38 @@ def test_can_plan_rest_api(self): 'Creating Rest API\n' ] + def test_can_update_rest_api_with_policy(self): + function = create_function_resource('function_name') + rest_api = models.RestAPI( + resource_name='rest_api', + swagger_doc={'swagger': '2.0'}, + minimum_compression='', + api_gateway_stage='api', + endpoint_type='EDGE', + policy="{'Statement': []}", + lambda_function=function, + ) + self.remote_state.declare_resource_exists(rest_api) + self.remote_state.deployed_values['rest_api'] = { + 'rest_api_id': 'my_rest_api_id', + } + plan = self.determine_plan(rest_api) + + assert plan[8].params == { + 'patch_operations': [ + {'op': 'replace', + 'path': '/minimumCompressionSize', + 'value': ''}, + {'op': 'replace', + 'path': StringFormat( + ("/endpointConfiguration/types/" + "{rest_api[endpointConfiguration][types][0]}"), + ['rest_api']), + 'value': 'EDGE'} + ], + 'rest_api_id': Variable("rest_api_id") + } + def test_can_update_rest_api(self): function = create_function_resource('function_name') rest_api = models.RestAPI( @@ -583,6 +984,7 @@ def test_can_update_rest_api(self): swagger_doc={'swagger': '2.0'}, minimum_compression='', api_gateway_stage='api', + endpoint_type='REGIONAL', lambda_function=function, ) self.remote_state.declare_resource_exists(rest_api) @@ -591,6 +993,7 @@ def test_can_update_rest_api(self): } plan = self.determine_plan(rest_api) self.assert_loads_needed_variables(plan) + assert plan[4:] == [ models.StoreValue(name='rest_api_id', value='my_rest_api_id'), models.RecordResourceVariable( @@ -607,16 +1010,9 @@ def test_can_update_rest_api(self): }, ), models.APICall( - method_name='deploy_rest_api', - params={'rest_api_id': Variable('rest_api_id'), - 'api_gateway_stage': 'api'}, - ), - models.APICall( - method_name='add_permission_for_apigateway', - params={'function_name': 'appname-dev-function_name', - 'region_name': Variable('region_name'), - 'account_id': Variable('account_id'), - 'rest_api_id': Variable('rest_api_id')}, + method_name='get_rest_api', + params={'rest_api_id': Variable('rest_api_id')}, + output_var='rest_api' ), models.APICall( method_name='update_rest_api', @@ -625,8 +1021,14 @@ def test_can_update_rest_api(self): 'patch_operations': [{ 'op': 'replace', 'path': '/minimumCompressionSize', - 'value': '', - }], + 'value': ''}, + {'op': 'replace', + 'value': 'REGIONAL', + 'path': StringFormat( + '/endpointConfiguration/types/%s' % ( + '{rest_api[endpointConfiguration][types][0]}'), + ['rest_api'])}, + ], }, ), models.APICall( @@ -636,6 +1038,11 @@ def test_can_update_rest_api(self): 'account_id': Variable("account_id"), 'function_name': 'appname-dev-function_name'}, output_var=None), + models.APICall( + method_name='deploy_rest_api', + params={'rest_api_id': Variable('rest_api_id'), + 'api_gateway_stage': 'api'}, + ), models.StoreValue( name='rest_api_url', value=StringFormat( @@ -994,6 +1401,41 @@ def test_sqs_event_source_exists_updates_batch_size(self): ) ] + @pytest.mark.parametrize('functions,integration_injected', [ + ( + (create_function_resource('connect'), None, None), + 'connect' + ), + ( + (None, create_function_resource('message'), None), + 'message' + ), + ( + (None, None, create_function_resource('disconnect')), + 'disconnect' + ), + ]) + def test_websocket_api_plan_omits_unused_lambdas( + self, functions, integration_injected): + websocket_api = models.WebsocketAPI( + resource_name='websocket_api', + name='app-dev-websocket-api', + api_gateway_stage='api', + routes=['$connect', '$default', '$disconnect'], + connect_function=functions[0], + message_function=functions[1], + disconnect_function=functions[2], + ) + plan = self.determine_plan(websocket_api) + integrations = [ + code.params['handler_type'] for code in plan + if isinstance(code, APICall) + and code.method_name == 'create_websocket_integration' + ] + + assert len(integrations) == 1 + assert integrations[0] == integration_injected + class TestRemoteState(object): def setup_method(self): @@ -1008,11 +1450,24 @@ def create_rest_api_model(self): resource_name='rest_api', swagger_doc={'swagger': '2.0'}, minimum_compression='', + endpoint_type='EDGE', api_gateway_stage='api', lambda_function=None, ) return rest_api + def create_websocket_api_model(self): + websocket_api = models.WebsocketAPI( + resource_name='websocket_api', + name='app-stage-websocket-api', + api_gateway_stage='api', + routes=[], + connect_function=None, + message_function=None, + disconnect_function=None, + ) + return websocket_api + def test_role_exists(self): self.client.get_role_arn_for_name.return_value = 'role:arn' role = models.ManagedIAMRole('my_role', @@ -1060,9 +1515,9 @@ def test_rest_api_exists_no_deploy(self, no_deployed_values): remote_state = RemoteState( self.client, no_deployed_values) assert not remote_state.resource_exists(rest_api) - assert not self.client.rest_api_exists.called + assert not self.client.get_rest_api.called - def test_api_exists_with_existing_deploy(self): + def test_rest_api_exists_with_existing_deploy(self): rest_api = self.create_rest_api_model() deployed_resources = { 'resources': [{ @@ -1071,11 +1526,11 @@ def test_api_exists_with_existing_deploy(self): 'rest_api_id': 'my_rest_api_id', }] } - self.client.rest_api_exists.return_value = True + self.client.get_rest_api.return_value = {'apiId': 'my_rest_api_id'} remote_state = RemoteState( self.client, DeployedResources(deployed_resources)) assert remote_state.resource_exists(rest_api) - self.client.rest_api_exists.assert_called_with('my_rest_api_id') + self.client.get_rest_api.assert_called_with('my_rest_api_id') def test_rest_api_not_exists_with_preexisting_deploy(self): rest_api = self.create_rest_api_model() @@ -1086,11 +1541,50 @@ def test_rest_api_not_exists_with_preexisting_deploy(self): 'rest_api_id': 'my_rest_api_id', }] } - self.client.rest_api_exists.return_value = False + self.client.get_rest_api.return_value = {} remote_state = RemoteState( self.client, DeployedResources(deployed_resources)) assert not remote_state.resource_exists(rest_api) - self.client.rest_api_exists.assert_called_with('my_rest_api_id') + self.client.get_rest_api.assert_called_with('my_rest_api_id') + + def test_websocket_api_exists_no_deploy(self, no_deployed_values): + rest_api = self.create_websocket_api_model() + remote_state = RemoteState( + self.client, no_deployed_values) + assert not remote_state.resource_exists(rest_api) + assert not self.client.websocket_api_exists.called + + def test_websocket_api_exists_with_existing_deploy(self): + websocket_api = self.create_websocket_api_model() + deployed_resources = { + 'resources': [{ + 'name': 'websocket_api', + 'resource_type': 'websocket_api', + 'websocket_api_id': 'my_websocket_api_id', + }] + } + self.client.websocket_api_exists.return_value = True + remote_state = RemoteState( + self.client, DeployedResources(deployed_resources)) + assert remote_state.resource_exists(websocket_api) + self.client.websocket_api_exists.assert_called_with( + 'my_websocket_api_id') + + def test_websocket_api_not_exists_with_preexisting_deploy(self): + websocket_api = self.create_websocket_api_model() + deployed_resources = { + 'resources': [{ + 'name': 'websocket_api', + 'resource_type': 'websocket_api', + 'websocket_api_id': 'my_websocket_api_id', + }] + } + self.client.websocket_api_exists.return_value = False + remote_state = RemoteState( + self.client, DeployedResources(deployed_resources)) + assert not remote_state.resource_exists(websocket_api) + self.client.websocket_api_exists.assert_called_with( + 'my_websocket_api_id') def test_can_get_deployed_values(self): remote_state = RemoteState( @@ -1429,6 +1923,24 @@ def test_can_delete_rest_api(self): ) ] + def test_can_delete_websocket_api(self): + plan = [] + deployed = { + 'resources': [{ + 'name': 'websocket_api', + 'websocket_api_id': 'my_websocket_api_id', + 'resource_type': 'websocket_api', + }] + } + config = FakeConfig(deployed) + self.execute(plan, config) + assert plan == [ + models.APICall( + method_name='delete_websocket_api', + params={'api_id': 'my_websocket_api_id'}, + ) + ] + def test_can_handle_when_resource_changes_values(self): plan = self.determine_plan( models.S3BucketNotification( diff --git a/tests/unit/deploy/test_swagger.py b/tests/unit/deploy/test_swagger.py index 025dec67b..0934b94c8 100644 --- a/tests/unit/deploy/test_swagger.py +++ b/tests/unit/deploy/test_swagger.py @@ -3,7 +3,7 @@ from chalice import CORSConfig from chalice.app import CustomAuthorizer, CognitoUserPoolAuthorizer from chalice.app import IAMAuthorizer, Chalice - +from chalice.deploy.models import RestAPI, IAMPolicy import mock from pytest import fixture @@ -566,6 +566,85 @@ def foo(): } +def test_can_custom_resource_policy_with_cfn(sample_app): + swagger_gen = CFNSwaggerGenerator() + rest_api = RestAPI( + resource_name='dev', + swagger_doc={}, + lambda_function=None, + minimum_compression="", + api_gateway_stage="xyz", + endpoint_type="PRIVATE", + policy=IAMPolicy({ + 'Statement': [{ + "Effect": "Allow", + "Principal": "*", + "Action": "execute-api:Invoke", + "Resource": [ + "arn:aws:execute-api:*:*:*", + "arn:aws:exceute-api:*:*:*/*" + ], + "Condition": { + "StringEquals": { + "aws:SourceVpce": "vpce-abc123" + } + } + }] + }) + ) + + doc = swagger_gen.generate_swagger(sample_app, rest_api) + assert doc['x-amazon-apigateway-policy'] == { + 'Statement': [{ + 'Action': 'execute-api:Invoke', + 'Condition': {'StringEquals': { + 'aws:SourceVpce': 'vpce-abc123'}}, + 'Effect': 'Allow', + 'Principal': '*', + 'Resource': [ + 'arn:aws:execute-api:*:*:*', + "arn:aws:exceute-api:*:*:*/*"] + }] + } + + +def test_can_auto_resource_policy_with_cfn(sample_app): + swagger_gen = CFNSwaggerGenerator() + rest_api = RestAPI( + resource_name='dev', + swagger_doc={}, + lambda_function=None, + minimum_compression="", + api_gateway_stage="xyz", + endpoint_type="PRIVATE", + policy=IAMPolicy({ + 'Statement': [{ + "Effect": "Allow", + "Principal": "*", + "Action": "execute-api:Invoke", + "Resource": "arn:aws:execute-api:*:*:*/*", + "Condition": { + "StringEquals": { + "aws:SourceVpce": "vpce-abc123" + } + } + }] + }) + ) + + doc = swagger_gen.generate_swagger(sample_app, rest_api) + assert doc['x-amazon-apigateway-policy'] == { + 'Statement': [{ + 'Action': 'execute-api:Invoke', + 'Condition': {'StringEquals': { + 'aws:SourceVpce': 'vpce-abc123'}}, + 'Effect': 'Allow', + 'Principal': '*', + 'Resource': 'arn:aws:execute-api:*:*:*/*', + }] + } + + def test_will_custom_auth_with_cfn(sample_app): swagger_gen = CFNSwaggerGenerator() diff --git a/tests/unit/deploy/test_validate.py b/tests/unit/deploy/test_validate.py index 0e13c8c0a..8e9ab082f 100644 --- a/tests/unit/deploy/test_validate.py +++ b/tests/unit/deploy/test_validate.py @@ -12,6 +12,8 @@ from chalice.deploy.validate import validate_route_content_types from chalice.deploy.validate import validate_unique_function_names from chalice.deploy.validate import validate_feature_flags +from chalice.deploy.validate import validate_endpoint_type +from chalice.deploy.validate import validate_resource_policy from chalice.deploy.validate import ExperimentalFeatureError @@ -233,6 +235,51 @@ def index(): sample_app.api.binary_types) is None +def test_can_validate_resource_policy(sample_app): + config = Config.create( + chalice_app=sample_app, api_gateway_endpoint_type='PRIVATE') + with pytest.raises(ValueError): + validate_resource_policy(config) + + config = Config.create( + chalice_app=sample_app, + api_gateway_endpoint_vpce='vpce-abc123', + api_gateway_endpoint_type='PRIVATE') + validate_resource_policy(config) + + config = Config.create( + chalice_app=sample_app, + api_gateway_endpoint_vpce='vpce-abc123', + api_gateway_endpoint_type='REGIONAL') + with pytest.raises(ValueError): + validate_resource_policy(config) + + config = Config.create( + chalice_app=sample_app, + api_gateway_policy_file='xyz.json', + api_gateway_endpoint_type='PRIVATE') + validate_resource_policy(config) + + config = Config.create( + chalice_app=sample_app, + api_gateway_endpoint_vpce=['vpce-abc123', 'vpce-bdef'], + api_gateway_policy_file='bar.json', + api_gateway_endpoint_type='PRIVATE') + with pytest.raises(ValueError): + validate_resource_policy(config) + + +def test_can_validate_endpoint_type(sample_app): + config = Config.create( + chalice_app=sample_app, api_gateway_endpoint_type='EDGE2') + with pytest.raises(ValueError): + validate_endpoint_type(config) + + config = Config.create( + chalice_app=sample_app, api_gateway_endpoint_type='REGIONAL') + validate_endpoint_type(config) + + def test_can_validate_feature_flags(sample_app): # The _features_used is marked internal because we don't want # chalice users to access it, but this attribute is intended to be diff --git a/tests/unit/test_app.py b/tests/unit/test_app.py index 5b3093959..6cfd433e6 100644 --- a/tests/unit/test_app.py +++ b/tests/unit/test_app.py @@ -12,7 +12,6 @@ from hypothesis import given, assume import six - from chalice import app from chalice import NotFoundError from chalice.app import ( @@ -21,6 +20,10 @@ Response, handle_extra_types, MultiDict, + WebsocketEvent, + BadRequestError, + WebsocketDisconnectedError, + WebsocketEventSourceHandler, ) from chalice import __version__ as chalice_version from chalice.deploy.validate import ExperimentalFeatureError @@ -78,6 +81,40 @@ def serialize(self): return serialized +class FakeGoneException(Exception): + pass + + +class FakeExceptionFactory(object): + def __init__(self): + self.GoneException = FakeGoneException + + +class FakeClient(object): + def __init__(self, errors=None): + if errors is None: + errors = [] + self._errors = errors + self.calls = [] + self.exceptions = FakeExceptionFactory() + + def post_to_connection(self, ConnectionId, Data): + self.calls.append((ConnectionId, Data)) + if self._errors: + error = self._errors.pop() + raise error + + +class FakeSession(object): + def __init__(self, client=None): + self.calls = [] + self._client = client + + def client(self, name, endpoint_url=None): + self.calls.append((name, endpoint_url)) + return self._client + + @pytest.fixture def view_function(): def _func(): @@ -115,6 +152,13 @@ def assert_requires_opt_in(app, flag): ) +def websocket_handler_for_route(route, app): + fn = app.websocket_handlers[route].handler_function + handler = WebsocketEventSourceHandler( + fn, WebsocketEvent, app.websocket_api) + return handler + + @fixture def sample_app(): demo = app.Chalice('demo-app') @@ -142,6 +186,31 @@ def image(): return demo +@fixture +def sample_websocket_app(): + demo = app.Chalice('app-name') + demo.websocket_api.session = FakeSession() + + calls = [] + + @demo.on_ws_connect() + def connect(event): + demo.websocket_api.send(event.connection_id, 'connected') + calls.append(('connect', event)) + + @demo.on_ws_disconnect() + def disconnect(event): + demo.websocket_api.send(event.connection_id, 'message') + calls.append(('disconnect', event)) + + @demo.on_ws_message() + def message(event): + demo.websocket_api.send(event.connection_id, 'disconnected') + calls.append(('default', event)) + + return demo, calls + + @fixture def auth_request(): method_arn = ( @@ -1918,13 +1987,16 @@ def test_multidict_raises_keyerror(input_dict): assert val is val -@pytest.mark.parametrize('input_dict', [ - {}, - {'key': []} -]) -def test_multidict_returns_emptylist(input_dict): - d = MultiDict(input_dict) - assert d.getlist('key') == [] +def test_multidict_pop_raises_del_error(): + d = MultiDict({}) + with pytest.raises(KeyError): + del d['key'] + + +def test_multidict_getlist_does_raise_keyerror(): + d = MultiDict({}) + with pytest.raises(KeyError): + d.getlist('key') @pytest.mark.parametrize('input_dict', [ @@ -1962,7 +2034,448 @@ def test_multidict_list_wont_change_source(input_dict): assert d.getlist('key') == dict_copy['key'] -def test_multidict_is_readonly(): - d = MultiDict(None) - with pytest.raises(TypeError): - d['key'] = 'value' +@pytest.mark.parametrize('input_dict,key,popped,leftover', [ + ( + {'key': ['value'], 'key2': [[]]}, + 'key', + 'value', + {'key2': []}, + ), + ( + {'key': [''], 'key2': [[]]}, + 'key', + '', + {'key2': []}, + ), + ( + {'key': ['value1', 'value2', 'value3'], + 'key2': [[]]}, + 'key', + 'value3', + {'key2': []}, + ), +]) +def test_multidict_list_can_pop_value(input_dict, key, popped, leftover): + d = MultiDict(input_dict) + pop_result = d.pop(key) + assert popped == pop_result + assert leftover == {key: d[key] for key in d} + + +def test_multidict_assignment(): + d = MultiDict({}) + d['key'] = 'value' + assert d['key'] == 'value' + + +def test_multidict_get_reassigned_value(): + d = MultiDict({}) + d['key'] = 'value' + assert d['key'] == 'value' + assert d.get('key') == 'value' + assert d.getlist('key') == ['value'] + + +def test_multidict_get_list_wraps_key(): + d = MultiDict({}) + d['key'] = ['value'] + assert d.getlist('key') == [['value']] + + +def test_multidict_repr(): + d = MultiDict({ + 'foo': ['bar', 'baz'], + 'buz': ['qux'], + }) + rep = repr(d) + assert rep.startswith('MultiDict({') + assert "'foo': ['bar', 'baz']" in rep + assert "'buz': ['qux']" in rep + + +def test_multidict_str(): + d = MultiDict({ + 'foo': ['bar', 'baz'], + 'buz': ['qux'], + }) + rep = str(d) + assert rep.startswith('MultiDict({') + assert "'foo': ['bar', 'baz']" in rep + assert "'buz': ['qux']" in rep + + +def test_can_configure_websockets(sample_websocket_app): + demo, _ = sample_websocket_app + + assert len(demo.websocket_handlers) == 3, demo.websocket_handlers + assert '$connect' in demo.websocket_handlers, demo.websocket_handlers + assert '$disconnect' in demo.websocket_handlers, demo.websocket_handlers + assert '$default' in demo.websocket_handlers, demo.websocket_handlers + + +def test_websocket_event_json_body_available(sample_websocket_app, + create_websocket_event): + demo = app.Chalice('demo-app') + called = {'wascalled': False} + + @demo.on_ws_message() + def message(event): + called['wascalled'] = True + assert event.json_body == {'foo': 'bar'} + # Second access hits the cache. Test that that works as well. + assert event.json_body == {'foo': 'bar'} + + event = create_websocket_event('$default', body='{"foo": "bar"}') + handler = websocket_handler_for_route('$default', demo) + + handler(event, context=None) + assert called['wascalled'] is True + + +def test_websocket_event_json_body_can_raise_error(sample_websocket_app, + create_websocket_event): + demo = app.Chalice('demo-app') + called = {'wascalled': False} + + @demo.on_ws_message() + def message(event): + called['wascalled'] = True + with pytest.raises(BadRequestError): + event.json_body + + event = create_websocket_event('$default', body='{"foo": "bar"') + handler = websocket_handler_for_route('$default', demo) + + handler(event, context=None) + assert called['wascalled'] is True + + +def test_can_route_websocket_connect_message(sample_websocket_app, + create_websocket_event): + demo, calls = sample_websocket_app + client = FakeClient() + demo.websocket_api.session = FakeSession(client) + event = create_websocket_event('$connect') + handler = websocket_handler_for_route('$connect', demo) + response = handler(event, context=None) + + assert response == {'statusCode': 200} + assert len(calls) == 1 + assert calls[0][0] == 'connect' + event = calls[0][1] + assert isinstance(event, WebsocketEvent) + assert event.domain_name == 'abcd1234.us-west-2.amazonaws.com' + assert event.stage == 'api' + assert event.connection_id == 'ABCD1234=' + + +def test_can_route_websocket_disconnect_message(sample_websocket_app, + create_websocket_event): + demo, calls = sample_websocket_app + client = FakeClient() + demo.websocket_api.session = FakeSession(client) + event = create_websocket_event('$disconnect') + handler = websocket_handler_for_route('$disconnect', demo) + response = handler(event, context=None) + + assert response == {'statusCode': 200} + assert len(calls) == 1 + assert calls[0][0] == 'disconnect' + event = calls[0][1] + assert isinstance(event, WebsocketEvent) + assert event.domain_name == 'abcd1234.us-west-2.amazonaws.com' + assert event.stage == 'api' + assert event.connection_id == 'ABCD1234=' + + +def test_can_route_websocket_default_message(sample_websocket_app, + create_websocket_event): + demo, calls = sample_websocket_app + client = FakeClient() + demo.websocket_api.session = FakeSession(client) + event = create_websocket_event('$default', body='foo bar') + handler = websocket_handler_for_route('$default', demo) + response = handler(event, context=None) + + assert response == {'statusCode': 200} + assert len(calls) == 1 + assert calls[0][0] == 'default' + event = calls[0][1] + assert isinstance(event, WebsocketEvent) + assert event.domain_name == 'abcd1234.us-west-2.amazonaws.com' + assert event.stage == 'api' + assert event.connection_id == 'ABCD1234=' + assert event.body == 'foo bar' + + +def test_can_configure_client_on_connect(sample_websocket_app, + create_websocket_event): + demo, calls = sample_websocket_app + client = FakeClient() + demo.websocket_api.session = FakeSession(client) + event = create_websocket_event('$connect') + handler = websocket_handler_for_route('$connect', demo) + handler(event, context=None) + + assert demo.websocket_api.session.calls == [ + ('apigatewaymanagementapi', + 'https://abcd1234.us-west-2.amazonaws.com/api'), + ] + + +def test_can_configure_client_on_disconnect(sample_websocket_app, + create_websocket_event): + demo, calls = sample_websocket_app + client = FakeClient() + demo.websocket_api.session = FakeSession(client) + event = create_websocket_event('$disconnect') + handler = websocket_handler_for_route('$disconnect', demo) + handler(event, context=None) + + assert demo.websocket_api.session.calls == [ + ('apigatewaymanagementapi', + 'https://abcd1234.us-west-2.amazonaws.com/api'), + ] + + +def test_can_configure_client_on_message(sample_websocket_app, + create_websocket_event): + demo, calls = sample_websocket_app + client = FakeClient() + demo.websocket_api.session = FakeSession(client) + event = create_websocket_event('$default', body='foo bar') + handler = websocket_handler_for_route('$default', demo) + + handler(event, context=None) + + assert demo.websocket_api.session.calls == [ + ('apigatewaymanagementapi', + 'https://abcd1234.us-west-2.amazonaws.com/api'), + ] + + +def test_does_only_configure_client_once(sample_websocket_app, + create_websocket_event): + demo, calls = sample_websocket_app + client = FakeClient() + demo.websocket_api.session = FakeSession(client) + event = create_websocket_event('$default', body='foo bar') + handler = websocket_handler_for_route('$default', demo) + + handler(event, context=None) + handler(event, context=None) + + assert demo.websocket_api.session.calls == [ + ('apigatewaymanagementapi', + 'https://abcd1234.us-west-2.amazonaws.com/api'), + ] + + +def test_cannot_configure_client_without_session(sample_websocket_app, + create_websocket_event): + demo, calls = sample_websocket_app + demo.websocket_api.session = None + event = create_websocket_event('$default', body='foo bar') + handler = websocket_handler_for_route('$default', demo) + with pytest.raises(ValueError) as e: + handler(event, context=None) + + assert str(e.value) == ( + 'Assign app.websocket_api.session to a boto3 session before using ' + 'the WebsocketAPI' + ) + + +def test_cannot_send_websocket_message_without_configure( + sample_websocket_app, create_websocket_event): + demo = app.Chalice('app-name') + client = FakeClient() + demo.websocket_api.session = FakeSession(client) + + @demo.on_ws_message() + def message_handler(event): + demo.websocket_api.send('connection_id', event.body) + + event = create_websocket_event('$default', body='foo bar') + event_obj = WebsocketEvent(event, None) + handler = demo.websocket_handlers['$default'].handler_function + with pytest.raises(ValueError) as e: + handler(event_obj) + assert str(e.value) == ( + 'WebsocketAPI.configure must be called before using the WebsocketAPI' + ) + + +def test_can_send_websocket_message(create_websocket_event): + demo = app.Chalice('app-name') + client = FakeClient() + demo.websocket_api.session = FakeSession(client) + + @demo.on_ws_message() + def message_handler(event): + demo.websocket_api.send('connection_id', event.body) + + event = create_websocket_event('$default', body='foo bar') + handler = websocket_handler_for_route('$default', demo) + handler(event, context=None) + + assert len(client.calls) == 1 + call = client.calls[0] + connection_id, message = call + assert connection_id == 'connection_id' + assert message == 'foo bar' + + +def test_does_raise_on_send_to_bad_websocket(create_websocket_event): + demo = app.Chalice('app-name') + client = FakeClient(errors=[FakeGoneException]) + demo.websocket_api.session = FakeSession(client) + + @demo.on_ws_message() + def message_handler(event): + with pytest.raises(WebsocketDisconnectedError) as e: + demo.websocket_api.send('connection_id', event.body) + assert e.value.connection_id == 'connection_id' + + event = create_websocket_event('$default', body='foo bar') + handler = websocket_handler_for_route('$default', demo) + handler(event, context=None) + + +def test_does_reraise_on_websocket_send_error(create_websocket_event): + class SomeOtherError(Exception): + pass + + demo = app.Chalice('app-name') + fake_418_error = SomeOtherError() + fake_418_error.response = {'ResponseMetadata': {'HTTPStatusCode': 418}} + client = FakeClient(errors=[fake_418_error]) + demo.websocket_api.session = FakeSession(client) + + @demo.on_ws_message() + def message_handler(event): + with pytest.raises(SomeOtherError): + demo.websocket_api.send('connection_id', event.body) + + event = create_websocket_event('$default', body='foo bar') + handler = websocket_handler_for_route('$default', demo) + handler(event, context=None) + + +def test_does_reraise_on_other_send_exception(create_websocket_event): + demo = app.Chalice('app-name') + fake_500_error = Exception() + fake_500_error.response = {'ResponseMetadata': {'HTTPStatusCode': 500}} + fake_500_error.key = 'foo' + client = FakeClient(errors=[fake_500_error]) + demo.websocket_api.session = FakeSession(client) + + @demo.on_ws_message() + def message_handler(event): + with pytest.raises(Exception) as e: + demo.websocket_api.send('connection_id', event.body) + assert e.value.key == 'foo' + + event = create_websocket_event('$default', body='foo bar') + demo(event, context=None) + + +def test_cannot_send_message_on_unconfigured_app(): + demo = app.Chalice('app-name') + demo.websocket_api.session = None + + with pytest.raises(ValueError) as e: + demo.websocket_api.send('connection_id', 'body') + + assert str(e.value) == ( + 'Assign app.websocket_api.session to a boto3 session before ' + 'using the WebsocketAPI' + ) + + +def test_cannot_re_register_websocket_handlers(create_websocket_event): + demo = app.Chalice('app-name') + + @demo.on_ws_message() + def message_handler(event): + pass + + with pytest.raises(ValueError) as e: + @demo.on_ws_message() + def message_handler_2(event): + pass + + assert str(e.value) == ( + "Duplicate websocket handler: 'on_ws_message'. There can only be one " + "handler for each websocket decorator." + ) + + @demo.on_ws_connect() + def connect_handler(event): + pass + + with pytest.raises(ValueError) as e: + @demo.on_ws_connect() + def conncet_handler_2(event): + pass + + assert str(e.value) == ( + "Duplicate websocket handler: 'on_ws_connect'. There can only be one " + "handler for each websocket decorator." + ) + + @demo.on_ws_disconnect() + def disconnect_handler(event): + pass + + with pytest.raises(ValueError) as e: + @demo.on_ws_disconnect() + def disconncet_handler_2(event): + pass + + assert str(e.value) == ( + "Duplicate websocket handler: 'on_ws_disconnect'. There can only be " + "one handler for each websocket decorator." + ) + + +def test_can_parse_json_websocket_body(create_websocket_event): + demo = app.Chalice('app-name') + client = FakeClient() + demo.websocket_api.session = FakeSession(client) + + @demo.on_ws_message() + def message(event): + assert event.json_body == {'foo': 'bar'} + + event = create_websocket_event('$default', body='{"foo": "bar"}') + demo(event, context=None) + + +def test_can_access_websocket_json_body_twice(create_websocket_event): + demo = app.Chalice('app-name') + client = FakeClient() + demo.websocket_api.session = FakeSession(client) + + @demo.on_ws_message() + def message(event): + assert event.json_body == {'foo': 'bar'} + assert event.json_body == {'foo': 'bar'} + + event = create_websocket_event('$default', body='{"foo": "bar"}') + demo(event, context=None) + + +def test_does_raise_on_invalid_json_wbsocket_body(create_websocket_event): + demo = app.Chalice('app-name') + client = FakeClient() + demo.websocket_api.session = FakeSession(client) + + @demo.on_ws_message() + def message(event): + with pytest.raises(BadRequestError) as e: + event.json_body + assert 'Error Parsing JSON' in str(e.value) + + event = create_websocket_event('$default', body='foo bar') + demo(event, context=None) diff --git a/tests/unit/test_package.py b/tests/unit/test_package.py index 3172cb2e2..a0fa7a655 100644 --- a/tests/unit/test_package.py +++ b/tests/unit/test_package.py @@ -632,6 +632,168 @@ def test_can_generate_rest_api(self, sample_app_with_auth): 'RestAPIId': {'Value': {'Ref': 'RestAPI'}} } + @pytest.mark.parametrize('route_key,route', [ + ('$default', 'WebsocketMessageRoute'), + ('$connect', 'WebsocketConnectRoute'), + ('$disconnect', 'WebsocketDisconnectRoute')] + ) + def test_generate_partial_websocket_api( + self, route_key, route, sample_websocket_app): + # Remove all but one websocket route. + sample_websocket_app.websocket_handlers = { + name: handler for name, handler in + sample_websocket_app.websocket_handlers.items() + if name == route_key + } + config = Config.create(chalice_app=sample_websocket_app, + project_dir='.', + api_gateway_stage='api') + template = self.generate_template(config, 'dev') + resources = template['Resources'] + + # Check that the template's deployment only depends on the one route. + depends_on = resources['WebsocketAPIDeployment'].pop('DependsOn') + assert [route] == depends_on + + def test_generate_websocket_api(self, sample_websocket_app): + config = Config.create(chalice_app=sample_websocket_app, + project_dir='.', + api_gateway_stage='api') + template = self.generate_template(config, 'dev') + resources = template['Resources'] + + assert resources['WebsocketAPI']['Type'] == 'AWS::ApiGatewayV2::Api' + + for handler, route in (('WebsocketConnect', '$connect'), + ('WebsocketMessage', '$default'), + ('WebsocketDisconnect', '$disconnect'),): + # Lambda function should be created. + assert resources[handler][ + 'Type'] == 'AWS::Serverless::Function' + + # Along with permission to invoke from API Gateway. + assert resources['%sInvokePermission' % handler] == { + 'Type': 'AWS::Lambda::Permission', + 'Properties': { + 'Action': 'lambda:InvokeFunction', + 'FunctionName': {'Ref': handler}, + 'Principal': 'apigateway.amazonaws.com', + 'SourceArn': { + 'Fn::Sub': [ + ( + 'arn:aws:execute-api:${AWS::Region}:${AWS::' + 'AccountId}:${WebsocketAPIId}/*' + ), + {'WebsocketAPIId': {'Ref': 'WebsocketAPI'}}]}}, + } + + # Ensure Integration is created. + assert resources['%sAPIIntegration' % handler] == { + 'Type': 'AWS::ApiGatewayV2::Integration', + 'Properties': { + 'ApiId': { + 'Ref': 'WebsocketAPI' + }, + 'ConnectionType': 'INTERNET', + 'ContentHandlingStrategy': 'CONVERT_TO_TEXT', + 'IntegrationType': 'AWS_PROXY', + 'IntegrationUri': { + 'Fn::Sub': [ + ( + 'arn:aws:apigateway:${AWS::Region}:lambda:path' + '/2015-03-31/functions/arn:aws:lambda:' + '${AWS::Region}:' '${AWS::AccountId}:function:' + '${WebsocketHandler}/invocations' + ), + {'WebsocketHandler': {'Ref': handler}} + ], + } + } + } + + # Route for the handler. + assert resources['%sRoute' % handler] == { + 'Type': 'AWS::ApiGatewayV2::Route', + 'Properties': { + 'ApiId': { + 'Ref': 'WebsocketAPI' + }, + 'RouteKey': route, + 'Target': { + 'Fn::Join': [ + '/', + [ + 'integrations', + {'Ref': '%sAPIIntegration' % handler}, + ] + ] + } + } + } + + # Ensure the deployment is created. It must manually depend on + # the routes since it cannot be created for WebsocketAPI that has no + # routes. The API has no such implicit contract so CloudFormation can + # deploy things out of order without the explicit DependsOn. + depends_on = set(resources['WebsocketAPIDeployment'].pop('DependsOn')) + assert set(['WebsocketConnectRoute', + 'WebsocketMessageRoute', + 'WebsocketDisconnectRoute']) == depends_on + assert resources['WebsocketAPIDeployment'] == { + 'Type': 'AWS::ApiGatewayV2::Deployment', + 'Properties': { + 'ApiId': { + 'Ref': 'WebsocketAPI' + } + } + } + + # Ensure the stage is created. + resources['WebsocketAPIStage'] = { + 'Type': 'AWS::ApiGatewayV2::Stage', + 'Properties': { + 'ApiId': { + 'Ref': 'WebsocketAPI' + }, + 'DeploymentId': {'Ref': 'WebsocketAPIDeployment'}, + 'StageName': 'api', + } + } + + # Ensure the outputs are created + assert template['Outputs'] == { + 'WebsocketConnectHandlerArn': { + 'Value': { + 'Fn::GetAtt': ['WebsocketConnect', 'Arn'] + } + }, + 'WebsocketConnectHandlerName': {'Value': {'Ref': + 'WebsocketConnect'}}, + 'WebsocketMessageHandlerArn': { + 'Value': { + 'Fn::GetAtt': ['WebsocketMessage', 'Arn'] + } + }, + 'WebsocketMessageHandlerName': {'Value': {'Ref': + 'WebsocketMessage'}}, + 'WebsocketDisconnectHandlerArn': { + 'Value': { + 'Fn::GetAtt': ['WebsocketDisconnect', 'Arn'] + } + }, + 'WebsocketDisconnectHandlerName': {'Value': { + 'Ref': 'WebsocketDisconnect'}}, + 'WebsocketConnectEndpointURL': { + 'Value': { + 'Fn::Sub': ( + 'wss://${WebsocketAPI}.execute-api.' + '${AWS::Region}.amazonaws.com/api/' + ) + } + }, + 'WebsocketAPIId': {'Value': {'Ref': 'WebsocketAPI'}} + } + def test_managed_iam_role(self): role = models.ManagedIAMRole( resource_name='default_role',