diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 39c99ed..30052ea 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -7,16 +7,14 @@ assignees: '' --- If you have an active AWS support contract, please open a case with AWS Premium Support team using the below documentation to report the issue: - https://docs.aws.amazon.com/awssupport/latest/user/case-management.html -Before submitting a new issue, please search through open GitHub Issues (https://github.com/aws/res/issues) and check out the troubleshooting documentation. - Please make sure to add the following data in order to facilitate the root cause detection. + **Describe the bug** -A clear and concise description of what the bug is. +A clear and concise description of what the bug is. Before submitting a new issue, please search through open [GitHub Issues](https://github.com/aws/res/issues) and check out the [troubleshooting documentation](https://github.com/aws/res/wiki/Troubleshooting). -**To Reproduce** +**Steps to reproduce** Steps to reproduce the behavior: 1. Go to '...' 2. Click on '....' @@ -26,13 +24,18 @@ Steps to reproduce the behavior: **Expected behavior** A clear and concise description of what you expected to happen. -**Screenshots** -If applicable, add screenshots to help explain your problem. +**Actual behavior** +A clear and concise description of what actually happened. -**Desktop (please complete the following information):** - - OS: [e.g. iOS] - - Browser [e.g. chrome, safari] - - Version [e.g. 22] +**Screenshots/Video** +If applicable, add screenshots and/or a video to help explain your problem. + +**Environment (please complete the following information):** + - RES Version: [e.g. 2023.11] + - Software Stack AMI ID: [e.g. ami-0fceec18b58bfda68] + - Software Stack OS: [e.g. Windows, Amazon Linux 2, CentOS 7, RedHat Enterprise Linux 7] **Additional context** -Add any other context about the problem here. \ No newline at end of file +Add any other context about the problem here. + + diff --git a/CHANGELOG.md b/CHANGELOG.md index 412f90e..1c39ace 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,30 @@ # Change Log -All notable changes to this project will be documented in this file. +This file is used to list changes made in each release of Research and Engineering Studio (RES). -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), -and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +2024.01 +------ + +**ENHANCEMENTS** + +- Add support for snapshots that enable migration between versions of RES. + - The migration process involves taking a snapshot, deploying the new version (e.g. 2024.01) and applying the snapshot from the previous version (e.g. 2023.11) on the new version. Allows for admins to confirm the new version works before transferring users. +- Add support for private subnets. + - Enable deployments of RES infrastructure hosts in private subnets with internet access. + - RES infrastructure hosts refer to the Cluster Manager, Bastion Host, VDC Gateway, VDC Controller, and VDC Broker. +- Deprecation of the analytics stack. + - Removed required OpenSearch dependency. + - Reduces RES environment deployment and deletion time by approximately 30 minutes. +- Add support for use of custom Amazon Linux 2 EC2 AMI for RES infrastructure hosts. + - Enable specifying an AL2 EC2 AMI use for RES infrastructure hosts for users that require specific software or updates installed on their hosts. + - RES infrastructure hosts refer to the Cluster Manager, Bastion Host, VDC Gateway, VDC Controller, and VDC Broker. +- Add support for ldap_id_mapping “True” in SSSD. + - Previously the AD sync code required the groups and users to have POSIX attributes uidNumber and gidNumber in order to sync with RES. This conflicted with the IDs generated by SSSD, potentially causing RES users to not be able to access filesystems if they were using SSSD with other systems (e.g. ParallelCluster) . +- Add support for four new regions Asia Pacific (Tokyo), Asia Pacific (Seoul), Canada (Central), Europe (Milan). +- Add ability to add users to projects. Previously only groups could be added to project permissions. + +**BUG FIXES** + +- Added validation for FSx ONTAP filesystem creation +- Narrowed installation IAM permissions +- Skipped deletion of batteries included related resources +- VDI no longer tries to mount filesystems after removing filesystem from project \ No newline at end of file diff --git a/README.md b/README.md index 8756a16..65e2d91 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ If you have a [support plan](https://aws.amazon.com/premiumsupport/) with AWS Su You can also [open an issue](https://github.com/aws/res/issues/new/choose) and choose from one of our templates for guidance, bug reports, or feature requests. Please check for open [similar issues](https://github.com/aws/res/issues) before opening another one. -##More Resources +## More Resources * [Changelog](https://github.com/aws/res/blob/mainline/CHANGELOG.md) * [Amazon Web Services Discussion Forums](https://repost.aws/) diff --git a/RES_VERSION.txt b/RES_VERSION.txt index 38160ca..0969a86 100644 --- a/RES_VERSION.txt +++ b/RES_VERSION.txt @@ -1 +1 @@ -2023.11 +2024.01 diff --git a/codeCoverage.yml b/codeCoverage.yml index cf50c82..9543164 100644 --- a/codeCoverage.yml +++ b/codeCoverage.yml @@ -5,7 +5,6 @@ source/idea/idea-data-model/src/ideadatamodel/aws: 0.8 source/idea/idea-data-model/src/ideadatamodel/cluster_resources: 0.8 source/idea/idea-data-model/src/ideadatamodel/common: 0.8 source/idea/idea-data-model/src/ideadatamodel: 0.8 -source/idea/idea-sdk/src/ideasdk/analytics: 0.8 source/idea/idea-sdk/src/ideasdk/auth: 0.8 source/idea/idea-sdk/src/ideasdk/aws: 0.8 source/idea/idea-sdk/src/ideasdk/utils: 0.8 diff --git a/deployment/ecr/idea-administrator/Dockerfile b/deployment/ecr/idea-administrator/Dockerfile deleted file mode 100644 index 6f40aa7..0000000 --- a/deployment/ecr/idea-administrator/Dockerfile +++ /dev/null @@ -1,55 +0,0 @@ -FROM public.ecr.aws/docker/library/python:3.9.16-slim - -WORKDIR /root - -RUN apt-get update && \ - apt-get -y install \ - curl \ - tar \ - unzip \ - locales \ - && apt-get clean - - -ENV DEBIAN_FRONTEND=noninteractive -ENV LC_ALL="en_US.UTF-8" \ - LC_CTYPE="en_US.UTF-8" \ - LANG="en_US.UTF-8" - -RUN sed -i -e "s/# $LANG.*/$LANG UTF-8/" /etc/locale.gen \ - && locale-gen "en_US.UTF-8" \ - && dpkg-reconfigure locales - -# install aws cli -RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" && \ - unzip -qq awscliv2.zip && \ - ./aws/install && \ - rm -rf ./aws awscliv2.zip - -# install nvm and node -RUN set -uex && \ - apt-get update && \ - apt-get install -y ca-certificates curl gnupg && \ - mkdir -p /etc/apt/keyrings && \ - curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key \ - | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \ - NODE_MAJOR=18 && \ - echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_$NODE_MAJOR.x nodistro main" \ - > /etc/apt/sources.list.d/nodesource.list && \ - apt-get update && \ - apt-get install nodejs -y - -# add all packaged artifacts to container -ARG PUBLIC_ECR_TAG -ENV PUBLIC_ECR_TAG=${PUBLIC_ECR_TAG} -ADD all-*.tar.gz cfn_params_2_values.sh /root/.idea/downloads/ - -# install administrator app -RUN mkdir -p /root/.idea/downloads/idea-administrator-${PUBLIC_ECR_TAG} && \ - tar -xvf /root/.idea/downloads/idea-administrator-*.tar.gz -C /root/.idea/downloads/idea-administrator-${PUBLIC_ECR_TAG} && \ - /bin/bash /root/.idea/downloads/idea-administrator-${PUBLIC_ECR_TAG}/install.sh && \ - rm -rf /root/.idea/downloads/idea-administrator-${PUBLIC_ECR_TAG} - -CMD ["bash"] - - diff --git a/deployment/ecr/idea-administrator/cfn_params_2_values.sh b/deployment/ecr/idea-administrator/cfn_params_2_values.sh deleted file mode 100755 index 90f7e86..0000000 --- a/deployment/ecr/idea-administrator/cfn_params_2_values.sh +++ /dev/null @@ -1,74 +0,0 @@ -#!/bin/bash - -aws_partition=${1} -aws_region=${2} -aws_account_id=${3} -aws_dns_suffix=${4} -cluster_name=${5} -administrator_email=${6} -ssh_key_pair_name=${7} -client_ip1=${8} -client_ip2=${9} -vpc_id=${10} -pub_subnets=${11} -pvt_subnets=${12} -storage_home_provider=${13} -home_fs_id=${14} - -values_file="/root/.idea/clusters/${5}/${2}/values.yml" - -prt_subnets(){ - for sn in $(echo $1| tr ',' ' ') - do - echo "- ${sn}" - done -} - -dir_name=$(dirname ${values_file}) - -mkdir -p ${dir_name} - -rm -f ${values_file} -cat << EOF1 > ${values_file} -_regenerate: false -aws_partition: ${aws_partition} -aws_region: ${aws_region} -aws_account_id: ${aws_account_id} -aws_dns_suffix: ${aws_dns_suffix} -cluster_name: ${cluster_name} -administrator_email: ${administrator_email} -ssh_key_pair_name: ${ssh_key_pair_name} -client_ip: -- ${client_ip1} -- ${client_ip2} -alb_public: true -use_vpc_endpoints: true -directory_service_provider: aws_managed_activedirectory -enable_aws_backup: true -kms_key_type: aws-managed -use_existing_vpc: true -vpc_id: ${vpc_id} -existing_resources: -- subnets:public -- subnets:private -- shared-storage:home -public_subnet_ids: -EOF1 -prt_subnets ${pub_subnets} >> ${values_file} -cat << EOF2 >> ${values_file} -private_subnet_ids: -EOF2 -prt_subnets ${pvt_subnets} >> ${values_file} -cat << EOF3 >> ${values_file} -storage_home_provider: ${storage_home_provider} -use_existing_home_fs: true -existing_home_fs_id: ${home_fs_id} -enabled_modules: -- metrics -- virtual-desktop-controller -- bastion-host -metrics_provider: cloudwatch -base_os: amazonlinux2 -instance_type: m5.large -volume_size: '200' -EOF3 diff --git a/pyproject.toml b/pyproject.toml index fbe76a4..ae0b1e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,12 +39,13 @@ cdk-app = "idea.app:main" [project.optional-dependencies] dev = [ - "black", + "black~=24.1.0", "tox", "pytest", "pytest-cov", "boto3-stubs-lite[essential]", "boto3-stubs-lite[cloudformation]", + "types-requests", ] [tool.black] diff --git a/requirements/dev.txt b/requirements/dev.txt index 50155d9..2e2c73c 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,4 +1,4 @@ -aiofiles==0.8.0 +aiofiles==23.2.1 alembic==1.7.7 arrow==1.2.1 astroid==2.12.11 @@ -33,10 +33,10 @@ dill==0.3.5.1 exceptiongroup==1.0.0rc8 fastcounter==1.1.0 ghp-import==2.1.0 -greenlet==1.1.2 +greenlet>=1.1.2 httptools==0.4.0 idna==3.3 -importlib-metadata==4.11.3 +importlib-metadata>=4.11.3 iniconfig==1.1.1 invoke==1.7.1 ipaddress==1.0.23 @@ -59,7 +59,6 @@ multidict==6.0.2 mypy==0.950 mypy-extensions==0.4.3 openapi-schema-pydantic==1.2.4 -opensearch-py==2.3.1 orjson==3.6.5 packaging==21.3 pep517==0.12.0 @@ -102,6 +101,7 @@ rich==12.4.1 s3transfer==0.6.0 sanic==23.6.0 sanic-routing==23.6.0 +selenium==4.16.0 semver==2.13.0 sh==1.14.2 shortuuid==1.0.9 diff --git a/requirements/doc.txt b/requirements/doc.txt index 5cbc3e6..d4d8b1e 100644 --- a/requirements/doc.txt +++ b/requirements/doc.txt @@ -1,6 +1,6 @@ click==8.1.3 ghp-import==2.1.0 -importlib-metadata==4.11.3 +importlib-metadata>=4.11.3 jinja2==3.1.2 markdown==3.3.7 markupsafe==2.1.1 diff --git a/requirements/idea-administrator.txt b/requirements/idea-administrator.txt index d1494a3..f10fb34 100644 --- a/requirements/idea-administrator.txt +++ b/requirements/idea-administrator.txt @@ -1,4 +1,4 @@ -aiofiles==0.8.0 +aiofiles==23.2.1 alembic==1.8.0 arrow==1.2.1 attrs==21.4.0 @@ -27,7 +27,7 @@ decorator==5.1.1 defusedxml==0.7.1 exceptiongroup==1.0.0rc8 fastcounter==1.1.0 -greenlet==1.1.2 +greenlet>=1.1.2 httptools==0.4.0 idna==3.3 ipaddress==1.0.23 @@ -39,7 +39,6 @@ markupsafe==2.1.1 multidict==6.0.2 mypy==0.950 mypy-extensions==0.4.3 -opensearch-py==2.3.1 orjson==3.6.5 prettytable==3.3.0 prometheus-client==0.14.1 diff --git a/requirements/idea-cluster-manager.txt b/requirements/idea-cluster-manager.txt index 1c220e0..050e654 100644 --- a/requirements/idea-cluster-manager.txt +++ b/requirements/idea-cluster-manager.txt @@ -1,4 +1,4 @@ -aiofiles==0.8.0 +aiofiles==23.2.1 alembic==1.8.0 arrow==1.2.1 banal==1.0.6 @@ -16,7 +16,7 @@ cryptography==41.0.4 dataset==1.5.2 decorator==5.1.1 fastcounter==1.1.0 -greenlet==1.1.2 +greenlet>=1.1.2 httptools==0.4.0 idna==3.3 jinja2==3.1.2 @@ -27,7 +27,6 @@ markupsafe==2.1.1 multidict==6.0.2 mypy==0.950 mypy-extensions==0.4.3 -opensearch-py==2.3.1 orjson==3.6.5 prettytable==3.3.0 prometheus-client==0.14.1 diff --git a/requirements/idea-dev-lambda.in b/requirements/idea-dev-lambda.in index 98d751a..0d38bc5 100644 --- a/requirements/idea-dev-lambda.in +++ b/requirements/idea-dev-lambda.in @@ -1,2 +1 @@ -opensearch-py cryptography diff --git a/requirements/idea-dev-lambda.txt b/requirements/idea-dev-lambda.txt index 11b6fc3..45c7625 100644 --- a/requirements/idea-dev-lambda.txt +++ b/requirements/idea-dev-lambda.txt @@ -3,7 +3,6 @@ cffi==1.15.1 charset-normalizer==2.1.1 cryptography==41.0.4 idna==3.4 -opensearch-py==2.3.1 pycparser==2.21 requests==2.31.0 urllib3==1.26.18 diff --git a/requirements/idea-scheduler.txt b/requirements/idea-scheduler.txt index 009173c..ea70a58 100644 --- a/requirements/idea-scheduler.txt +++ b/requirements/idea-scheduler.txt @@ -1,4 +1,4 @@ -aiofiles==0.8.0 +aiofiles==23.2.1 alembic==1.7.7 arrow==1.2.1 banal==1.0.6 @@ -17,7 +17,7 @@ cryptography==41.0.4 dataset==1.5.2 decorator==5.1.1 fastcounter==1.1.0 -greenlet==1.1.2 +greenlet>=1.1.2 httptools==0.4.0 idna==3.3 jinja2==3.1.2 @@ -27,7 +27,6 @@ markupsafe==2.1.1 multidict==6.0.2 mypy==0.950 mypy-extensions==0.4.3 -opensearch-py==2.3.1 orjson==3.6.5 prettytable==3.3.0 prometheus-client==0.14.1 diff --git a/requirements/idea-sdk.in b/requirements/idea-sdk.in index 8b49317..b190515 100644 --- a/requirements/idea-sdk.in +++ b/requirements/idea-sdk.in @@ -23,7 +23,6 @@ fastcounter psutil blinker troposphere -opensearch-py Jinja2 cryptography PyJWT diff --git a/requirements/idea-virtual-desktop-controller.txt b/requirements/idea-virtual-desktop-controller.txt index 9a7b0e7..f229643 100644 --- a/requirements/idea-virtual-desktop-controller.txt +++ b/requirements/idea-virtual-desktop-controller.txt @@ -1,4 +1,4 @@ -aiofiles==0.8.0 +aiofiles==23.2.1 alembic==1.8.0 arrow==1.2.1 banal==1.0.6 @@ -16,7 +16,7 @@ cryptography==41.0.4 dataset==1.5.2 decorator==5.1.1 fastcounter==1.1.0 -greenlet==1.1.2 +greenlet>=1.1.2 httptools==0.4.0 idna==3.3 jinja2==3.1.2 @@ -26,7 +26,6 @@ markupsafe==2.1.1 multidict==6.0.2 mypy==0.950 mypy-extensions==0.4.3 -opensearch-py==2.3.1 orjson==3.6.5 prometheus-client==0.14.1 prompt-toolkit==3.0.29 diff --git a/requirements/tests.in b/requirements/tests.in index c13095c..3ea2b36 100644 --- a/requirements/tests.in +++ b/requirements/tests.in @@ -3,3 +3,4 @@ pytest-mock pytest-cov memory_profiler defusedxml +selenium diff --git a/requirements/tests.txt b/requirements/tests.txt index 2e05124..602a21b 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -11,3 +11,4 @@ pytest==7.2.0 pytest-cov==4.0.0 pytest-mock==3.10.0 tomli==2.0.1 +selenium==4.16.0 diff --git a/source/idea/README.md b/source/idea/README.md index 47982b4..3d745d3 100644 --- a/source/idea/README.md +++ b/source/idea/README.md @@ -45,3 +45,16 @@ Contains various utilities needed for tests. Contains everything needed for the vdc module (server set-up, api, etc). +## regional_pipeline_deployment.sh + +Bash script to deploy BI+RES stack to regional accounts. It take path to json file containing Account configs and local .cdk.json file that it iterates over and deploys the stack. The script updates the PortalDomainName , CustomDomainNameforWebApp and CustomDomainNameforVDI values based on the region of the the account before deploying the stack. The [json file](https://code.amazon.com/packages/RESRegionalTestingConfig/blobs/b888adf06fce77066496a48b92cc12b2d1b4ad66/--/configuration/regional_account_config.json#L1) is stored in RESRegionalTestingConfig package. Make sure you have jq and ada installed locally. Run mwinit before running the script so the script is able to fetch credentials for each account and deploy the stack. + +Steps to run the script: + +* Install `jq` and `ada` if you don’t have already. `ada` can be installed by running `toolbox install ada` +* Run `mwinit --aea` +* Run the following command that will trigger the deployment. Pass the path to the config file and cdk json file to the script +``` +./regional_pipeline_deployment.sh +``` + diff --git a/source/idea/app.py b/source/idea/app.py index 42fadad..ccff236 100644 --- a/source/idea/app.py +++ b/source/idea/app.py @@ -4,8 +4,11 @@ import aws_cdk as cdk from cdk_bootstrapless_synthesizer import BootstraplessStackSynthesizer +from idea.batteries_included.parameters.parameters import BIParameters +from idea.batteries_included.stack import BiStack from idea.constants import ( ARTIFACTS_BUCKET_PREFIX_NAME, + BATTERIES_INCLUDED_STACK_NAME, INSTALL_STACK_NAME, PIPELINE_STACK_NAME, ) @@ -46,5 +49,14 @@ def main() -> None: registry_name=registry_name, synthesizer=install_synthesizer, ) + context_bi = app.node.try_get_context("batteries_included") + if context_bi and context_bi.lower() == "true": + bi_stack_template_url = app.node.try_get_context("BIStackTemplateURL") + BiStack( + app, + BATTERIES_INCLUDED_STACK_NAME, + template_url=bi_stack_template_url, + parameters=BIParameters.from_context(app), + ) app.synth() diff --git a/source/idea/batteries_included/__init__.py b/source/idea/batteries_included/__init__.py new file mode 100644 index 0000000..7b229b1 --- /dev/null +++ b/source/idea/batteries_included/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +__name__ = "idea.batteries_included" +__version__ = "2023.12" diff --git a/source/idea/batteries_included/parameters/__init__.py b/source/idea/batteries_included/parameters/__init__.py new file mode 100644 index 0000000..0bae1b8 --- /dev/null +++ b/source/idea/batteries_included/parameters/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +__name__ = "idea.batteries_included.parameters" +__version__ = "2023.12" diff --git a/source/idea/batteries_included/parameters/common.py b/source/idea/batteries_included/parameters/common.py new file mode 100644 index 0000000..da640e9 --- /dev/null +++ b/source/idea/batteries_included/parameters/common.py @@ -0,0 +1,167 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from typing import Any + +from idea.infrastructure.install.parameters.base import Attributes, Base, Key + + +class CommonKey(Key): + CLUSTER_NAME = "EnvironmentName" + ADMIN_EMAIL = "AdministratorEmail" + INFRASTRUCTURE_HOST_AMI = "InfrastructureHostAMI" + SSH_KEY_PAIR = "SSHKeyPair" + CLIENT_IP = "ClientIp" + CLIENT_PREFIX_LIST = "ClientPrefixList" + VPC_ID = "VpcId" + LOAD_BALANCER_SUBNETS = "LoadBalancerSubnets" + INFRASTRUCTURE_HOST_SUBNETS = "InfrastructureHostSubnets" + VDI_SUBNETS = "VdiSubnets" + IS_LOAD_BALANCER_INTERNET_FACING = "IsLoadBalancerInternetFacing" + + +@dataclass +class CommonParameters(Base): + ssh_key_pair_name: str = Base.parameter( + Attributes( + id=CommonKey.SSH_KEY_PAIR, + type="AWS::EC2::KeyPair::KeyName", + description=( + "Default SSH keys, registered in EC2 that can be used to " + "SSH into environment instances." + ), + allowed_pattern=".+", + ) + ) + + client_ip: str = Base.parameter( + Attributes( + id=CommonKey.CLIENT_IP, + type="String", + description=( + "Default IP(s) allowed to directly access the Web UI and SSH " + "into the bastion host. We recommend that you restrict it with " + "your own IP/subnet (x.x.x.x/32 for your own ip or x.x.x.x/24 " + "for range. Replace x.x.x.x with your own PUBLIC IP. You can get " + "your public IP using tools such as https://ifconfig.co/)." + ), + allowed_pattern="(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})/(\d{1,2})", + constraint_description=( + "ClientIP must be a valid IP or network range of the form x.x.x.x/x. " + "specify your IP/NETMASK (e.g x.x.x/32 or x.x.x.x/24 for subnet range)" + ), + ) + ) + + client_prefix_list: str = Base.parameter( + Attributes( + id=CommonKey.CLIENT_PREFIX_LIST, + type="String", + description=( + "(Optional) A prefix list that covers IPs allowed to directly access the Web UI and SSH " + "into the bastion host." + ), + allowed_pattern="^(pl-[a-z0-9]{8,20})?$", + constraint_description=( + "Must be a valid prefix list ID, which starts with 'pl-'. These can be " + "found either by navigating to the VPC console, or by calling ec2:DescribePrefixLists" + ), + ) + ) + + cluster_name: str = Base.parameter( + Attributes( + id=CommonKey.CLUSTER_NAME, + type="String", + description='Provide name of the Environment, the name of the environment must start with "res-" and should be less than or equal to 11 characters.', + allowed_pattern="res-[A-Za-z\-\_0-9]{0,7}", + constraint_description='The name of the environment must start with "res-" and should be less than or equal to 11 characters.', + ) + ) + + administrator_email: str = Base.parameter( + Attributes( + id=CommonKey.ADMIN_EMAIL, + allowed_pattern=r"(^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$|^$)", + constraint_description="AdministratorEmail must be a valid email id", + ) + ) + + infrastructure_host_ami: str = Base.parameter( + Attributes( + id=CommonKey.INFRASTRUCTURE_HOST_AMI, + type="String", + allowed_pattern="^(ami-[0-9a-f]{8,17})?$", + description="(Optional) You may provide a custom AMI id to use for all the infrastructure hosts. The current supported base OS is Amazon Linux 2.", + constraint_description="The AMI id must begin with 'ami-' followed by only letters (a-f) or numbers(0-9).", + ) + ) + + vpc_id: str = Base.parameter( + Attributes( + id=CommonKey.VPC_ID, + type="AWS::SSM::Parameter::Value", + description="Please provide parameter store path to contain VpcId.", + ) + ) + + load_balancer_subnets: list[str] = Base.parameter( + Attributes( + id=CommonKey.LOAD_BALANCER_SUBNETS, + type="AWS::SSM::Parameter::Value>", + description="Provide parameter store path to contain at least 2 subnet IDs. Select at least 2 subnets from different Availability Zones. For deployments that need restricted internet access, select private subnets. For deployments that need internet access, select public subnets.", + allowed_pattern=".+", + ) + ) + + infrastructure_host_subnets: list[str] = Base.parameter( + Attributes( + id=CommonKey.INFRASTRUCTURE_HOST_SUBNETS, + type="AWS::SSM::Parameter::Value>", + description="Provide parameter store path to contain at least 2 subnet IDs. Select at least 2 private subnets from different Availability Zones.", + allowed_pattern=".+", + ) + ) + + vdi_subnets: list[str] = Base.parameter( + Attributes( + id=CommonKey.VDI_SUBNETS, + type="AWS::SSM::Parameter::Value>", + description="Provide parameter store path to contain at least 2 subnet IDs. Select at least 2 subnets from different Availability Zones. For deployments that need restricted internet access, select private subnets. For deployments that need internet access, select public subnets", + allowed_pattern=".+", + ) + ) + + is_load_balancer_internet_facing: str = Base.parameter( + Attributes( + id=CommonKey.IS_LOAD_BALANCER_INTERNET_FACING, + type="String", + description="Select true to deploy internet facing load balancer (Requires public subnets for load balancer). For deployments that need restricted internet access, select false.", + allowed_values=["true", "false"], + ) + ) + + +class CommonParameterGroups: + parameter_group_for_environment_and_installer_details: dict[str, Any] = { + "Label": {"default": "Environment and installer details"}, + "Parameters": [ + CommonKey.CLUSTER_NAME, + CommonKey.ADMIN_EMAIL, + CommonKey.INFRASTRUCTURE_HOST_AMI, + CommonKey.SSH_KEY_PAIR, + CommonKey.CLIENT_IP, + CommonKey.CLIENT_PREFIX_LIST, + ], + } + + parameter_group_for_network_configuration: dict[str, Any] = { + "Label": {"default": "Network configuration for the RES environment"}, + "Parameters": [ + CommonKey.VPC_ID, + CommonKey.IS_LOAD_BALANCER_INTERNET_FACING, + CommonKey.LOAD_BALANCER_SUBNETS, + CommonKey.INFRASTRUCTURE_HOST_SUBNETS, + CommonKey.VDI_SUBNETS, + ], + } diff --git a/source/idea/batteries_included/parameters/customdomain.py b/source/idea/batteries_included/parameters/customdomain.py new file mode 100644 index 0000000..dae838e --- /dev/null +++ b/source/idea/batteries_included/parameters/customdomain.py @@ -0,0 +1,110 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any + +from idea.infrastructure.install.parameters.base import Attributes, Base, Key + + +class CustomDomainKey(Key): + PORTAL_DOMAIN_NAME = "PortalDomainName" + CUSTOM_DOMAIN_NAME_FOR_WEB_APP = "CustomDomainNameforWebApp" + CUSTOM_DOMAIN_NAME_FOR_VDI = "CustomDomainNameforVDI" + ACM_CERTIFICATE_ARN_FOR_WEB_APP = "ACMCertificateARNforWebApp" + CERTIFICATE_SECRET_ARN_FOR_VDI = "CertificateSecretARNforVDI" + PRIVATE_KEY_SECRET_ARN_FOR_VDI = "PrivateKeySecretARNforVDI" + + +@dataclass +class CustomDomainParameters(Base): + portal_domain_name: str = Base.parameter( + Attributes( + id=CustomDomainKey.PORTAL_DOMAIN_NAME, + type="String", + description=( + "Domain Name for web portal domain that lives in Route53 in account " + "(may be different from the Active Directory domain). " + "Used to generate certificates." + ), + allowed_pattern="^([a-zA-Z0-9]+[\\.-])+([a-zA-Z0-9])+$", + ) + ) + + custom_domain_name_for_web_ui: str = Base.parameter( + Attributes( + id=CustomDomainKey.CUSTOM_DOMAIN_NAME_FOR_WEB_APP, + type="String", + description=( + "You may provide a custom domain name for web user interface, instead of default under amazonaws.com." + ), + ) + ) + + custom_domain_name_for_vdi: str = Base.parameter( + Attributes( + id=CustomDomainKey.CUSTOM_DOMAIN_NAME_FOR_VDI, + type="String", + description=( + "You may provide a custom domain name for VDI, instead of default under amazonaws.com." + ), + ) + ) + + acm_certificate_arn_for_web_ui: str = Base.parameter( + Attributes( + id=CustomDomainKey.ACM_CERTIFICATE_ARN_FOR_WEB_APP, + type="AWS::SSM::Parameter::Value", + description=( + ( + "If you have provided a custom domain name for Web UI above then " + "please provide parameter store path to contain the Amazon Resource Name (ARN) for the " + "corresponding certificate stored in Amazon Certificate Manager (ACM)." + ) + ), + ) + ) + + certificate_secret_arn_for_vdi_domain_name: str = Base.parameter( + Attributes( + id=CustomDomainKey.CERTIFICATE_SECRET_ARN_FOR_VDI, + type="AWS::SSM::Parameter::Value", + description=( + ( + "If you have provided a custom domain name for VDI above then " + "Please provide parameter store path to contain the Amazon Resource Name (ARN) for the " + "certificate secret stored in AWS Secret Manager (ASM)." + ) + ), + ) + ) + + private_key_secret_arn_for_vdi_domain_name: str = Base.parameter( + Attributes( + id=CustomDomainKey.PRIVATE_KEY_SECRET_ARN_FOR_VDI, + type="AWS::SSM::Parameter::Value", + description=( + ( + "If you have provided custom domain name for VDI above then " + "please provide parameter store path to contain the Amazon Resource Name (ARN) for the " + "private key of the certificate stored in AWS Secret Manager (ASM)." + ) + ), + ) + ) + + +class CustomDomainParameterGroups: + parameter_group_for_custom_domain: dict[str, Any] = { + "Label": { + "default": "Custom domain details, only needed if you would like to use a custom domain" + }, + "Parameters": [ + CustomDomainKey.PORTAL_DOMAIN_NAME, + CustomDomainKey.CUSTOM_DOMAIN_NAME_FOR_WEB_APP, + CustomDomainKey.CUSTOM_DOMAIN_NAME_FOR_VDI, + CustomDomainKey.ACM_CERTIFICATE_ARN_FOR_WEB_APP, + CustomDomainKey.CERTIFICATE_SECRET_ARN_FOR_VDI, + CustomDomainKey.PRIVATE_KEY_SECRET_ARN_FOR_VDI, + ], + } diff --git a/source/idea/batteries_included/parameters/directoryservice.py b/source/idea/batteries_included/parameters/directoryservice.py new file mode 100644 index 0000000..be66641 --- /dev/null +++ b/source/idea/batteries_included/parameters/directoryservice.py @@ -0,0 +1,154 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Code changes made to this file must be replicated in 'source/idea/infrastructure/install/parameters/directoryservice' too + +from dataclasses import dataclass +from typing import Any, Optional + +from idea.infrastructure.install.parameters.base import Attributes, Base, Key + + +class DirectoryServiceKey(Key): + NAME = "ActiveDirectoryName" + LDAP_BASE = "LDAPBase" + AD_SHORT_NAME = "ADShortName" + LDAP_CONNECTION_URI = "LDAPConnectionURI" + USERS_OU = "UsersOU" + GROUPS_OU = "GroupsOU" + COMPUTERS_OU = "ComputersOU" + SUDOERS_OU = "SudoersOU" + SUDOERS_GROUP_NAME = "SudoersGroupName" + ROOT_USERNAME = "ServiceAccountUsername" + ROOT_PASSWORD = "ServiceAccountPassword" + DOMAIN_TLS_CERTIFICATE_SECRET_ARN = "DomainTLSCertificateSecretArn" + ENABLE_LDAP_ID_MAPPING = "EnableLdapIDMapping" + + +@dataclass +class DirectoryServiceParameters(Base): + name: str = Base.parameter( + Attributes( + id=DirectoryServiceKey.NAME, + type="AWS::SSM::Parameter::Value", + description=( + "Please provide parameter store path to contain the Fully Qualified Domain Name (FQDN) for your Active Directory. " + ), + ) + ) + ldap_base: str = Base.parameter( + Attributes( + id=DirectoryServiceKey.LDAP_BASE, + type="AWS::SSM::Parameter::Value", + description=( + "Please provide parameter store path to contain the Active Directory base string Distinguished Name (DN). " + ), + ) + ) + ad_short_name: str = Base.parameter( + Attributes( + id=DirectoryServiceKey.AD_SHORT_NAME, + type="AWS::SSM::Parameter::Value", + description="Please provide parameter store path to contain the short name in Active directory", + ) + ) + ldap_connection_uri: str = Base.parameter( + Attributes( + id=DirectoryServiceKey.LDAP_CONNECTION_URI, + type="AWS::SSM::Parameter::Value", + description="Please provide parameter store path to contain the active directory connection URI (e.g. ldap://www.example.com)", + ) + ) + users_ou: str = Base.parameter( + Attributes( + id=DirectoryServiceKey.USERS_OU, + type="AWS::SSM::Parameter::Value", + description=( + "Please provide parameter store path to contain Users Organization Unit in your active directory " + ), + ) + ) + groups_ou: str = Base.parameter( + Attributes( + id=DirectoryServiceKey.GROUPS_OU, + type="AWS::SSM::Parameter::Value", + description="Please provide parameter store path to contain user groups Oganization Unit in your active directory", + ) + ) + computers_ou: str = Base.parameter( + Attributes( + id=DirectoryServiceKey.COMPUTERS_OU, + type="AWS::SSM::Parameter::Value", + description="Please provide parameter store path to contain Organization Unit for compute and storage servers in your active directory", + ) + ) + sudoers_ou: str = Base.parameter( + Attributes( + id=DirectoryServiceKey.SUDOERS_OU, + type="AWS::SSM::Parameter::Value", + description="Please provide parameter store path to contain Organization Unit for users who will be able to sudo in your active directory", + ) + ) + sudoers_group_name: str = Base.parameter( + Attributes( + id=DirectoryServiceKey.SUDOERS_GROUP_NAME, + type="AWS::SSM::Parameter::Value", + description="Please provide parameter store path to contain group name of users who will be able to sudo in your active directory", + ) + ) + root_username: str = Base.parameter( + Attributes( + id=DirectoryServiceKey.ROOT_USERNAME, + type="AWS::SSM::Parameter::Value", + description="Please provide parameter store path to contain Directory Service Root (Service Account) username", + no_echo=True, + ) + ) + root_password: str = Base.parameter( + Attributes( + id=DirectoryServiceKey.ROOT_PASSWORD, + type="String", + description="Please provide Directory Service Root (Service Account) password", + no_echo=True, + ) + ) + domain_tls_certificate_secret_arn: str = Base.parameter( + Attributes( + id=DirectoryServiceKey.DOMAIN_TLS_CERTIFICATE_SECRET_ARN, + type="String", + description="(Optional) Domain TLS Certificate Secret ARN", + ) + ) + enable_ldap_id_mapping: str = Base.parameter( + Attributes( + id=DirectoryServiceKey.ENABLE_LDAP_ID_MAPPING, + type="String", + description="Set to False to use the uidNumbers and gidNumbers for users and group from the provided AD. Otherwise set to True.", + allowed_values=["True", "False"], + ) + ) + + # These will be populated after the secrets are created from the above parameters + root_username_secret_arn: Optional[str] = None + root_password_secret_arn: Optional[str] = None + + +class DirectoryServiceParameterGroups: + parameter_group_for_directory_service: dict[str, Any] = { + "Label": {"default": "Active Directory details"}, + "Parameters": [ + DirectoryServiceKey.NAME, + DirectoryServiceKey.AD_SHORT_NAME, + DirectoryServiceKey.LDAP_BASE, + DirectoryServiceKey.LDAP_CONNECTION_URI, + DirectoryServiceKey.ROOT_USERNAME, + DirectoryServiceKey.ROOT_PASSWORD, + DirectoryServiceKey.USERS_OU, + DirectoryServiceKey.GROUPS_OU, + DirectoryServiceKey.SUDOERS_OU, + DirectoryServiceKey.SUDOERS_GROUP_NAME, + DirectoryServiceKey.COMPUTERS_OU, + DirectoryServiceKey.DOMAIN_TLS_CERTIFICATE_SECRET_ARN, + DirectoryServiceKey.ENABLE_LDAP_ID_MAPPING, + ], + } diff --git a/source/idea/batteries_included/parameters/parameters.py b/source/idea/batteries_included/parameters/parameters.py new file mode 100644 index 0000000..a4119f4 --- /dev/null +++ b/source/idea/batteries_included/parameters/parameters.py @@ -0,0 +1,51 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from typing import Any + +from idea.batteries_included.parameters import ( + common, + customdomain, + directoryservice, + shared_storage, +) + + +@dataclass +class BIParameters( + common.CommonParameters, + customdomain.CustomDomainParameters, + directoryservice.DirectoryServiceParameters, + shared_storage.SharedStorageParameters, +): + """ + This is where all the different categories of parameters are combined + using inheritance. + """ + + pass + + +class AllBIParameterGroups( + common.CommonParameterGroups, + customdomain.CustomDomainParameterGroups, + directoryservice.DirectoryServiceParameterGroups, + shared_storage.SharedStorageParameterGroups, +): + """ + All the parameter groups are collated here + """ + + @classmethod + def template_metadata(cls) -> dict[str, Any]: + return { + "AWS::CloudFormation::Interface": { + "ParameterGroups": [ + common.CommonParameterGroups.parameter_group_for_environment_and_installer_details, + common.CommonParameterGroups.parameter_group_for_network_configuration, + directoryservice.DirectoryServiceParameterGroups.parameter_group_for_directory_service, + shared_storage.SharedStorageParameterGroups.parameter_group_for_shared_storage, + customdomain.CustomDomainParameterGroups.parameter_group_for_custom_domain, + ] + } + } diff --git a/source/idea/batteries_included/parameters/shared_storage.py b/source/idea/batteries_included/parameters/shared_storage.py new file mode 100644 index 0000000..83dbafa --- /dev/null +++ b/source/idea/batteries_included/parameters/shared_storage.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any + +from idea.infrastructure.install.parameters.base import Attributes, Base, Key + + +class SharedStorageKey(Key): + SHARED_HOME_FILESYSTEM_ID = "SharedHomeFileSystemId" + + +@dataclass +class SharedStorageParameters(Base): + existing_home_fs_id: str = Base.parameter( + Attributes( + id=SharedStorageKey.SHARED_HOME_FILESYSTEM_ID, + type="AWS::SSM::Parameter::Value", + description=( + "Please provide parameter store path to contain id of a home file system to be mounted on all VDI instances." + ), + ) + ) + + +class SharedStorageParameterGroups: + parameter_group_for_shared_storage: dict[str, Any] = { + "Label": {"default": "Shared Storage details"}, + "Parameters": [ + SharedStorageKey.SHARED_HOME_FILESYSTEM_ID, + ], + } diff --git a/source/idea/batteries_included/stack.py b/source/idea/batteries_included/stack.py new file mode 100644 index 0000000..840b9d6 --- /dev/null +++ b/source/idea/batteries_included/stack.py @@ -0,0 +1,216 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from aws_cdk import CfnStack, Stack, aws_ssm +from constructs import Construct + +from idea.batteries_included.parameters.parameters import BIParameters + + +class BiStack(Stack): + def __init__( + self, + scope: Construct, + stack_id: str, + template_url: str, + parameters: BIParameters = BIParameters(), + ) -> None: + super().__init__(scope, stack_id) + + self.bi_stack = CfnStack( + self, + "RESExternal", + parameters={ + "PortalDomainName": str(parameters.portal_domain_name), + "Keypair": str(parameters.ssh_key_pair_name), + "EnvironmentName": str(parameters.cluster_name), + "AdminPassword": str(parameters.root_password), + "ServiceAccountPassword": str(parameters.root_password), + "ClientIpCidr": str(parameters.client_ip), + "ClientPrefixList": str(parameters.client_prefix_list), + }, + template_url=template_url, + ) + + self.vpc_id = aws_ssm.StringParameter( + self, + id=str(parameters.vpc_id), + parameter_name=str(parameters.vpc_id), + string_value=self.bi_stack.get_att("Outputs.VpcId").to_string(), + ) + + self.vpc_id.node.add_dependency(self.bi_stack) + + self.load_balancer_subnets = aws_ssm.StringListParameter( + self, + id=str(parameters.load_balancer_subnets), + parameter_name=str(parameters.load_balancer_subnets), + string_list_value=self.bi_stack.get_att("Outputs.PublicSubnets") + .to_string() + .split(","), + ) + + self.load_balancer_subnets.node.add_dependency(self.bi_stack) + + self.infrastructure_host_subnets = aws_ssm.StringListParameter( + self, + id=str(parameters.infrastructure_host_subnets), + parameter_name=str(parameters.infrastructure_host_subnets), + string_list_value=self.bi_stack.get_att("Outputs.PrivateSubnets") + .to_string() + .split(","), + ) + + self.infrastructure_host_subnets.node.add_dependency(self.bi_stack) + + self.vdi_subnets = aws_ssm.StringListParameter( + self, + id=str(parameters.vdi_subnets), + parameter_name=str(parameters.vdi_subnets), + string_list_value=self.bi_stack.get_att("Outputs.PrivateSubnets") + .to_string() + .split(","), + ) + + self.vdi_subnets.node.add_dependency(self.bi_stack) + + self.active_directory_name = aws_ssm.StringParameter( + self, + id=str(parameters.name), + parameter_name=str(parameters.name), + string_value=self.bi_stack.get_att( + "Outputs.ActiveDirectoryName" + ).to_string(), + ) + + self.active_directory_name.node.add_dependency(self.bi_stack) + + self.ad_short_name = aws_ssm.StringParameter( + self, + id=str(parameters.ad_short_name), + parameter_name=str(parameters.ad_short_name), + string_value=self.bi_stack.get_att("Outputs.ADShortName").to_string(), + ) + + self.ad_short_name.node.add_dependency(self.bi_stack) + + self.ldap_base = aws_ssm.StringParameter( + self, + id=str(parameters.ldap_base), + parameter_name=str(parameters.ldap_base), + string_value=self.bi_stack.get_att("Outputs.LDAPBase").to_string(), + ) + + self.ldap_base.node.add_dependency(self.bi_stack) + + self.ldap_connection_uri = aws_ssm.StringParameter( + self, + id=str(parameters.ldap_connection_uri), + parameter_name=str(parameters.ldap_connection_uri), + string_value=self.bi_stack.get_att("Outputs.LDAPConnectionURI").to_string(), + ) + + self.ldap_connection_uri.node.add_dependency(self.bi_stack) + + self.acm_certificate_arn_for_web_ui = aws_ssm.StringParameter( + self, + id=str(parameters.acm_certificate_arn_for_web_ui), + parameter_name=str(parameters.acm_certificate_arn_for_web_ui), + string_value=self.bi_stack.get_att( + "Outputs.ACMCertificateARNforWebApp" + ).to_string(), + ) + + self.acm_certificate_arn_for_web_ui.node.add_dependency(self.bi_stack) + + self.private_key_secret_arn_for_vdi_domain_name = aws_ssm.StringParameter( + self, + id=str(parameters.private_key_secret_arn_for_vdi_domain_name), + parameter_name=str(parameters.private_key_secret_arn_for_vdi_domain_name), + string_value=self.bi_stack.get_att( + "Outputs.PrivateKeySecretArn" + ).to_string(), + ) + + self.private_key_secret_arn_for_vdi_domain_name.node.add_dependency( + self.bi_stack + ) + + self.certificate_secret_arn_for_vdi_domain_name = aws_ssm.StringParameter( + self, + id=str(parameters.certificate_secret_arn_for_vdi_domain_name), + parameter_name=str(parameters.certificate_secret_arn_for_vdi_domain_name), + string_value=self.bi_stack.get_att( + "Outputs.CertificateSecretArn" + ).to_string(), + ) + + self.certificate_secret_arn_for_vdi_domain_name.node.add_dependency( + self.bi_stack + ) + + self.root_username = aws_ssm.StringParameter( + self, + id=str(parameters.root_username), + parameter_name=str(parameters.root_username), + string_value=self.bi_stack.get_att( + "Outputs.ServiceAccountUsername" + ).to_string(), + ) + + self.root_username.node.add_dependency(self.bi_stack) + + self.users_ou = aws_ssm.StringParameter( + self, + id=str(parameters.users_ou), + parameter_name=str(parameters.users_ou), + string_value=self.bi_stack.get_att("Outputs.UsersOU").to_string(), + ) + + self.users_ou.node.add_dependency(self.bi_stack) + + self.groups_ou = aws_ssm.StringParameter( + self, + id=str(parameters.groups_ou), + parameter_name=str(parameters.groups_ou), + string_value=self.bi_stack.get_att("Outputs.GroupsOU").to_string(), + ) + + self.groups_ou.node.add_dependency(self.bi_stack) + + self.sudoers_ou = aws_ssm.StringParameter( + self, + id=str(parameters.sudoers_ou), + parameter_name=str(parameters.sudoers_ou), + string_value=self.bi_stack.get_att("Outputs.SudoersOU").to_string(), + ) + + self.sudoers_ou.node.add_dependency(self.bi_stack) + + self.sudoers_group_name = aws_ssm.StringParameter( + self, + id=str(parameters.sudoers_group_name), + parameter_name=str(parameters.sudoers_group_name), + string_value="RESAdministrators", + ) + + self.sudoers_group_name.node.add_dependency(self.bi_stack) + + self.computers_ou = aws_ssm.StringParameter( + self, + id=str(parameters.computers_ou), + parameter_name=str(parameters.computers_ou), + string_value=self.bi_stack.get_att("Outputs.ComputersOU").to_string(), + ) + + self.computers_ou.node.add_dependency(self.bi_stack) + + self.existing_home_fs_id = aws_ssm.StringParameter( + self, + id=str(parameters.existing_home_fs_id), + parameter_name=str(parameters.existing_home_fs_id), + string_value=self.bi_stack.get_att( + "Outputs.SharedHomeFilesystemId" + ).to_string(), + ) + + self.existing_home_fs_id.node.add_dependency(self.bi_stack) diff --git a/source/idea/constants.py b/source/idea/constants.py index 832a465..74d86a8 100644 --- a/source/idea/constants.py +++ b/source/idea/constants.py @@ -6,3 +6,4 @@ INSTALL_STACK_NAME = "ResearchAndEngineeringStudio" ARTIFACTS_BUCKET_PREFIX_NAME = "research-engineering-studio" DEFAULT_ECR_REPOSITORY_NAME = "RESBuildImageRepository" +BATTERIES_INCLUDED_STACK_NAME = "BatteriesIncluded" diff --git a/source/idea/idea-administrator/resources/config/templates/analytics/settings.yml b/source/idea/idea-administrator/resources/config/templates/analytics/settings.yml deleted file mode 100644 index fe9483f..0000000 --- a/source/idea/idea-administrator/resources/config/templates/analytics/settings.yml +++ /dev/null @@ -1,29 +0,0 @@ -# Configure your AWS OpenSearch/Kibana options below -opensearch: - use_existing: {{use_existing_opensearch_cluster | lower}} - {%- if use_existing_opensearch_cluster %} - domain_vpc_endpoint_url: "{{opensearch_domain_endpoint}}" - {%- else %} - data_node_instance_type: "m5.large.search" # instance type for opensearch data nodes - data_nodes: 2 # number of data nodes for elasticsearch - ebs_volume_size: 100 # ebs volume size attached to data nodes - removal_policy: "DESTROY" # RETAIN will preserve the cluster even if you delete the stack. - node_to_node_encryption: true - logging: - app_log_enabled: true # Specify if Amazon OpenSearch Service application logging should be set up. - slow_index_log_enabled: true # Log Amazon OpenSearch Service audit logs to this log group - slow_search_log_enabled: true # Specify if slow search logging should be set up. - {%- endif %} - kms_key_id: {{ kms_key_id if kms_key_id else '~' }} # Specify your own CMK to encrypt OpenSearch domain. If set to ~ encryption will be managed by the default AWS key - default_number_of_shards: 2 - default_number_of_replicas: 1 - - endpoints: - external: - priority: 16 - path_patterns: ['/_dashboards*'] - -kinesis: - shard_count: 2 - stream_mode: PROVISIONED - kms_key_id: {{ kms_key_id if kms_key_id else '~' }} # Specify your own CMK to encrypt Kinesis stream. If set to ~ encryption will be managed by the default AWS key diff --git a/source/idea/idea-administrator/resources/config/templates/cluster-manager/settings.yml b/source/idea/idea-administrator/resources/config/templates/cluster-manager/settings.yml index de28b6f..2149984 100644 --- a/source/idea/idea-administrator/resources/config/templates/cluster-manager/settings.yml +++ b/source/idea/idea-administrator/resources/config/templates/cluster-manager/settings.yml @@ -74,6 +74,7 @@ ec2: min_capacity: 1 max_capacity: 3 cooldown_minutes: 5 + default_instance_warmup: 15 new_instances_protected_from_scale_in: false elb_healthcheck: # Specifies the time in minutes Auto Scaling waits before checking the health status of an EC2 instance that has come into service. diff --git a/source/idea/idea-administrator/resources/config/templates/cluster/settings.yml b/source/idea/idea-administrator/resources/config/templates/cluster/settings.yml index 78d45f7..0b53534 100644 --- a/source/idea/idea-administrator/resources/config/templates/cluster/settings.yml +++ b/source/idea/idea-administrator/resources/config/templates/cluster/settings.yml @@ -44,6 +44,10 @@ network: {{ utils.to_yaml(private_subnet_ids) | indent(4) }} public_subnets: {{ utils.to_yaml(public_subnet_ids) | indent(4) }} + load_balancer_subnets: + {{ utils.to_yaml(load_balancer_subnet_ids) | indent(4) }} + infrastructure_host_subnets: + {{ utils.to_yaml(infrastructure_host_subnet_ids) | indent(4) }} {%- else %} max_azs: 3 nat_gateways: 1 @@ -143,12 +147,12 @@ load_balancers: custom_dns_name: {{alb_custom_dns_name if alb_custom_dns_name else '~'}} # SSL/TLS Policy on External Load Balancer # For a list of policies - consult the documentation at https://docs.aws.amazon.com/elasticloadbalancing/latest/application/create-https-listener.html#describe-ssl-policies - ssl_policy: ELBSecurityPolicy-FS-1-2-Res-2020-10 + ssl_policy: ELBSecurityPolicy-TLS13-1-2-2021-06 internal_alb: access_logs: true # SSL/TLS Policy on Internal Load Balancer # For a list of policies - consult the documentation at https://docs.aws.amazon.com/elasticloadbalancing/latest/application/create-https-listener.html#describe-ssl-policies - ssl_policy: ELBSecurityPolicy-FS-1-2-Res-2020-10 + ssl_policy: ELBSecurityPolicy-TLS13-1-2-2021-06 cloudwatch_logs: # enable or disable publishing logs to cloudwatch across the cluster. diff --git a/source/idea/idea-administrator/resources/config/templates/global-settings/settings.yml b/source/idea/idea-administrator/resources/config/templates/global-settings/settings.yml index 04a701c..3792307 100644 --- a/source/idea/idea-administrator/resources/config/templates/global-settings/settings.yml +++ b/source/idea/idea-administrator/resources/config/templates/global-settings/settings.yml @@ -2,8 +2,6 @@ module_sets: default: cluster: module_id: cluster - analytics: - module_id: analytics identity-provider: module_id: identity-provider directoryservice: @@ -543,78 +541,66 @@ package_config: {%- if 'amazonlinux2' in supported_base_os or 'rhel7' in supported_base_os or 'centos7' in supported_base_os %} al2_rhel_centos7: version: 2023.0.531-1.el7.x86_64 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Gateway/nice-dcv-connection-gateway-2023.0.531-1.el7.x86_64.rpm - sha256sum: 7abc061b94807510c8284849fd2545b170dc8d47954f1afa9f38966f47ce15ae + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway-el7.x86_64.rpm + sha256sum: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway-el7.x86_64.rpm.sha256sum {%- endif %} {%- if 'rhel8' in supported_base_os or 'centos8' in supported_base_os or 'rocky8' in supported_base_os %} rhel_centos_rocky8: version: 2023.0.531-1.el8.x86_64 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Gateway/nice-dcv-connection-gateway-2023.0.531-1.el8.x86_64.rpm - sha256sum: 67a71866996f1bd32b22d9cadb695ca357d165270e2d4373e5d4c9e2a56cb974 + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway-el8.x86_64.rpm + sha256sum: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway-el8.x86_64.rpm.sha256sum {%- endif %} {%- if 'rhel9' in supported_base_os or 'centos9' in supported_base_os or 'rocky9' in supported_base_os %} rhel_centos_rocky9: version: 2023.0.531-1.el9.x86_64 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Gateway/nice-dcv-connection-gateway-2023.0.531-1.el9.x86_64.rpm - sha256sum: b58748fc3f726c6073c3c6d82688d5360e05a0d42244f420ba3ba9bdb92b0655 + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway-el9.x86_64.rpm + sha256sum: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway-el9.x86_64.rpm.sha256sum {%- endif %} ubuntu: - {%- if 'ubuntu1804' in supported_base_os %} - ubuntu1804: - version: 2023.0.531-1_amd64.ubuntu1804 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Gateway/nice-dcv-connection-gateway_2023.0.531-1_amd64.ubuntu1804.deb - sha256sum: b710ecb96e350ea3f8d53926c60f8001441c1f4b0db4f476ca282c720d6a265e - {%- endif %} {%- if 'ubuntu2004' in supported_base_os %} ubuntu2004: version: 2023.0.531-1_amd64.ubuntu2004 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Gateway/nice-dcv-connection-gateway_2023.0.531-1_amd64.ubuntu2004.deb - sha256sum: 87d8e3269f3bcfc648f80f7117e8d486ee7f6605449dc8c73f0a394a1c1035be + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway_amd64.ubuntu2004.deb + sha256sum: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway_amd64.ubuntu2004.deb.sha256sum {%- endif %} {%- if 'ubuntu2204' in supported_base_os %} ubuntu2204: version: 2023.0.531-1_amd64.ubuntu2204 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Gateway/nice-dcv-connection-gateway_2023.0.531-1_amd64.ubuntu2204.deb - sha256sum: b5c48779cec33a9fe26829ca0f28aa15702151b8ebc42d354f95254db5dedcb3 + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway_amd64.ubuntu2204.deb + sha256sum: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway_amd64.ubuntu2204.deb.sha256sum {%- endif %} aarch64: linux: {%- if 'amazonlinux2' in supported_base_os or 'rhel7' in supported_base_os or 'centos7' in supported_base_os %} al2_rhel_centos7: version: 2023.0.531-1.el7.aarch64 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Gateway/nice-dcv-connection-gateway-2023.0.531-1.el7.aarch64.rpm - sha256sum: 77d29efa59e08d669058037a2a8aba6b0d51822b8771e2abdb42a0c9f425685f + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway-el7.aarch64.rpm + sha256sum: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway-el7.aarch64.rpm.sha256sum {%- endif %} {%- if 'rhel8' in supported_base_os or 'centos8' in supported_base_os or 'rocky8' in supported_base_os %} rhel_centos_rocky8: version: 2023.0.531-1.el8.aarch64 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Gateway/nice-dcv-connection-gateway-2023.0.531-1.el8.aarch64.rpm - sha256sum: b55f344eff9ed07e3d706da24f369f7e93d2e106edc4b3589569c06308571664 + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway-el8.aarch64.rpm + sha256sum: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway-el8.aarch64.rpm.sha256sum {%- endif %} {%- if 'rhel9' in supported_base_os or 'centos9' in supported_base_os or 'rocky9' in supported_base_os %} rhel_centos_rocky9: version: 2023.0.531-1.el8.aarch64 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Gateway/nice-dcv-connection-gateway-2023.0.531-1.el8.aarch64.rpm - sha256sum: b55f344eff9ed07e3d706da24f369f7e93d2e106edc4b3589569c06308571664 + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway-el8.aarch64.rpm + sha256sum: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway-el8.aarch64.rpm.sha256sum {%- endif %} ubuntu: - {%- if 'ubuntu1804' in supported_base_os %} - ubuntu1804: - version: 2023.0.531-1_arm64.ubuntu1804 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Gateway/nice-dcv-connection-gateway_2023.0.531-1_arm64.ubuntu1804.deb - sha256sum: 1e7936fc95e9d4e7afb66ed044e1030ac51bc6bdb6ae2023e7389f3aa0890681 - {%- endif %} {%- if 'ubuntu2004' in supported_base_os %} ubuntu2004: version: 2023.0-531-1_arm64.ubuntu2004 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Gateway/nice-dcv-connection-gateway_2023.0.531-1_arm64.ubuntu2004.deb - sha256sum: 2f940ed9f345370f1b46754d6ff3fd3348dad4030d162dff0d34cb56110c4d8a + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway_arm64.ubuntu2004.deb + sha256sum: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway_arm64.ubuntu2004.deb.sha256sum {%- endif %} {%- if 'ubuntu2204' in supported_base_os %} ubuntu2204: version: 2023.0-531-1_arm64.ubuntu2204 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Gateway/nice-dcv-connection-gateway_2023.0.531-1_arm64.ubuntu2204.deb - sha256sum: 727c7ff0882edafa061bf18e9d7840c55cadc3ca1f716618bff7f70e4cf0919d + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway_arm64.ubuntu2204.deb + sha256sum: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-connection-gateway_arm64.ubuntu2204.deb.sha256sum {%- endif %} broker: linux: @@ -659,40 +645,37 @@ package_config: windows: msi: label: MSI - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Clients/nice-dcv-client-Release-2023.0-8655.msi + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-client-Release.msi zip: - label: ZIP - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Clients/nice-dcv-client-Release-portable-2023.0-8655.zip + label: Portable + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-client-Release-portable.zip macos: m1: - label: M1 Chip - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Clients/nice-dcv-viewer-2023.0.5388.arm64.dmg + label: ARM 64 + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-viewer.arm64.dmg intel: - label: Intel Chip - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Clients/nice-dcv-viewer-2023.0.5388.x86_64.dmg + label: x86_64 + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-viewer.x86_64.dmg linux: rhel_centos7: label: RHEL 7 | Cent OS 7 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Clients/nice-dcv-viewer-2023.0.5388-1.el7.x86_64.rpm + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-viewer-el7.x86_64.rpm rhel_centos_rocky8: label: RHEL 8 | Cent OS 8 | Rocky Linux 8 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Clients/nice-dcv-viewer-2023.0.5388-1.el8.x86_64.rpm + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-viewer-el8.x86_64.rpm rhel_centos_rocky9: label: RHEL 9 | Cent OS 9 | Rocky Linux 9 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Clients/nice-dcv-viewer-2023.0.5388-1.el9.x86_64.rpm + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-viewer-el9.x86_64.rpm suse15: label: SUSE Enterprise Linux 15 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Clients/nice-dcv-viewer-2023.0.5388-1.sles15.x86_64.rpm + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-viewer-sles15.x86_64.rpm ubuntu: - ubuntu1804: - label: Ubuntu 18.04 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Clients/nice-dcv-viewer_2023.0.5388-1_amd64.ubuntu1804.deb ubuntu2004: label: Ubuntu 20.04 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Clients/nice-dcv-viewer_2023.0.5388-1_amd64.ubuntu2004.deb + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-viewer_amd64.ubuntu2004.deb ubuntu2204: label: Ubuntu 22.04 - url: https://d1uj6qtbmh3dt5.cloudfront.net/2023.0/Clients/nice-dcv-viewer_2023.0.5388-1_amd64.ubuntu2204.deb + url: https://d1uj6qtbmh3dt5.cloudfront.net/nice-dcv-viewer_amd64.ubuntu2204.deb {%- endif %} gpu_settings: diff --git a/source/idea/idea-administrator/resources/config/templates/idea.yml b/source/idea/idea-administrator/resources/config/templates/idea.yml index 8321a05..dfb7ddf 100644 --- a/source/idea/idea-administrator/resources/config/templates/idea.yml +++ b/source/idea/idea-administrator/resources/config/templates/idea.yml @@ -13,12 +13,6 @@ modules: - settings.yml - logging.yml - - name: analytics - id: analytics - type: stack - config_files: - - settings.yml - - name: identity-provider id: identity-provider type: stack diff --git a/source/idea/idea-administrator/resources/config/templates/virtual-desktop-controller/settings.yml b/source/idea/idea-administrator/resources/config/templates/virtual-desktop-controller/settings.yml index fbe71ef..f5a6836 100644 --- a/source/idea/idea-administrator/resources/config/templates/virtual-desktop-controller/settings.yml +++ b/source/idea/idea-administrator/resources/config/templates/virtual-desktop-controller/settings.yml @@ -32,6 +32,7 @@ controller: min_capacity: 1 max_capacity: 3 cooldown_minutes: 5 + default_instance_warmup: 15 new_instances_protected_from_scale_in: false elb_healthcheck: # Specifies the time in minutes Auto Scaling waits before checking the health status of an EC2 instance that has come into service. @@ -67,6 +68,7 @@ dcv_broker: min_capacity: 1 max_capacity: 3 cooldown_minutes: 5 + default_instance_warmup: 15 new_instances_protected_from_scale_in: false elb_healthcheck: # Specifies the time in minutes Auto Scaling waits before checking the health status of an EC2 instance that has come into service. @@ -85,7 +87,7 @@ dcv_broker: gateway_communication_port: 8446 # DO NOT CHANGE # SSL/TLS Policy on VDC HTTPS listeners # For a list of policies - consult the documentation at https://docs.aws.amazon.com/elasticloadbalancing/latest/application/create-https-listener.html#describe-ssl-policies - ssl_policy: ELBSecurityPolicy-FS-1-2-Res-2020-10 + ssl_policy: ELBSecurityPolicy-TLS13-1-2-2021-06 session_token_validity: 1440 # in minutes dynamodb_table: autoscaling: @@ -114,6 +116,7 @@ dcv_connection_gateway: min_capacity: 1 max_capacity: 3 cooldown_minutes: 5 + default_instance_warmup: 15 new_instances_protected_from_scale_in: false elb_healthcheck: # Specifies the time in minutes Auto Scaling waits before checking the health status of an EC2 instance that has come into service. @@ -137,14 +140,6 @@ external_nlb: # external NLB access logs are enabled by default access_logs: true -opensearch: - dcv_session: - alias: {{cluster_name}}_{{module_id}}_user_sessions - software_stack: - alias: {{cluster_name}}_{{module_id}}_software_stacks - session_permission: - alias: {{cluster_name}}_{{module_id}}_session_permission - dcv_session: idle_timeout: 1440 # in minutes idle_timeout_warning: 300 # in seconds @@ -163,12 +158,12 @@ dcv_session: quic_support: {{ dcv_session_quic_support | lower }} # - # By default, the eVDI subnets match the cluster private subnets. + # eVDI subnets will match the provided dcv_session_private_subnet_ids, if not provided it will default to the cluster private subnets. # Here you can specify eVDI-specific subnets as an alternative to using the cluster.network.private_subnets # network: private_subnets: - {{ utils.to_yaml(private_subnet_ids) | indent(6) }} + {{ utils.to_yaml(dcv_session_private_subnet_ids) | indent(6) if dcv_session_private_subnet_ids else utils.to_yaml(private_subnet_ids) | indent(6) }} # Supported eVDI randomize_subnets settings: # True - Randomize the subnets (dcv_session.network.private_subnets or cluster.network.private_subnets) for deployment diff --git a/source/idea/idea-administrator/resources/config/values.yml b/source/idea/idea-administrator/resources/config/values.yml index 578bf57..2c0e3f6 100644 --- a/source/idea/idea-administrator/resources/config/values.yml +++ b/source/idea/idea-administrator/resources/config/values.yml @@ -52,7 +52,7 @@ base_os: ~ instance_ami: ~ volume_size: 200 -# Provide a list of module names to be enabled. cluster, directoryservice, cluster-manager and analytics are mandatory modules for IDEA to operate and will always be deployed +# Provide a list of module names to be enabled. cluster, directoryservice and cluster-manager are mandatory modules for RES to operate and will always be deployed enabled_modules: [ ] # metrics @@ -89,11 +89,16 @@ dcv_connection_gateway_custom_certificate_certificate_secret_arn: ~ dcv_connection_gateway_custom_certificate_private_key_secret_arn: ~ dcv_connection_gateway_custom_dns_hostname: ~ +# Virtual Desktop Controller - DCV Session Subnets +dcv_session_private_subnet_ids: [] + # Build cluster using existing resources use_existing_vpc: false vpc_id: ~ # value is required when use_existing_vpc == true private_subnet_ids: [ ] # value is required when use_existing_vpc == true public_subnet_ids: [ ] # value is required when use_existing_vpc == true +load_balancer_subnet_ids: [] # value is required when use_existing_vpc == true +infrastructure_host_subnet_ids: [] # value is required when use_existing_vpc == true use_existing_internal_fs: false # use_existing_vpc should be true when use_existing_internal_fs == true existing_internal_fs_id: ~ # value is required when use_existing_internal_fs == true @@ -101,9 +106,6 @@ existing_internal_fs_id: ~ # value is required when use_existing_internal_fs == use_existing_home_fs: false # use_existing_vpc should be true when use_existing_home_fs == true existing_home_fs_id: ~ # value is required when use_existing_home_fs == true -use_existing_opensearch_cluster: false # use_existing_vpc should be true when use_existing_opensearch_cluster == true -opensearch_domain_endpoint: ~ # value is required when use_existing_opensearch_cluster == true - use_existing_directory_service: false # use_existing_vpc should be true when use_existing_directory_service == true directory_id: ~ # value is required when use_existing_directory_service == true directory_service_root_username_secret_arn: ~ # value is required when use_existing_directory_service == true diff --git a/source/idea/idea-administrator/resources/input_params/install_params.yml b/source/idea/idea-administrator/resources/input_params/install_params.yml index f69f770..4a35955 100644 --- a/source/idea/idea-administrator/resources/input_params/install_params.yml +++ b/source/idea/idea-administrator/resources/input_params/install_params.yml @@ -89,6 +89,9 @@ SocaInputParamSpec: - name: existing_resources - name: public_subnet_ids - name: private_subnet_ids + - name: load_balancer_subnet_ids + - name: infrastructure_host_subnet_ids + - name: dcv_session_private_subnet_ids - name: directory_id - name: directory_service_root_username_secret_arn - name: directory_service_root_password_secret_arn @@ -96,7 +99,6 @@ SocaInputParamSpec: - name: existing_internal_fs_id - name: storage_home_provider - name: existing_home_fs_id - - name: opensearch_domain_endpoint - name: module-settings title: "Module Settings" required: yes @@ -499,6 +501,51 @@ SocaInputParamSpec: param: existing_resources contains: 'subnets:public' + - name: load_balancer_subnet_ids + title: "Existing External Load Balancer Subnets" + description: "Select existing external load balancer subnets" + param_type: checkbox + data_type: str + multiple: true + help_text: ~ + tag: default + markdown: ~ + validate: + required: yes + when: + param: existing_resources + contains: 'subnets:external_load_balancer' + + - name: infrastructure_host_subnet_ids + title: "Existing Infrastructure Hosts Subnets" + description: "Select existing infrastructure hosts subnets" + param_type: checkbox + data_type: str + multiple: true + help_text: ~ + tag: default + markdown: ~ + validate: + required: yes + when: + param: existing_resources + contains: 'subnets:infrastructure_hosts' + + - name: dcv_session_private_subnet_ids + title: "Existing DCV Session Private Subnets" + description: "Select existing dcv session private subnets" + param_type: checkbox + data_type: str + multiple: true + help_text: ~ + tag: default + markdown: ~ + validate: + required: yes + when: + param: existing_resources + contains: 'subnets:dcv_session' + - name: storage_internal_provider title: "Storage Provider: Internal" description: "Select internal storage provider" @@ -565,20 +612,6 @@ SocaInputParamSpec: param: existing_resources contains: 'shared-storage:home' - - name: opensearch_domain_endpoint - title: "Existing OpenSearch Service Domain" - description: "Select existing OpenSearch Service Domain" - param_type: select - data_type: str - help_text: ~ - tag: default - markdown: ~ - validate: - required: yes - when: - param: existing_resources - contains: 'analytics:opensearch' - - name: directory_id title: "Existing Directory" description: "Select existing AWS Managed Microsoft AD" diff --git a/source/idea/idea-administrator/resources/lambda_functions/idea_analytics_sink/__init__.py b/source/idea/idea-administrator/resources/lambda_functions/add_to_user_pool_client_scopes/__init__.py similarity index 100% rename from source/idea/idea-administrator/resources/lambda_functions/idea_analytics_sink/__init__.py rename to source/idea/idea-administrator/resources/lambda_functions/add_to_user_pool_client_scopes/__init__.py diff --git a/source/idea/idea-administrator/resources/lambda_functions/add_to_user_pool_client_scopes/handler.py b/source/idea/idea-administrator/resources/lambda_functions/add_to_user_pool_client_scopes/handler.py new file mode 100644 index 0000000..c468848 --- /dev/null +++ b/source/idea/idea-administrator/resources/lambda_functions/add_to_user_pool_client_scopes/handler.py @@ -0,0 +1,95 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from idea_lambda_commons import HttpClient, CfnResponse, CfnResponseStatus +import boto3 +import logging + +PHYSICAL_RESOURCE_ID = 'user-pool-client' + + +def handler(event: dict, context): + """ + Add OAuth scopes to an existing user pool client. + """ + request_type = event.get('RequestType', None) + resource_properties = event.get('ResourceProperties', {}) + cluster_name = resource_properties.get('cluster_name') + module_id = resource_properties.get('module_id') + stack_name = f'{cluster_name}-{module_id}' + client_id_secrete_name = f'{stack_name}-client-id' + user_pool_id = resource_properties.get('user_pool_id') + o_auth_scopes_to_add = resource_properties.get('o_auth_scopes_to_add') + client = HttpClient() + + if request_type == 'Delete': + client.send_cfn_response(CfnResponse( + context=context, + event=event, + status=CfnResponseStatus.SUCCESS, + data={}, + physical_resource_id=PHYSICAL_RESOURCE_ID + )) + return + + try: + # Retrieve the client ID from Secrets Manager + secretsmanager = boto3.client('secretsmanager') + res = secretsmanager.get_secret_value(SecretId=client_id_secrete_name) + client_id = res.get('SecretString', '') + + # Read the current configuration of the user pool client + idp_client = boto3.client('cognito-idp') + res = idp_client.describe_user_pool_client( + UserPoolId=user_pool_id, + ClientId=client_id + ) + user_pool_client = res.get('UserPoolClient', {}) + user_pool_client.pop('ClientSecret', None) + user_pool_client.pop('LastModifiedDate', None) + user_pool_client.pop('CreationDate', None) + + allowed_o_auth_scopes = user_pool_client.get('AllowedOAuthScopes', []) + allowed_o_auth_scopes_is_updated = False + for o_auth_scope in o_auth_scopes_to_add: + if o_auth_scope not in allowed_o_auth_scopes: + allowed_o_auth_scopes.append(o_auth_scope) + allowed_o_auth_scopes_is_updated = True + + if allowed_o_auth_scopes_is_updated: + # Only update the allowed OAuth scopes of the client and keep all the other attributes unchanged. + user_pool_client['AllowedOAuthScopes'] = allowed_o_auth_scopes + idp_client.update_user_pool_client( + **user_pool_client, + ) + + logging.info('add to user pool client scopes successfully') + client.send_cfn_response(CfnResponse( + context=context, + event=event, + status=CfnResponseStatus.SUCCESS, + data={}, + physical_resource_id=PHYSICAL_RESOURCE_ID + )) + except Exception as e: + error_message = f'failed to add to user pool client scopes: {e}' + logging.exception(error_message) + + client.send_cfn_response(CfnResponse( + context=context, + event=event, + status=CfnResponseStatus.FAILED, + data={}, + physical_resource_id=PHYSICAL_RESOURCE_ID, + reason=error_message, + )) + finally: + client.destroy() diff --git a/source/idea/idea-administrator/resources/lambda_functions/idea_analytics_sink/handler.py b/source/idea/idea-administrator/resources/lambda_functions/idea_analytics_sink/handler.py deleted file mode 100644 index 085e31f..0000000 --- a/source/idea/idea-administrator/resources/lambda_functions/idea_analytics_sink/handler.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'). You may not use this file except in compliance -# with the License. A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES -# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions -# and limitations under the License. -""" -Analytics Sink Lambda: -This function is triggered by the analytics kinesis stream when any module submits an analytics event -""" -import base64 -import os -import json -import logging -from opensearchpy import OpenSearch, helpers - -opensearch_endpoint = os.environ.get('opensearch_endpoint') -if not opensearch_endpoint.startswith('https://'): - opensearch_endpoint = f'https://{opensearch_endpoint}' - -os_client = OpenSearch( - hosts=[opensearch_endpoint], - port=443, - use_ssl=True, - verify_certs=True -) - -logger = logging.getLogger() -logger.setLevel(logging.INFO) - - -def handler(event, _): - try: - bulk_request_with_timestamp = [] - for record in event['Records']: - analytics_entry = base64.b64decode(record["kinesis"]["data"]) - analytics_entry = json.loads(analytics_entry.decode()) - request = { - "_id": analytics_entry["document_id"], - "_index": analytics_entry["index_id"] - } - - if analytics_entry["action"] == 'CREATE_ENTRY': - # create new - request["_op_type"] = "create" - request["_source"] = analytics_entry["entry"] - - elif analytics_entry["action"] == 'UPDATE_ENTRY': - # update existing - request["_op_type"] = "update" - request["doc"] = analytics_entry["entry"] - else: - # delete - request["_op_type"] = "delete" - - bulk_request_with_timestamp.append((analytics_entry["timestamp"], request)) - - bulk_request_with_timestamp_sorted = sorted(bulk_request_with_timestamp, key=lambda x: x[0]) - bulk_request = [] - for entry in bulk_request_with_timestamp_sorted: - request = entry[1] - logger.info(f'Submitting request for action: {request["_op_type"]} for document_id: {request["_id"]} to index {request["_index"]}') - bulk_request.append(request) - - response = helpers.bulk( - client=os_client, - actions=bulk_request - ) - logger.info(response) - except Exception as e: - logger.exception(f'Error while processing analytics request for event: {json.dumps(event)}, error: {e}') diff --git a/source/idea/idea-administrator/resources/lambda_functions/idea_analytics_sink/requirements.txt b/source/idea/idea-administrator/resources/lambda_functions/idea_analytics_sink/requirements.txt deleted file mode 100644 index 63c4ec3..0000000 --- a/source/idea/idea-administrator/resources/lambda_functions/idea_analytics_sink/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -opensearch-py diff --git a/source/idea/idea-administrator/resources/lambda_functions/idea_custom_resource_opensearch_private_ips/handler.py b/source/idea/idea-administrator/resources/lambda_functions/idea_custom_resource_opensearch_private_ips/handler.py deleted file mode 100644 index e3db36a..0000000 --- a/source/idea/idea-administrator/resources/lambda_functions/idea_custom_resource_opensearch_private_ips/handler.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'). You may not use this file except in compliance -# with the License. A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES -# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions -# and limitations under the License. - -from idea_lambda_commons import HttpClient, CfnResponse, CfnResponseStatus -import boto3 -import logging -import json - -logger = logging.getLogger() -logger.setLevel(logging.INFO) - -PHYSICAL_RESOURCE_ID = 'opensearch-private-ip-addresses' - - -def handler(event, context): - domain_name = None - http_client = HttpClient() - try: - logger.info(f'ReceivedEvent: {json.dumps(event)}') - request_type = event.get('RequestType') - - if request_type == 'Delete': - http_client.send_cfn_response(CfnResponse( - context=context, - event=event, - status=CfnResponseStatus.SUCCESS, - data={}, - physical_resource_id=PHYSICAL_RESOURCE_ID - )) - return - - resource_properties = event.get('ResourceProperties', {}) - domain_name = resource_properties.get('DomainName') - logger.info('OpenSearch DomainName: ' + domain_name) - - ec2_client = boto3.client('ec2') - response = ec2_client.describe_network_interfaces(Filters=[ - {'Name': 'description', 'Values': [f'ES {domain_name}']}, - {'Name': 'requester-id', 'Values': ['amazon-elasticsearch']} - ]) - - network_interfaces = response.get('NetworkInterfaces', []) - result = [] - for network_interface in network_interfaces: - logger.debug(network_interface) - private_ip_addresses = network_interface.get('PrivateIpAddresses', []) - for private_ip_address in private_ip_addresses: - ip_address = private_ip_address.get('PrivateIpAddress', None) - if ip_address is None: - continue - result.append(ip_address) - - if len(result) == 0: - msg = 'No IP addresses found' - logger.error(msg) - http_client.send_cfn_response(CfnResponse( - context=context, - event=event, - status=CfnResponseStatus.FAILED, - data={ - 'error': msg - }, - physical_resource_id=PHYSICAL_RESOURCE_ID - )) - else: - http_client.send_cfn_response(CfnResponse( - context=context, - event=event, - status=CfnResponseStatus.SUCCESS, - data={ - 'IpAddresses': ','.join(result) - }, - physical_resource_id=PHYSICAL_RESOURCE_ID - )) - except Exception as e: - logger.exception(f'Failed to get ES Private IP Address: {e}') - error_message = f'Exception getting private IP addresses for ES soca-{domain_name}' - http_client.send_cfn_response(CfnResponse( - context=context, - event=event, - status=CfnResponseStatus.FAILED, - data={ - 'error': error_message - }, - physical_resource_id=PHYSICAL_RESOURCE_ID - )) - finally: - http_client.destroy() diff --git a/source/idea/idea-administrator/resources/policies/add-to-user-pool-client-scopes.yml b/source/idea/idea-administrator/resources/policies/add-to-user-pool-client-scopes.yml new file mode 100644 index 0000000..71ab050 --- /dev/null +++ b/source/idea/idea-administrator/resources/policies/add-to-user-pool-client-scopes.yml @@ -0,0 +1,34 @@ +Version: '2012-10-17' +Statement: + - Action: + - logs:CreateLogGroup + - logs:CreateLogStream + - logs:DeleteLogStream + - logs:PutLogEvents + Resource: {{ context.arns.get_lambda_log_group_arn() }} + Effect: Allow + Sid: CloudWatchLogsPermissions + + - Action: + - cognito-idp:DescribeUserPoolClient + - cognito-idp:UpdateUserPoolClient + Resource: '*' + Effect: Allow + + - Action: + - secretsmanager:GetSecretValue + Condition: + StringEquals: + secretsmanager:ResourceTag/res:EnvironmentName: '{{ context.cluster_name }}' + secretsmanager:ResourceTag/res:ModuleId: '{{ context.config.get_module_id("cluster-manager") }}' + Resource: '*' + Effect: Allow + + {%- if context.config.get_string('cluster.secretsmanager.kms_key_id') %} + - Action: + - kms:GenerateDataKey + - kms:Decrypt + Resource: + - '{{ context.arns.kms_key_arn }}' + Effect: Allow + {%- endif %} diff --git a/source/idea/idea-administrator/resources/policies/analytics-sink-lambda.yml b/source/idea/idea-administrator/resources/policies/analytics-sink-lambda.yml deleted file mode 100644 index 4b42e0a..0000000 --- a/source/idea/idea-administrator/resources/policies/analytics-sink-lambda.yml +++ /dev/null @@ -1,53 +0,0 @@ -Version: '2012-10-17' -Statement: - - Action: - - kinesis:DescribeStream - - kinesis:DescribeStreamSummary - - kinesis:GetRecords - - kinesis:GetShardIterator - - kinesis:ListShards - - kinesis:ListStreams - - kinesis:SubscribeToShard - Resource: - - {{ context.arns.get_kinesis_arn() }} - Effect: Allow - - {%- if context.config.get_string('analytics.kinesis.kms_key_id') %} - - Action: - - kms:GenerateDataKey - - kms:Decrypt - Resource: - - '{{ context.arns.kms_kinesis_key_arn }}' - Effect: Allow - {%- endif %} - - - Action: - - logs:CreateLogGroup - - logs:CreateLogStream - - logs:PutLogEvents - Resource: '*' - Effect: Allow - - - Effect: Allow - Action: - - es:ESHttpPost - - es:ESHttpPut - Resource: '*' - - {%- if context.config.get_string('analytics.opensearch.kms_key_id') %} - - Action: - - kms:GenerateDataKey - - kms:Decrypt - Resource: - - '{{ context.arns.kms_opensearch_key_arn }}' - Effect: Allow - {%- endif %} - - - Effect: Allow - Action: - - ec2:CreateNetworkInterface - - ec2:DescribeNetworkInterfaces - - ec2:DeleteNetworkInterface - - ec2:AssignPrivateIpAddresses - - ec2:UnassignPrivateIpAddresses - Resource: '*' diff --git a/source/idea/idea-administrator/resources/policies/cluster-manager.yml b/source/idea/idea-administrator/resources/policies/cluster-manager.yml index be1732f..de7bc91 100644 --- a/source/idea/idea-administrator/resources/policies/cluster-manager.yml +++ b/source/idea/idea-administrator/resources/policies/cluster-manager.yml @@ -100,6 +100,16 @@ Statement: - '{{ context.arns.get_ddb_table_arn("ad-automation") }}' - '{{ context.arns.get_ddb_table_arn(context.module_id + ".distributed-lock") }}' - '{{ context.arns.get_ddb_table_arn("snapshots") }}' + - '{{ context.arns.get_ddb_table_arn("apply-snapshot") }}' + Effect: Allow + + - Action: + - dynamodb:ImportTable + - dynamodb:DescribeImport + - dynamodb:DeleteTable + - dynamodb:Scan + Resource: + - '{{ context.arns.get_ddb_table_arn("temp-*") }}' Effect: Allow - Action: diff --git a/source/idea/idea-administrator/resources/policies/virtual-desktop-dcv-broker.yml b/source/idea/idea-administrator/resources/policies/virtual-desktop-dcv-broker.yml index 13a5bac..816e358 100644 --- a/source/idea/idea-administrator/resources/policies/virtual-desktop-dcv-broker.yml +++ b/source/idea/idea-administrator/resources/policies/virtual-desktop-dcv-broker.yml @@ -66,7 +66,6 @@ Statement: - tag:GetResources - tag:GetTagValues - tag:GetTagKeys - - iam:PassRole - ssm:ListDocuments - ssm:ListDocumentVersions - ssm:DescribeDocument @@ -85,5 +84,10 @@ Statement: - logs:PutRetentionPolicy Resource: '*' Effect: Allow + - Action: + - iam:PassRole + Resource: + - '{{ context.arns.get_iam_arn("vdc-broker-role")}}' + Effect: Allow {% include '_templates/custom-kms-key.yml' %} diff --git a/source/idea/idea-administrator/resources/policies/virtual-desktop-dcv-connection-gateway.yml b/source/idea/idea-administrator/resources/policies/virtual-desktop-dcv-connection-gateway.yml index a0ce2ae..d750a2c 100644 --- a/source/idea/idea-administrator/resources/policies/virtual-desktop-dcv-connection-gateway.yml +++ b/source/idea/idea-administrator/resources/policies/virtual-desktop-dcv-connection-gateway.yml @@ -18,10 +18,12 @@ Statement: - tag:GetResources - tag:GetTagValues - tag:GetTagKeys - - iam:PassRole Resource: '*' Effect: Allow - + - Action: + - iam:PassRole + Resource: '{{ context.arns.get_iam_arn("vdc-gateway-role")}}' + Effect: Allow - Action: - s3:GetObject - s3:ListBucket diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/aws_service_availability_helper.py b/source/idea/idea-administrator/src/ideaadministrator/app/aws_service_availability_helper.py index 0a21671..575048b 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/aws_service_availability_helper.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/aws_service_availability_helper.py @@ -100,10 +100,6 @@ def __init__(self, aws_region: str = None, aws_profile: str = None, aws_secondar 'title': 'Amazon Elastic Load Balancing (ELB)', 'required': True }, - 'es': { - 'title': 'Amazon OpenSearch Service', - 'required': True - }, 'eventbridge': { 'title': 'Amazon EventBridge', 'required': True diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/cdk_app.py b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/cdk_app.py index f674552..ce53ff8 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/cdk_app.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/cdk_app.py @@ -20,7 +20,6 @@ ClusterManagerStack, SchedulerStack, BastionHostStack, - AnalyticsStack, VirtualDesktopControllerStack, MetricsStack ) @@ -160,18 +159,6 @@ def bastion_host_stack(self): env=self.cdk_env ) - def analytics_stack(self): - AnalyticsStack( - scope=self.cdk_app, - cluster_name=self.cluster_name, - aws_region=self.aws_region, - aws_profile=self.aws_profile, - module_id=self.module_id, - deployment_id=self.deployment_id, - termination_protection=self.termination_protection, - env=self.cdk_env - ) - def virtual_desktop_controller_stack(self): VirtualDesktopControllerStack( scope=self.cdk_app, @@ -213,8 +200,6 @@ def build_stack(self): self.scheduler_stack() elif self.module_name == constants.MODULE_BASTION_HOST: self.bastion_host_stack() - elif self.module_name == constants.MODULE_ANALYTICS: - self.analytics_stack() elif self.module_name == constants.MODULE_VIRTUAL_DESKTOP_CONTROLLER: self.virtual_desktop_controller_stack() elif self.module_name == constants.MODULE_METRICS: diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/cdk_invoker.py b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/cdk_invoker.py index 4075e36..92c0656 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/cdk_invoker.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/cdk_invoker.py @@ -165,7 +165,6 @@ def __init__(self, constants.MODULE_SCHEDULER: self.invoke_scheduler, constants.MODULE_BASTION_HOST: self.invoke_bastion_host, constants.MODULE_VIRTUAL_DESKTOP_CONTROLLER: self.invoke_virtual_desktop_controller, - constants.MODULE_ANALYTICS: self.invoke_analytics, constants.MODULE_METRICS: self.invoke_metrics } @@ -511,16 +510,6 @@ def invoke_shared_storage(self, **_): ]) self.exec_shell(cdk_cmd) - def invoke_analytics(self, **_): - outputs_file = os.path.join(self.deployment_dir, 'analytics-outputs.json') - cdk_app_cmd = self.get_cdk_app_cmd() - cdk_cmd = self.get_cdk_command('deploy', [ - f"--app '{cdk_app_cmd}' ", - f'--outputs-file {outputs_file} ', - '--require-approval never' - ]) - self.exec_shell(cdk_cmd) - def invoke_metrics(self, **_): outputs_file = os.path.join(self.deployment_dir, 'metrics-outputs.json') cdk_app_cmd = self.get_cdk_app_cmd() diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/__init__.py b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/__init__.py index b3e1fc3..4d54d62 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/__init__.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/__init__.py @@ -17,4 +17,3 @@ from .dns import * from .directory_service import * from .storage import * -from .analytics import OpenSearch diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/analytics.py b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/analytics.py deleted file mode 100644 index 42f2295..0000000 --- a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/analytics.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance -# with the License. A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES -# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions -# and limitations under the License. - -__all__ = ('OpenSearch') # noqa - -from ideaadministrator.app.cdk.constructs import SocaBaseConstruct, ExistingSocaCluster, IdeaNagSuppression -from ideasdk.context import ArnBuilder -from ideaadministrator.app_context import AdministratorContext -from ideasdk.utils import Utils - -from typing import List, Dict, Optional - -import aws_cdk as cdk -import constructs -from aws_cdk import ( - aws_ec2 as ec2, - aws_iam as iam, - aws_opensearchservice as opensearch, - aws_kms as kms -) - - -class OpenSearch(SocaBaseConstruct, opensearch.Domain): - - def __init__( - self, context: AdministratorContext, name: str, scope: constructs.Construct, - cluster: ExistingSocaCluster, - security_groups: List[ec2.ISecurityGroup], - data_nodes: int, - data_node_instance_type: str, - ebs_volume_size: int, - removal_policy: cdk.RemovalPolicy, - version: Optional[opensearch.EngineVersion] = None, - create_service_linked_role: bool = True, - access_policies: Optional[List[iam.PolicyStatement]] = None, - advanced_options: Optional[Dict[str, str]] = None, - automated_snapshot_start_hour: Optional[int] = None, - capacity: Optional[opensearch.CapacityConfig] = None, - cognito_dashboards_auth: Optional[opensearch.CognitoOptions] = None, - custom_endpoint: Optional[opensearch.CustomEndpointOptions] = None, - domain_name: Optional[str] = None, - ebs: Optional[opensearch.EbsOptions] = None, - enable_version_upgrade: Optional[bool] = None, - encryption_at_rest: Optional[opensearch.EncryptionAtRestOptions] = None, - kms_key_arn: Optional[kms.IKey] = None, - enforce_https: Optional[bool] = None, - fine_grained_access_control: Optional[opensearch.AdvancedSecurityOptions] = None, - logging: Optional[opensearch.LoggingOptions] = None, - node_to_node_encryption: Optional[bool] = None, - tls_security_policy: Optional[opensearch.TLSSecurityPolicy] = None, - use_unsigned_basic_auth: Optional[bool] = None, - vpc_subnets: Optional[List[ec2.SubnetSelection]] = None, - zone_awareness: Optional[opensearch.ZoneAwarenessConfig] = None - ): - - self.context = context - - if version is None: - version = opensearch.EngineVersion.OPENSEARCH_2_3 - - if vpc_subnets is None: - vpc_subnets = [ec2.SubnetSelection( - subnets=cluster.private_subnets[0:data_nodes] - )] - - if zone_awareness is None: - if data_nodes > 1: - zone_awareness = opensearch.ZoneAwarenessConfig( - enabled=True, - availability_zone_count=min(3, data_nodes) - ) - else: - zone_awareness = opensearch.ZoneAwarenessConfig( - enabled=False - ) - - if domain_name is None: - domain_name = self.build_resource_name(name).lower() - - if enforce_https is None: - enforce_https = True - - if encryption_at_rest is None: - encryption_at_rest = opensearch.EncryptionAtRestOptions( - enabled=True, - kms_key=kms_key_arn - ) - - if ebs is None: - ebs = opensearch.EbsOptions( - volume_size=ebs_volume_size, - volume_type=ec2.EbsDeviceVolumeType.GP3 - ) - - if capacity is None: - capacity = opensearch.CapacityConfig( - data_node_instance_type=data_node_instance_type, - data_nodes=data_nodes - ) - - if automated_snapshot_start_hour is None: - automated_snapshot_start_hour = 0 - - if removal_policy is None: - removal_policy = cdk.RemovalPolicy.DESTROY - - if access_policies is None: - arn_builder = ArnBuilder(self.context.config()) - access_policies = [ - iam.PolicyStatement( - principals=[iam.AnyPrincipal()], - actions=['es:ESHttp*'], - resources=[ - arn_builder.get_arn( - 'es', - f'domain/{domain_name}/*' - ) - ] - ) - ] - - if advanced_options is None: - advanced_options = { - 'rest.action.multi.allow_explicit_index': 'true' - } - - super().__init__( - context, name, scope, - version=version, - access_policies=access_policies, - advanced_options=advanced_options, - automated_snapshot_start_hour=automated_snapshot_start_hour, - capacity=capacity, - cognito_dashboards_auth=cognito_dashboards_auth, - custom_endpoint=custom_endpoint, - domain_name=domain_name, - ebs=ebs, - enable_version_upgrade=enable_version_upgrade, - encryption_at_rest=encryption_at_rest, - enforce_https=enforce_https, - fine_grained_access_control=fine_grained_access_control, - logging=logging, - node_to_node_encryption=node_to_node_encryption, - removal_policy=removal_policy, - security_groups=security_groups, - tls_security_policy=tls_security_policy, - use_unsigned_basic_auth=use_unsigned_basic_auth, - vpc=cluster.vpc, - vpc_subnets=vpc_subnets, - zone_awareness=zone_awareness) - - if create_service_linked_role: - - aws_service_name = self.context.config().get_string('global-settings.opensearch.aws_service_name') - if Utils.is_empty(aws_service_name): - dns_suffix = self.context.config().get_string('cluster.aws.dns_suffix', required=True) - aws_service_name = f'es.{dns_suffix}' - - # DO NOT CHANGE THE DESCRIPTION OF THE ROLE. - service_linked_role = iam.CfnServiceLinkedRole( - self, - self.build_resource_name('es-service-linked-role'), - aws_service_name=aws_service_name, - description='Role for ES to access resources in the VPC' - ) - self.node.add_dependency(service_linked_role) - - self.add_nag_suppression(suppressions=[ - IdeaNagSuppression(rule_id='AwsSolutions-OS3', reason='Access to OpenSearch cluster is restricted within a VPC'), - IdeaNagSuppression(rule_id='AwsSolutions-OS4', reason='Use existing resources flow to provision an even more scalable OpenSearch cluster with dedicated master nodes'), - IdeaNagSuppression(rule_id='AwsSolutions-OS5', reason='Access to OpenSearch cluster is restricted within a VPC') - ]) diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/common.py b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/common.py index e86e15a..d6cc730 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/common.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/common.py @@ -23,7 +23,6 @@ 'CloudWatchAlarm', 'Output', 'DynamoDBTable', - 'KinesisStream' ) from aws_cdk.aws_ec2 import IVpc, SubnetSelection, ISecurityGroup @@ -594,22 +593,3 @@ def __init__(self, context: AdministratorContext, name: str, scope: constructs.C write_capacity=write_capacity, partition_key=partition_key, sort_key=sort_key) - - -class KinesisStream(SocaBaseConstruct, kinesis.Stream): - def __init__(self, context: AdministratorContext, name: str, scope: constructs.Construct, stream_name: str, stream_mode: kinesis.StreamMode, shard_count: Optional[int]): - self.context = context - kms_key_id = self.context.config().get_string('analytics.kinesis.kms_key_id') - - if kms_key_id is not None: - kms_key_arn = self.get_kms_key_arn(kms_key_id) - kinesis_encryption_key = kms.Key.from_key_arn(scope=scope, id=f'kinesis-kms-key', key_arn=kms_key_arn) - else: - kinesis_encryption_key = kms.Alias.from_alias_name(scope=scope, id=f'kinesis-kms-key-default', alias_name='alias/aws/kinesis') - - super().__init__(context, name, scope, - stream_name=f'{context.cluster_name()}-{stream_name}', - stream_mode=stream_mode, - encryption=kinesis.StreamEncryption.KMS, - encryption_key=kinesis_encryption_key, - shard_count=shard_count) diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/directory_service.py b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/directory_service.py index d771a05..b82ede2 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/directory_service.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/directory_service.py @@ -418,26 +418,6 @@ def build_user_pool(self, props: cognito.UserPoolProps): IdeaNagSuppression(rule_id='AwsSolutions-COG3', reason='suppress advanced security rule 1/to save cost, 2/Not supported in GovCloud') ]) - group_name_helper = GroupNameHelper(self.context) - - cognito.CfnUserPoolGroup( - scope=self.scope, - id=f'{user_pool_name}-administrators-group', - description='Administrators group (Sudo Users)', - group_name=group_name_helper.get_cluster_administrators_group(), - precedence=1, - user_pool_id=self.user_pool.user_pool_id - ) - - cognito.CfnUserPoolGroup( - scope=self.scope, - id=f'{user_pool_name}-managers-group', - description='Managers group with limited administration access.', - group_name=group_name_helper.get_cluster_managers_group(), - precedence=2, - user_pool_id=self.user_pool.user_pool_id - ) - domain_url = self.context.config().get_string('identity-provider.cognito.domain_url') if Utils.is_not_empty(domain_url): domain_prefix = domain_url.replace('https://', '').split('.')[0] diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/existing_resources.py b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/existing_resources.py index 734f44d..846dc90 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/existing_resources.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/existing_resources.py @@ -11,7 +11,8 @@ __all__ = ( 'ExistingVpc', - 'ExistingSocaCluster' + 'ExistingSocaCluster', + 'SubnetFilterKeys' ) from ideasdk.utils import Utils @@ -21,6 +22,7 @@ from ideaadministrator.app.cdk.constructs import ( SocaBaseConstruct ) +from enum import Enum from typing import List, Optional, Dict @@ -30,6 +32,11 @@ aws_iam as iam ) +class SubnetFilterKeys(str, Enum): + LOAD_BALANCER_SUBNETS = "load_balancer_subnets" + INFRASTRUCTURE_HOST_SUBNETS = "infrastructure_host_subnets" + CLUSTER = "cluster" + class ExistingVpc(SocaBaseConstruct): """ @@ -43,8 +50,8 @@ def __init__(self, context: AdministratorContext, name: str, scope: constructs.C self.scope = scope self.vpc_id = self.context.config().get_string('cluster.network.vpc_id', required=True) self.vpc = ec2.Vpc.from_lookup(self.scope, 'vpc', vpc_id=self.vpc_id) - self._private_subnets: Optional[List[ec2.ISubnet]] = None - self._public_subnets: Optional[List[ec2.ISubnet]] = None + self._private_subnets: Dict[SubnetFilterKeys, List[ec2.ISubnet]] = {} + self._public_subnets: Dict[SubnetFilterKeys, List[ec2.ISubnet]] = {} def lookup_vpc(self): self.vpc = ec2.Vpc.from_lookup(self.scope, 'vpc', vpc_id=self.vpc_id) @@ -54,19 +61,27 @@ def get_public_subnet_ids(self) -> List[str]: def get_private_subnet_ids(self) -> List[str]: return self.context.config().get_list('cluster.network.private_subnets', []) + + def get_filter_subnets_ids(self, subnet_filter_key: SubnetFilterKeys) -> List[str]: + return self.context.config().get_list(f'cluster.network.{subnet_filter_key}', []) - def get_public_subnets(self) -> List[ec2.ISubnet]: + def get_public_subnets(self, subnet_filter_key: Optional[SubnetFilterKeys] = None ) -> List[ec2.ISubnet]: """ + Filter subnets by providing a subnet filter key. For an example, this can be used to get public subnets specific to the external load balancers or filter subnets from Vpc based on subnet ids configured in `cluster.network.public_subnets` the result is sorted based on the order of subnet ids provided in the configuration. """ - if self._public_subnets is not None: - return self._public_subnets - public_subnet_ids = self.get_public_subnet_ids() + # default is to get external load balancer public subnets + subnet_filter_key = subnet_filter_key if subnet_filter_key else SubnetFilterKeys.LOAD_BALANCER_SUBNETS + result = self._public_subnets.get(subnet_filter_key) + if result is not None: + return result + + public_subnet_ids = self.get_public_subnet_ids() if subnet_filter_key == SubnetFilterKeys.CLUSTER else self.get_filter_subnets_ids(subnet_filter_key) if Utils.is_empty(public_subnet_ids): - self._public_subnets = [] - return self._public_subnets + self._public_subnets[subnet_filter_key] = [] + return self._public_subnets[subnet_filter_key] result = [] if self.vpc.public_subnets is not None: @@ -77,11 +92,12 @@ def get_public_subnets(self) -> List[ec2.ISubnet]: # sort based on index in public_subnets[] configuration result.sort(key=lambda x: public_subnet_ids.index(x.subnet_id)) - self._public_subnets = result - return self._public_subnets + self._public_subnets[subnet_filter_key] = result + return self._public_subnets[subnet_filter_key] - def get_private_subnets(self) -> List[ec2.ISubnet]: + def get_private_subnets(self, subnet_filter_key: Optional[SubnetFilterKeys] = None) -> List[ec2.ISubnet]: """ + filter subnets by providing a subnet filter key. For an example, this can be used to get private subnets specific to the external load balancers or filter subnets from Vpc based on subnet ids configured in `cluster.network.private_subnets` the result is sorted based on the order of subnet ids provided in the configuration. @@ -96,13 +112,17 @@ def get_private_subnets(self) -> List[ec2.ISubnet]: * After lookup, ec2.IVpc buckets the subnets under public_subnets, private_subnets and isolated_subnets. * To ensure all configured subnets are selected, both vpc.private_subnets and vpc.isolated_subnets are checked to resolve ec2.ISubnet """ - if self._private_subnets is not None: - return self._private_subnets - private_subnet_ids = self.get_private_subnet_ids() + # default to get infrastructure hosts subnets + subnet_filter_key = subnet_filter_key if subnet_filter_key else SubnetFilterKeys.INFRASTRUCTURE_HOST_SUBNETS + result = self._private_subnets.get(subnet_filter_key) + if result is not None: + return result + + private_subnet_ids = self.get_private_subnet_ids() if subnet_filter_key == SubnetFilterKeys.CLUSTER else self.get_filter_subnets_ids(subnet_filter_key) if Utils.is_empty(private_subnet_ids): - self._private_subnets = [] - return self._private_subnets + self._private_subnets[subnet_filter_key] = [] + return self._private_subnets[subnet_filter_key] result = [] if self.vpc.private_subnets is not None: @@ -116,9 +136,8 @@ def get_private_subnets(self) -> List[ec2.ISubnet]: # sort based on index in private_subnets[] configuration result.sort(key=lambda x: private_subnet_ids.index(x.subnet_id)) - self._private_subnets = result - return self._private_subnets - + self._private_subnets[subnet_filter_key] = result + return self._private_subnets[subnet_filter_key] class ExistingSocaCluster(SocaBaseConstruct): diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/network.py b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/network.py index 13b9772..b3116a0 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/network.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/constructs/network.py @@ -24,7 +24,6 @@ 'VpcEndpointSecurityGroup', 'VpcGatewayEndpoint', 'VpcInterfaceEndpoint', - 'OpenSearchSecurityGroup', 'DefaultClusterSecurityGroup', 'VirtualDesktopPublicLoadBalancerAccessSecurityGroup', 'VirtualDesktopBastionAccessSecurityGroup' @@ -635,25 +634,6 @@ def get_endpoint_url(self) -> str: return f'https://{dns}' -class OpenSearchSecurityGroup(SecurityGroup): - - def __init__(self, context: AdministratorContext, name: str, scope: constructs.Construct, - vpc: ec2.IVpc): - super().__init__(context, name, scope, vpc, description='OpenSearch security group') - self.setup_ingress() - self.setup_egress() - - def setup_ingress(self): - self.add_ingress_rule( - ec2.Peer.ipv4(self.vpc.vpc_cidr_block), - ec2.Port.tcp(443), - description='Allow HTTPS traffic from all VPC nodes to OpenSearch' - ) - - def setup_egress(self): - self.add_outbound_traffic_rule() - - class DefaultClusterSecurityGroup(SecurityGroup): """ Default Cluster Security Group with no inbound or outbound rules. @@ -679,7 +659,7 @@ def setup_ingress(self): self.add_ingress_rule( ec2.Peer.ipv4(self.vpc.vpc_cidr_block), ec2.Port.tcp(443), - description='Allow HTTPS traffic from all VPC nodes to OpenSearch' + description='Allow HTTPS traffic from all VPC nodes' ) def setup_egress(self): diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/__init__.py b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/__init__.py index e455fa7..d499f7f 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/__init__.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/__init__.py @@ -19,6 +19,5 @@ from .cluster_manager_stack import ClusterManagerStack from .scheduler_stack import SchedulerStack from .bastion_host_stack import BastionHostStack -from .analytics_stack import AnalyticsStack from .virtual_desktop_controller_stack import VirtualDesktopControllerStack from .metrics_stack import MetricsStack diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/analytics_stack.py b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/analytics_stack.py deleted file mode 100644 index 28ab322..0000000 --- a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/analytics_stack.py +++ /dev/null @@ -1,373 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance -# with the License. A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES -# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions -# and limitations under the License. -from ideaadministrator.app.cdk.idea_code_asset import IdeaCodeAsset, SupportedLambdaPlatforms -from ideadatamodel import ( - constants -) - -import ideaadministrator -from ideaadministrator.app.cdk.stacks import IdeaBaseStack - -from ideaadministrator.app.cdk.constructs import ( - ExistingSocaCluster, - KinesisStream, - IdeaNagSuppression, - Role, - Policy, - LambdaFunction, - OpenSearchSecurityGroup, - OpenSearch, - CustomResource -) -from ideadatamodel import exceptions - -from ideasdk.utils import Utils - -from typing import Optional -import aws_cdk as cdk -import constructs -from aws_cdk import ( - aws_ec2 as ec2, - aws_opensearchservice as opensearch, - aws_elasticloadbalancingv2 as elbv2, - aws_kinesis as kinesis, - aws_lambda as lambda_, - aws_kms as kms, - aws_lambda_event_sources as lambda_event_sources, - aws_logs as logs -) - - -class AnalyticsStack(IdeaBaseStack): - - def __init__(self, scope: constructs.Construct, - cluster_name: str, - aws_region: str, - aws_profile: str, - module_id: str, - deployment_id: str, - termination_protection: bool = True, - env: cdk.Environment = None): - - super().__init__( - scope=scope, - cluster_name=cluster_name, - aws_region=aws_region, - aws_profile=aws_profile, - module_id=module_id, - deployment_id=deployment_id, - termination_protection=termination_protection, - description=f'ModuleId: {module_id}, Cluster: {cluster_name}, Version: {ideaadministrator.props.current_release_version}', - tags={ - constants.IDEA_TAG_MODULE_ID: module_id, - constants.IDEA_TAG_MODULE_NAME: constants.MODULE_ANALYTICS, - constants.IDEA_TAG_MODULE_VERSION: ideaadministrator.props.current_release_version - }, - env=env - ) - - self.cluster = ExistingSocaCluster(self.context, self.stack) - - self.security_group: Optional[ec2.ISecurityGroup] = None - self.opensearch: Optional[opensearch.Domain] = None - self.kinesis_stream: Optional[KinesisStream] = None - - self.build_security_group() - is_existing = self.context.config().get_bool('analytics.opensearch.use_existing', default=False) - if is_existing: - domain_vpc_endpoint_url = self.context.config().get_string('analytics.opensearch.domain_vpc_endpoint_url', required=True) - self.opensearch = opensearch.Domain.from_domain_endpoint( - scope=self.stack, - id='existing-opensearch', - domain_endpoint=f'https://{domain_vpc_endpoint_url}' - ) - self.build_dashboard_endpoints() - else: - self.build_opensearch() - self.build_dashboard_endpoints() - - self.add_nag_suppression( - construct=self.stack, - suppressions=[ - IdeaNagSuppression(rule_id='AwsSolutions-IAM5', reason='CDK L2 construct does not support custom LogGroup permissions'), - IdeaNagSuppression(rule_id='AwsSolutions-IAM4', reason='Usage is required for Service Linked Role'), - IdeaNagSuppression(rule_id='AwsSolutions-L1', reason='CDK L2 construct does not offer options to customize the Lambda runtime'), - IdeaNagSuppression(rule_id='AwsSolutions-KDS3', reason='Kinesis Data Stream is encrypted with customer-managed KMS key') - ] - ) - data_nodes = self.context.config().get_int('analytics.opensearch.data_nodes', required=True) - if data_nodes == 1: - self.add_nag_suppression( - construct=self.stack, - suppressions=[ - IdeaNagSuppression(rule_id='AwsSolutions-OS7', reason='OpenSearch domain has 1 data node disabling Zone Awareness') - ] - ) - - self.build_analytics_input_stream() - self.build_cluster_settings() - - def build_analytics_input_stream(self): - - stream_config = self.context.config().get_string('analytics.kinesis.stream_mode', required=True) - if stream_config not in {'PROVISIONED', 'ON_DEMAND'}: - raise exceptions.invalid_params('analytics.kinesis.stream_mode needs to be one of PROVISIONED or ON_DEMAND only') - - if stream_config == 'PROVISIONED': - stream_mode = kinesis.StreamMode.PROVISIONED - shard_count = self.context.config().get_int('analytics.kinesis.shard_count', required=True) - else: - stream_mode = kinesis.StreamMode.ON_DEMAND - shard_count = None - - self.kinesis_stream = KinesisStream( - context=self.context, - name=f'{self.module_id}-kinesis-stream', - scope=self.stack, - stream_name=f'{self.module_id}-kinesis-stream', - stream_mode=stream_mode, - shard_count=shard_count - ) - if self.aws_region in Utils.get_value_as_list('KINESIS_STREAMS_CLOUDFORMATION_UNSUPPORTED_STREAMMODEDETAILS_REGION_LIST', constants.CAVEATS, []): - self.kinesis_stream.node.default_child.add_property_deletion_override('StreamModeDetails') - - lambda_name = f'{self.module_id}-sink-lambda' - stream_processing_lambda_role = Role( - context=self.context, - name=f'{lambda_name}-role', - scope=self.stack, - description=f'Role for {lambda_name} function for Cluster: {self.cluster_name}', - assumed_by=['lambda']) - - stream_processing_lambda_role.attach_inline_policy(Policy( - context=self.context, - name=f'{lambda_name}-policy', - scope=self.stack, - policy_template_name='analytics-sink-lambda.yml' - )) - - stream_processing_lambda = LambdaFunction( - self.context, - lambda_name, - self.stack, - idea_code_asset=IdeaCodeAsset( - lambda_package_name='idea_analytics_sink', - lambda_platform=SupportedLambdaPlatforms.PYTHON - ), - description='Lambda to process analytics-kinesis-stream data', - timeout_seconds=900, - security_groups=[self.security_group], - role=stream_processing_lambda_role, - environment={ - 'opensearch_endpoint': self.opensearch.domain_endpoint - }, - vpc=self.cluster.vpc, - vpc_subnets=ec2.SubnetSelection( - subnets=self.cluster.private_subnets - ) - ) - - stream_processing_lambda.add_event_source(lambda_event_sources.KinesisEventSource( - self.kinesis_stream, - batch_size=100, - starting_position=lambda_.StartingPosition.LATEST - )) - - def build_security_group(self): - self.security_group = OpenSearchSecurityGroup( - context=self.context, - name=f'{self.module_id}-opensearch-security-group', - scope=self.stack, - vpc=self.cluster.vpc - ) - - def check_service_linked_role_exists(self) -> bool: - try: - aws_dns_suffix = self.context.config().get_string('cluster.aws.dns_suffix', required=True) - list_roles_result = self.context.aws().iam().list_roles( - PathPrefix=f'/aws-service-role/es.{aws_dns_suffix}') - roles = Utils.get_value_as_list('Roles', list_roles_result, default=[]) - - list_roles_result = self.context.aws().iam().list_roles( - PathPrefix=f'/aws-service-role/opensearchservice.{aws_dns_suffix}') - roles.extend(Utils.get_value_as_list('Roles', list_roles_result, default=[])) - - return Utils.is_not_empty(roles) - except Exception as e: - self.context.aws_util().handle_aws_exception(e) - - def build_opensearch(self): - create_service_linked_role = not self.check_service_linked_role_exists() - - data_nodes = self.context.config().get_int('analytics.opensearch.data_nodes', required=True) - data_node_instance_type = self.context.config().get_string('analytics.opensearch.data_node_instance_type', required=True) - ebs_volume_size = self.context.config().get_int('analytics.opensearch.ebs_volume_size', required=True) - node_to_node_encryption = self.context.config().get_bool('analytics.opensearch.node_to_node_encryption', required=True) - removal_policy = self.context.config().get_string('analytics.opensearch.removal_policy', required=True) - app_log_removal_policy = self.context.config().get_string('analytics.opensearch.logging.app_log_removal_policy', default='DESTROY') - search_log_removal_policy = self.context.config().get_string('analytics.opensearch.logging.search_log_removal_policy', default='DESTROY') - slow_index_log_removal_policy = self.context.config().get_string('analytics.opensearch.logging.slow_index_log_removal_policy', default='DESTROY') - kms_key_id = self.context.config().get_string('analytics.opensearch.kms_key_id') - kms_key_arn = None - if kms_key_id is not None: - kms_key_arn = kms.Key.from_key_arn(self.stack, 'opensearch-kms-key', self.get_kms_key_arn(key_id=kms_key_id)) - - self.opensearch = OpenSearch( - context=self.context, - name='analytics', - scope=self.stack, - cluster=self.cluster, - security_groups=[self.security_group], - data_nodes=data_nodes, - data_node_instance_type=data_node_instance_type, - ebs_volume_size=ebs_volume_size, - removal_policy=cdk.RemovalPolicy(removal_policy), - node_to_node_encryption=node_to_node_encryption, - kms_key_arn=kms_key_arn, - create_service_linked_role=create_service_linked_role, - logging=opensearch.LoggingOptions( - slow_search_log_enabled=self.context.config().get_bool('analytics.opensearch.logging.slow_search_log_enabled', required=True), - slow_search_log_group=logs.LogGroup( - scope=self.stack, - id='analytics-search-log-group', - log_group_name=f'/{self.cluster_name}/{self.module_id}/search-log', - removal_policy=cdk.RemovalPolicy(search_log_removal_policy) - ), - app_log_enabled=self.context.config().get_bool('analytics.opensearch.logging.app_log_enabled', required=True), - app_log_group=logs.LogGroup( - scope=self.stack, - id='analytics-app-log-group', - log_group_name=f'/{self.cluster_name}/{self.module_id}/app-log', - removal_policy=cdk.RemovalPolicy(app_log_removal_policy) - ), - slow_index_log_enabled=self.context.config().get_bool('analytics.opensearch.logging.slow_index_log_enabled', required=True), - slow_index_log_group=logs.LogGroup( - scope=self.stack, - id='analytics-slow-index-log-group', - log_group_name=f'/{self.cluster_name}/{self.module_id}/slow-index-log', - removal_policy=cdk.RemovalPolicy(slow_index_log_removal_policy) - ), - # Audit logs are not enabled by default and not supported as this setting require fine-grained access permissions. - # Manually provision OpenSearch cluster to enable audit logs and use existing resources flow to use the OpenSearch cluster - audit_log_enabled=False - )) - - def build_dashboard_endpoints(self): - - cluster_endpoints_lambda_arn = self.context.config().get_string('cluster.cluster_endpoints_lambda_arn', required=True) - external_https_listener_arn = self.context.config().get_string('cluster.load_balancers.external_alb.https_listener_arn', required=True) - dashboard_endpoint_path_patterns = self.context.config().get_list('analytics.opensearch.endpoints.external.path_patterns', required=True) - dashboard_endpoint_priority = self.context.config().get_int('analytics.opensearch.endpoints.external.priority', required=True) - - is_existing = self.context.config().get_bool('analytics.opensearch.use_existing', default=False) - if is_existing: - domain_name = self.opensearch.domain_name - if domain_name.startswith('vpc-'): - # existing lookup returns the domain name as vpc-idea-dev1-analytics - # when used in describe_domain, service returns error - Domain not found: vpc-idea-dev1-analytics - # hack - fix to replace vpc- and then perform look up - domain_name = domain_name.replace('vpc-', '', 1) - describe_domain_result = self.context.aws().opensearch().describe_domain(DomainName=domain_name) - domain_status = describe_domain_result['DomainStatus'] - domain_cluster_config = domain_status['ClusterConfig'] - data_nodes = domain_cluster_config['InstanceCount'] - else: - domain_name = self.opensearch.domain_name - data_nodes = self.context.config().get_int('analytics.opensearch.data_nodes', required=True) - - opensearch_private_ips = CustomResource( - context=self.context, - name='opensearch-private-ips', - scope=self.stack, - idea_code_asset=IdeaCodeAsset( - lambda_package_name='idea_custom_resource_opensearch_private_ips', - lambda_platform=SupportedLambdaPlatforms.PYTHON - ), - lambda_timeout_seconds=180, - policy_template_name='custom-resource-opensearch-private-ips.yml', - resource_type='OpenSearchPrivateIPAddresses' - ).invoke( - name='opensearch-private-ips', - properties={ - 'DomainName': domain_name - } - ) - - targets = [] - for i in range(data_nodes * 3): - targets.append(elbv2.CfnTargetGroup.TargetDescriptionProperty( - id=cdk.Fn.select( - index=i, - array=cdk.Fn.split( - delimiter=',', - source=opensearch_private_ips.get_att_string('IpAddresses') - ) - ) - )) - - dashboard_target_group = elbv2.CfnTargetGroup( - self.stack, - f'{self.cluster_name}-dashboard-target-group', - port=443, - protocol='HTTPS', - target_type='ip', - vpc_id=self.cluster.vpc.vpc_id, - name=self.get_target_group_name('dashboard'), - targets=targets, - health_check_path='/' - ) - - cdk.CustomResource( - self.stack, - 'dashboard-endpoint', - service_token=cluster_endpoints_lambda_arn, - properties={ - 'endpoint_name': f'{self.module_id}-dashboard-endpoint', - 'listener_arn': external_https_listener_arn, - 'priority': dashboard_endpoint_priority, - 'target_group_arn': dashboard_target_group.ref, - 'conditions': [ - { - 'Field': 'path-pattern', - 'Values': dashboard_endpoint_path_patterns - } - ], - 'actions': [ - { - 'Type': 'forward', - 'TargetGroupArn': dashboard_target_group.ref - } - ], - 'tags': { - constants.IDEA_TAG_ENVIRONMENT_NAME: self.cluster_name, - constants.IDEA_TAG_MODULE_ID: self.module_id, - constants.IDEA_TAG_MODULE_NAME: constants.MODULE_ANALYTICS - } - }, - resource_type='Custom::DashboardEndpointExternal' - ) - - def build_cluster_settings(self): - - cluster_settings = { - 'deployment_id': self.deployment_id, - 'opensearch.domain_name': self.opensearch.domain_name, - 'opensearch.domain_arn': self.opensearch.domain_arn, - 'opensearch.domain_endpoint': self.opensearch.domain_endpoint, - 'opensearch.dashboard_endpoint': f'{self.opensearch.domain_endpoint}/_dashboards', - 'kinesis.stream_name': self.kinesis_stream.stream_name, - 'kinesis.stream_arn': self.kinesis_stream.stream_arn - } - - if self.security_group is not None: - cluster_settings['opensearch.security_group_id'] = self.security_group.security_group_id - - self.update_cluster_settings(cluster_settings) diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/base_stack.py b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/base_stack.py index 4a2471a..92c9133 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/base_stack.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/base_stack.py @@ -134,24 +134,3 @@ def lookup_user_pool(self) -> cognito.IUserPool: f'{self.cluster_name}-user-pool', self.context.config().get_string('identity-provider.cognito.user_pool_id', required=True) ) - - def build_access_control_groups(self, user_pool: cognito.IUserPool): - group_name_helper = GroupNameHelper(self.context) - # module administrators group - cognito.CfnUserPoolGroup( - scope=self.stack, - id=f'{self.module_id}-administrators-group', - description=f'Module administrators group for module id: {self.module_id}, cluster: {self.cluster_name}', - group_name=group_name_helper.get_module_administrators_group(self.module_id), - precedence=3, - user_pool_id=user_pool.user_pool_id - ) - # module users group - cognito.CfnUserPoolGroup( - scope=self.stack, - id=f'{self.module_id}-users-group', - description=f'Module user group for module id: {self.module_id}, cluster: {self.cluster_name}', - group_name=group_name_helper.get_module_users_group(self.module_id), - precedence=4, - user_pool_id=user_pool.user_pool_id - ) diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/cluster_manager_stack.py b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/cluster_manager_stack.py index 2a2dad9..2518a52 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/cluster_manager_stack.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/cluster_manager_stack.py @@ -91,7 +91,6 @@ def __init__(self, scope: constructs.Construct, self.user_pool = self.lookup_user_pool() self.build_oauth2_client() - self.build_access_control_groups(user_pool=self.user_pool) self.build_sqs_queues() self.build_iam_roles() self.build_security_groups() @@ -246,6 +245,7 @@ def build_auto_scaling_group(self): min_capacity = self.context.config().get_int('cluster-manager.ec2.autoscaling.min_capacity', default=1) max_capacity = self.context.config().get_int('cluster-manager.ec2.autoscaling.max_capacity', default=3) cooldown_minutes = self.context.config().get_int('cluster-manager.ec2.autoscaling.cooldown_minutes', default=5) + default_instance_warmup = self.context.config().get_int('cluster-manager.ec2.autoscaling.default_instance_warmup', default=15) new_instances_protected_from_scale_in = self.context.config().get_bool('cluster-manager.ec2.autoscaling.new_instances_protected_from_scale_in', default=True) elb_healthcheck_grace_time_minutes = self.context.config().get_int('cluster-manager.ec2.autoscaling.elb_healthcheck.grace_time_minutes', default=15) scaling_policy_target_utilization_percent = self.context.config().get_int('cluster-manager.ec2.autoscaling.cpu_utilization_scaling_policy.target_utilization_percent', default=80) @@ -324,6 +324,7 @@ def build_auto_scaling_group(self): min_capacity=min_capacity, max_capacity=max_capacity, new_instances_protected_from_scale_in=new_instances_protected_from_scale_in, + default_instance_warmup=cdk.Duration.minutes(default_instance_warmup), cooldown=cdk.Duration.minutes(cooldown_minutes), health_check=asg.HealthCheck.elb( grace=cdk.Duration.minutes(elb_healthcheck_grace_time_minutes) diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/cluster_stack.py b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/cluster_stack.py index 64780c5..444157e 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/cluster_stack.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/cluster_stack.py @@ -20,6 +20,7 @@ from ideaadministrator.app.cdk.constructs import ( Vpc, ExistingVpc, + SubnetFilterKeys, VpcGatewayEndpoint, VpcInterfaceEndpoint, PrivateHostedZone, @@ -195,15 +196,15 @@ def vpc(self) -> ec2.IVpc: else: return self._vpc - def private_subnets(self) -> List[ec2.ISubnet]: + def private_subnets(self, subnet_filter_key: Optional[SubnetFilterKeys] = None) -> List[ec2.ISubnet]: if self.context.config().get_bool('cluster.network.use_existing_vpc', False): - return self._existing_vpc.get_private_subnets() + return self._existing_vpc.get_private_subnets(subnet_filter_key) else: return self.vpc.private_subnets - def public_subnets(self) -> List[ec2.ISubnet]: + def public_subnets(self, subnet_filter_key: Optional[SubnetFilterKeys] = None) -> List[ec2.ISubnet]: if self.context.config().get_bool('cluster.network.use_existing_vpc', False): - return self._existing_vpc.get_public_subnets() + return self._existing_vpc.get_public_subnets(subnet_filter_key) else: return self.vpc.public_subnets @@ -858,7 +859,7 @@ def build_cluster_endpoints(self): # external ALB - can be deployed in public or private subnets is_public = self.context.config().get_bool('cluster.load_balancers.external_alb.public', default=True) - external_alb_subnets = self.public_subnets() if is_public is True else self.private_subnets() + external_alb_subnets = self.public_subnets(SubnetFilterKeys.LOAD_BALANCER_SUBNETS) if is_public is True else self.private_subnets(SubnetFilterKeys.LOAD_BALANCER_SUBNETS) self.external_alb = elbv2.ApplicationLoadBalancer( self.stack, f'{self.cluster_name}-external-alb', @@ -934,7 +935,7 @@ def build_cluster_endpoints(self): self.external_alb, 'https-listener', port=443, - ssl_policy=self.context.config().get_string('cluster.load_balancers.external_alb.ssl_policy', default='ELBSecurityPolicy-FS-1-2-Res-2020-10'), + ssl_policy=self.context.config().get_string('cluster.load_balancers.external_alb.ssl_policy', default='ELBSecurityPolicy-TLS13-1-2-2021-06'), load_balancer_arn=self.external_alb.load_balancer_arn, protocol='HTTPS', certificates=[ @@ -954,7 +955,7 @@ def build_cluster_endpoints(self): self.internal_alb, 'https-listener', port=443, - ssl_policy=self.context.config().get_string('cluster.load_balancers.internal_alb.ssl_policy', default='ELBSecurityPolicy-FS-1-2-Res-2020-10'), + ssl_policy=self.context.config().get_string('cluster.load_balancers.internal_alb.ssl_policy', default='ELBSecurityPolicy-TLS13-1-2-2021-06'), load_balancer_arn=self.internal_alb.load_balancer_arn, protocol='HTTPS', certificates=[ @@ -1005,7 +1006,7 @@ def build_cluster_endpoints(self): self.internal_alb, 'dcv-broker-client-listener', port=dcv_broker_client_communication_port, - ssl_policy=self.context.config().get_string('virtual-desktop-controller.dcv_broker.ssl_policy', default='ELBSecurityPolicy-FS-1-2-Res-2020-10'), + ssl_policy=self.context.config().get_string('virtual-desktop-controller.dcv_broker.ssl_policy', default='ELBSecurityPolicy-TLS13-1-2-2021-06'), load_balancer_arn=self.internal_alb.load_balancer_arn, protocol='HTTPS', certificates=[ @@ -1029,7 +1030,7 @@ def build_cluster_endpoints(self): self.internal_alb, 'dcv-broker-agent-listener', port=dcv_broker_agent_communication_port, - ssl_policy=self.context.config().get_string('virtual-desktop-controller.dcv_broker.ssl_policy', default='ELBSecurityPolicy-FS-1-2-Res-2020-10'), + ssl_policy=self.context.config().get_string('virtual-desktop-controller.dcv_broker.ssl_policy', default='ELBSecurityPolicy-TLS13-1-2-2021-06'), load_balancer_arn=self.internal_alb.load_balancer_arn, protocol='HTTPS', certificates=[ @@ -1053,7 +1054,7 @@ def build_cluster_endpoints(self): self.internal_alb, 'dcv-broker-gateway-listener', port=dcv_broker_gateway_communication_port, - ssl_policy=self.context.config().get_string('virtual-desktop-controller.dcv_broker.ssl_policy', default='ELBSecurityPolicy-FS-1-2-Res-2020-10'), + ssl_policy=self.context.config().get_string('virtual-desktop-controller.dcv_broker.ssl_policy', default='ELBSecurityPolicy-TLS13-1-2-2021-06'), load_balancer_arn=self.internal_alb.load_balancer_arn, protocol='HTTPS', certificates=[ @@ -1087,6 +1088,14 @@ def build_cluster_settings(self): public_subnets.append(subnet.subnet_id) cluster_settings['network.public_subnets'] = public_subnets + load_balancer_subnets = self.context.config().get_list("cluster.network.load_balancer_subnets", []) + if Utils.is_not_empty(load_balancer_subnets): + cluster_settings['network.load_balancer_subnets'] = load_balancer_subnets + + infrastructure_host_subnets = self.context.config().get_list("cluster.network.infrastructure_host_subnets", []) + if Utils.is_not_empty(load_balancer_subnets): + cluster_settings['network.infrastructure_host_subnets'] = infrastructure_host_subnets + private_subnets = self.context.config().get_list('cluster.network.private_subnets', []) if Utils.is_empty(private_subnets): for subnet in self.vpc.private_subnets: diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/scheduler_stack.py b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/scheduler_stack.py index f54f684..9b6581b 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/scheduler_stack.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/scheduler_stack.py @@ -97,7 +97,6 @@ def __init__(self, scope: constructs.Construct, self.user_pool = self.lookup_user_pool() self.build_oauth2_client() - self.build_access_control_groups(user_pool=self.user_pool) self.build_sqs_queue() self.build_iam_roles() self.build_security_groups() diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/virtual_desktop_controller_stack.py b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/virtual_desktop_controller_stack.py index d13c3b9..e2987c4 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/virtual_desktop_controller_stack.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/cdk/stacks/virtual_desktop_controller_stack.py @@ -13,6 +13,7 @@ from ideaadministrator.app.cdk.idea_code_asset import IdeaCodeAsset, SupportedLambdaPlatforms from ideaadministrator.app.cdk.constructs import ( ExistingSocaCluster, + SubnetFilterKeys, OAuthClientIdAndSecret, Role, InstanceProfile, @@ -138,8 +139,9 @@ def __init__(self, scope: constructs.Construct, cluster_name: str, aws_region: s self.user_pool = self.lookup_user_pool() + self.resource_server = None self.build_oauth2_client() - self.build_access_control_groups(user_pool=self.user_pool) + self.update_cluster_manager_client_scopes() self.build_sqs_queues() self.build_scheduled_event_notification_infra() @@ -187,12 +189,12 @@ def build_sqs_queues(self): resources=["*"]), ] ) - + self.sqs_kms_key = kms.Key(self.stack, "res-sqs-kms", enable_key_rotation=True, policy=sqs_kms_key_policy ) - + self.sqs_kms_key.add_alias(f'{self.cluster_name}/sqs') self.event_sqs_queue = SQSQueue( @@ -334,7 +336,7 @@ def build_scheduled_event_notification_infra(self): def build_oauth2_client(self): # add resource server - resource_server = self.user_pool.add_resource_server( + self.resource_server = self.user_pool.add_resource_server( id='resource-server', identifier=self.module_id, scopes=[ @@ -374,7 +376,7 @@ def build_oauth2_client(self): user_pool_client_name=self.module_id ) client.node.add_dependency(session_manager_resource_server) - client.node.add_dependency(resource_server) + client.node.add_dependency(self.resource_server) # read secret value by invoking custom resource oauth_credentials_lambda_arn = self.context.config().get_string('identity-provider.cognito.oauth_credentials_lambda_arn', required=True) @@ -399,6 +401,56 @@ def build_oauth2_client(self): client_secret=client_secret.get_att_string('ClientSecret') ) + def update_cluster_manager_client_scopes(self): + """ + Allow the cluster manager client to access VDC APIs via the VDC scopes + :return: + """ + lambda_name = f'{self.module_id}-update-cluster-manager-client-scope' + update_cluster_manager_client_scope_lambda_role = Role( + context=self.context, + name=f'{lambda_name}-role', + scope=self.stack, + assumed_by=['lambda'], + description=f'{lambda_name}-role' + ) + + update_cluster_manager_client_scope_lambda_role.attach_inline_policy(Policy( + context=self.context, + name=f'{lambda_name}-policy', + scope=self.stack, + policy_template_name='add-to-user-pool-client-scopes.yml' + )) + update_cluster_manager_client_scope_lambda = LambdaFunction( + context=self.context, + name=lambda_name, + scope=self.stack, + idea_code_asset=IdeaCodeAsset( + lambda_package_name='add_to_user_pool_client_scopes', + lambda_platform=SupportedLambdaPlatforms.PYTHON + ), + description='Update cluster manager client scope for accessing vdc', + timeout_seconds=180, + role=update_cluster_manager_client_scope_lambda_role, + ) + update_cluster_manager_client_scope_lambda.node.add_dependency(self.resource_server) + + cdk.CustomResource( + self.stack, + f'{self.cluster_name}-update-client-scope', + service_token=update_cluster_manager_client_scope_lambda.function_arn, + properties={ + 'cluster_name': self.cluster_name, + 'module_id': self.context.config().get_module_id(constants.MODULE_CLUSTER_MANAGER), + 'user_pool_id': self.user_pool.user_pool_id, + 'o_auth_scopes_to_add': [ + f'{self.context.config().get_module_id(constants.MODULE_VIRTUAL_DESKTOP_CONTROLLER)}/read', + f'{self.context.config().get_module_id(constants.MODULE_VIRTUAL_DESKTOP_CONTROLLER)}/write', + ] + }, + resource_type='Custom::UpdateClusterManagerClient' + ) + def build_dcv_host_infra(self): self.dcv_host_role = self._build_iam_role( role_description=f'IAM role assigned to virtual-desktop-{self.COMPONENT_DCV_HOST}', @@ -835,6 +887,7 @@ def _build_auto_scaling_group(self, component_name: str, security_group: Securit max_capacity=self.context.config().get_int(f'virtual-desktop-controller.{self.CONFIG_MAPPING[component_name]}.autoscaling.max_capacity', default=3), new_instances_protected_from_scale_in=self.context.config().get_bool(f'virtual-desktop-controller.{self.CONFIG_MAPPING[component_name]}.autoscaling.new_instances_protected_from_scale_in', default=True), cooldown=cdk.Duration.minutes(self.context.config().get_int(f'virtual-desktop-controller.{self.CONFIG_MAPPING[component_name]}.autoscaling.cooldown_minutes', default=5)), + default_instance_warmup=cdk.Duration.minutes(self.context.config().get_int(f'virtual-desktop-controller.{self.CONFIG_MAPPING[component_name]}.autoscaling.default_instance_warmup', default=15)), health_check=asg.HealthCheck.elb( grace=cdk.Duration.minutes(self.context.config().get_int(f'virtual-desktop-controller.{self.CONFIG_MAPPING[component_name]}.autoscaling.elb_healthcheck.grace_time_minutes', default=15)) ), @@ -940,7 +993,7 @@ def _build_dcv_connection_gateway_instance_infrastructure(self): def _build_dcv_connection_gateway_network_infrastructure(self): is_public = self.context.config().get_bool('cluster.load_balancers.external_alb.public', default=True) - external_nlb_subnets = self.cluster.public_subnets if is_public is True else self.cluster.private_subnets + external_nlb_subnets = self.cluster.existing_vpc.get_public_subnets(SubnetFilterKeys.LOAD_BALANCER_SUBNETS) if is_public is True else self.cluster.existing_vpc.get_private_subnets(SubnetFilterKeys.LOAD_BALANCER_SUBNETS) self.external_nlb = elbv2.NetworkLoadBalancer( self.stack, f'{self.cluster_name}-{self.module_id}-external-nlb', diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/config_generator.py b/source/idea/idea-administrator/src/ideaadministrator/app/config_generator.py index de61ad4..98a54b8 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/config_generator.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/config_generator.py @@ -262,18 +262,48 @@ def get_vpc_id(self) -> Optional[str]: return value def get_private_subnet_ids(self) -> Optional[str]: - private_subnet_ids = Utils.get_value_as_list('private_subnet_ids', self.user_values, []) - public_subnet_ids = Utils.get_value_as_list('public_subnet_ids', self.user_values, []) - if self.get_use_existing_vpc() and Utils.is_empty(private_subnet_ids) and Utils.is_empty(public_subnet_ids): + # by default all infrastructure host subnets and dcv session subnets are private + dcv_session_private_subnet_ids = self.user_values.get('dcv_session_private_subnet_ids', []) + infrastructure_host_subnet_ids = self.user_values.get('infrastructure_host_subnet_ids', []) + + private_subnet_ids = list(set([*infrastructure_host_subnet_ids, *dcv_session_private_subnet_ids])) + # if alb is not public, the external load balancer subnets are also considered private + if not self.get_alb_public: + load_balancer_subnet_ids = self.user_values.get('load_balancer_subnet_ids', []) + private_subnet_ids = list(set([*private_subnet_ids, *load_balancer_subnet_ids])) + + if self.get_use_existing_vpc() and Utils.is_empty(private_subnet_ids): raise exceptions.invalid_params('private_subnet_ids is required when use_existing_vpc = True') return private_subnet_ids def get_public_subnet_ids(self) -> Optional[str]: - private_subnet_ids = Utils.get_value_as_list('private_subnet_ids', self.user_values, []) - public_subnet_ids = Utils.get_value_as_list('public_subnet_ids', self.user_values, []) - if self.get_use_existing_vpc() and Utils.is_empty(private_subnet_ids) and Utils.is_empty(public_subnet_ids): + public_subnet_ids = [] + # if alb is public, only the load_balancer_subnet_ids are public subnets + if self.get_alb_public(): + load_balancer_subnet_ids = self.user_values.get('load_balancer_subnet_ids', []) + public_subnet_ids = load_balancer_subnet_ids + + if self.get_use_existing_vpc() and self.get_alb_public() and Utils.is_empty(public_subnet_ids): raise exceptions.invalid_params('public_subnet_ids is required when use_existing_vpc = True') return public_subnet_ids + + def get_load_balancer_subnet_ids(self) -> Optional[str]: + load_balancer_subnet_ids = self.user_values.get('load_balancer_subnet_ids', []) + if self.get_use_existing_vpc() and Utils.is_empty(load_balancer_subnet_ids): + raise exceptions.invalid_params('load_balancer_subnet_ids is required when use_existing_vpc = True') + return load_balancer_subnet_ids + + def get_infrastructure_host_subnet_ids(self) -> Optional[str]: + infrastructure_host_subnet_ids = self.user_values.get('infrastructure_host_subnet_ids', []) + if self.get_use_existing_vpc() and Utils.is_empty(infrastructure_host_subnet_ids): + raise exceptions.invalid_params('infrastructure_host_subnet_ids is required when use_existing_vpc = True') + return infrastructure_host_subnet_ids + + def get_dcv_session_private_subnet_ids(self) -> Optional[str]: + dcv_session_private_subnet_ids = self.user_values.get('dcv_session_private_subnet_ids', []) + if self.get_use_existing_vpc() and Utils.is_empty(dcv_session_private_subnet_ids): + raise exceptions.invalid_params('dcv_session_private_subnet_ids is required when use_existing_vpc = True') + return dcv_session_private_subnet_ids def get_use_existing_internal_fs(self) -> Optional[bool]: value = Utils.get_value_as_bool('use_existing_internal_fs', self.user_values, False) @@ -299,24 +329,6 @@ def get_existing_home_fs_id(self) -> Optional[str]: raise exceptions.invalid_params('existing_home_fs_id is required when use_existing_home_fs = True') return value - def get_use_existing_opensearch_cluster(self) -> Optional[str]: - value = Utils.get_value_as_bool('use_existing_opensearch_cluster', self.user_values, False) - if value and not self.get_use_existing_vpc(): - raise exceptions.invalid_params('use_existing_opensearch_cluster cannot be True if use_existing_vpc = False') - return value - - def get_opensearch_domain_arn(self) -> Optional[str]: - value = Utils.get_value_as_string('opensearch_domain_arn', self.user_values) - if Utils.is_empty(value) and self.get_use_existing_opensearch_cluster(): - raise exceptions.invalid_params('opensearch_domain_arn is required when use_existing_opensearch_cluster = True') - return value - - def get_opensearch_domain_endpoint(self) -> Optional[str]: - value = Utils.get_value_as_string('opensearch_domain_endpoint', self.user_values) - if Utils.is_empty(value) and self.get_use_existing_opensearch_cluster(): - raise exceptions.invalid_params('opensearch_domain_endpoint is required when use_existing_opensearch_cluster = True') - return value - def get_alb_custom_certificate_provided(self) -> Optional[bool]: return Utils.get_value_as_bool('alb_custom_certificate_provided', self.user_values, default=False) @@ -431,12 +443,13 @@ def generate_config_from_templates(self, temp=False, path=None): 'vpc_id': self.get_vpc_id(), 'private_subnet_ids': self.get_private_subnet_ids(), 'public_subnet_ids': self.get_public_subnet_ids(), + 'load_balancer_subnet_ids': self.get_load_balancer_subnet_ids(), + 'infrastructure_host_subnet_ids': self.get_infrastructure_host_subnet_ids(), + 'dcv_session_private_subnet_ids': self.get_dcv_session_private_subnet_ids(), 'use_existing_internal_fs': self.get_use_existing_internal_fs(), 'existing_internal_fs_id': self.get_existing_internal_fs_id(), 'use_existing_home_fs': self.get_use_existing_home_fs(), 'existing_home_fs_id': self.get_existing_home_fs_id(), - 'use_existing_opensearch_cluster': self.get_use_existing_opensearch_cluster(), - 'opensearch_domain_endpoint': self.get_opensearch_domain_endpoint(), 'alb_public': self.get_alb_public(), 'alb_custom_certificate_provided': self.get_alb_custom_certificate_provided(), 'alb_custom_certificate_acm_certificate_arn': self.get_alb_custom_certificate_acm_certificate_arn(), diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/delete_cluster.py b/source/idea/idea-administrator/src/ideaadministrator/app/delete_cluster.py index 9bd9ed7..0b802d0 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/delete_cluster.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/delete_cluster.py @@ -14,7 +14,7 @@ from ideasdk.utils import Utils from ideadatamodel import constants, exceptions, errorcodes, EC2Instance, SocaMemory, SocaMemoryUnit -from typing import Optional, List +from typing import Optional, List, Mapping from prettytable import PrettyTable import time import botocore.exceptions @@ -100,6 +100,9 @@ def find_ec2_instances(self): if ec2_instance.state == 'terminated': continue + if ec2_instance.get_tag(constants.BI_TAG_DEPLOYMENT) == "true": + continue + # check termination protection instances describe_instance_attribute_result = self.context.aws().ec2().describe_instance_attribute( Attribute='disableApiTermination', @@ -280,6 +283,9 @@ def find_cloud_formation_stacks(self): if self.is_bootstrap_stack(stack_name): continue + if self.is_batteries_included_stack(stack): + continue + if self.is_cluster_stack(stack_name): cluster_stacks.append(stack) elif self.is_identity_provider_stack(stack_name): @@ -366,6 +372,14 @@ def delete_dynamo_table(self, table_name: str): def is_bootstrap_stack(self, stack_name: str) -> bool: return stack_name == self.get_bootstrap_stack_name() + def is_batteries_included_stack(self, stack: Mapping) -> bool: + stack_tags = Utils.get_value_as_list("Tags", stack) + for tag in stack_tags: + tag_dict = Utils.get_as_dict(tag, default={}) + if tag_dict["Key"] == constants.BI_TAG_DEPLOYMENT and tag_dict["Value"] == "true": + return True + return False + def is_cluster_stack(self, stack_name: str) -> bool: for module in self.cluster_modules: module_name = module['name'] @@ -377,17 +391,6 @@ def is_cluster_stack(self, stack_name: str) -> bool: return True return False - def is_analytics_stack(self, stack_name: str) -> bool: - for module in self.cluster_modules: - module_name = module['name'] - if module_name == constants.MODULE_ANALYTICS: - analytics_stack_name = module['stack_name'] - if analytics_stack_name == stack_name: - return True - if stack_name == f'{self.cluster_name}-analytics': - return True - return False - def is_identity_provider_stack(self, stack_name: str) -> bool: for module in self.cluster_modules: module_name = module['name'] @@ -450,8 +453,6 @@ def check_stack_deletion_status(self, stack_names: List[str]) -> bool: stacks_deleted.append(stack_name) elif stack_status == 'DELETE_FAILED': if self.delete_failed_attempt < self.delete_failed_max_attempts: - if self.is_analytics_stack(stack_name): - self.try_delete_vpc_lambda_enis() self.context.warning(f'stack: {stack_name}, status: {stack_status}, submitting a new delete_cloud_formation_stack request. [Loop {self.delete_failed_attempt}/{self.delete_failed_max_attempts}]') self.delete_cloud_formation_stack(stack_name) self.delete_failed_attempt += 1 @@ -461,8 +462,6 @@ def check_stack_deletion_status(self, stack_names: List[str]) -> bool: delete_failed += 1 else: print(f'stack: {stack_name}, status: {stack_status}') - if self.is_analytics_stack(stack_name): - self.try_delete_vpc_lambda_enis() except botocore.exceptions.ClientError as e: if e.response['Error']['Code'] == 'ValidationError': diff --git a/source/idea/idea-administrator/src/ideaadministrator/app/installer_params.py b/source/idea/idea-administrator/src/ideaadministrator/app/installer_params.py index e529e2a..64c2245 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app/installer_params.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app/installer_params.py @@ -23,7 +23,6 @@ 'ExistingResourcesPrompt', 'SubnetIdsPrompt', 'FileSystemIdPrompt', - 'OpenSearchDomainEndpointPrompt', 'DirectoryServiceIdPrompt', 'PrefixListIdPrompt', 'EnabledModulesPrompt', @@ -379,7 +378,6 @@ def __init__(self, factory: 'InstallerPromptFactory'): param=factory.args.get_meta('existing_resources'), default=self.get_default) self.existing_file_systems = False - self.existing_opensearch = False self.existing_subnets = False self.existing_directories = False @@ -393,8 +391,6 @@ def get_default(self, reset: bool = False) -> Optional[List[str]]: if self.existing_file_systems: result.append('shared-storage:internal') result.append('shared-storage:home') - if self.existing_opensearch: - result.append('analytics:opensearch') if self.existing_directories: result.append('directoryservice:aws_managed_activedirectory') return result @@ -429,9 +425,6 @@ def get_choices(self, refresh: bool = False) -> List[SocaUserInputChoice]: with self.context.spinner('search for existing file systems ...'): file_systems = self.context.get_aws_resources().get_file_systems(vpc_id=vpc_id, refresh=refresh) self.existing_file_systems = len(file_systems) > 0 - with self.context.spinner('search for existing opensearch clusters ...'): - opensearch_clusters = self.context.get_aws_resources().get_opensearch_clusters(vpc_id=vpc_id, refresh=refresh) - self.existing_opensearch = len(opensearch_clusters) > 0 with self.context.spinner('search for existing directories ...'): directories = self.context.get_aws_resources().get_directories(vpc_id=vpc_id, refresh=refresh) directory_service_provider = self.args.get('directory_service_provider') @@ -457,11 +450,6 @@ def get_choices(self, refresh: bool = False) -> List[SocaUserInputChoice]: title='Shared Storage: Home', value='shared-storage:home' )) - if self.existing_opensearch: - choices.append(SocaUserInputChoice( - title='Analytics: OpenSearch Clusters', - value='analytics:opensearch' - )) if self.existing_directories: choices.append(SocaUserInputChoice( title='Directory: AWS Managed Microsoft AD', @@ -681,39 +669,6 @@ def get_choices(self, refresh: bool = False) -> List[SocaUserInputChoice]: return choices - -class OpenSearchDomainEndpointPrompt(DefaultPrompt[str]): - - def __init__(self, factory: 'InstallerPromptFactory'): - super().__init__(factory=factory, - param=factory.args.get_meta('opensearch_domain_endpoint')) - - def get_choices(self, refresh: bool = False) -> List[SocaUserInputChoice]: - - vpc_id = self.args.get('vpc_id') - if Utils.is_empty(vpc_id): - raise exceptions.general_exception('vpc_id is required to find existing resources') - - opensearch_clusters = self.context.get_aws_resources().get_opensearch_clusters(vpc_id=vpc_id, refresh=refresh) - - if len(opensearch_clusters) == 0: - raise exceptions.general_exception('Unable to find any existing opensearch clusters') - - choices = [] - - for opensearch_cluster in opensearch_clusters: - choices.append(SocaUserInputChoice( - title=opensearch_cluster.title, - value=opensearch_cluster.vpc_endpoint - )) - - return choices - - def filter(self, value) -> Optional[T]: - self.args.set('use_existing_opensearch_cluster', True) - return super().filter(value) - - class DirectoryServiceIdPrompt(DefaultPrompt[str]): def __init__(self, factory: 'InstallerPromptFactory'): @@ -795,12 +750,6 @@ def get_choices(self, refresh: bool = False) -> List[SocaUserInputChoice]: checked=True, disabled=True ), - SocaUserInputChoice( - title='Analytics (required)', - value=constants.MODULE_ANALYTICS, - checked=True, - disabled=True - ), SocaUserInputChoice( title='Identity Provider (required)', value=constants.MODULE_IDENTITY_PROVIDER, @@ -1141,7 +1090,6 @@ def initialize_overrides(self): existing_storage_flag_key='use_existing_home_fs', storage_provider_key='storage_home_provider' )) - self.register(OpenSearchDomainEndpointPrompt(factory=self)) self.register(DirectoryServiceIdPrompt(factory=self)) diff --git a/source/idea/idea-administrator/src/ideaadministrator/app_main.py b/source/idea/idea-administrator/src/ideaadministrator/app_main.py index e3d2d40..cf04ec9 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/app_main.py +++ b/source/idea/idea-administrator/src/ideaadministrator/app_main.py @@ -774,9 +774,6 @@ def delete_config(cluster_name: str, aws_profile: str, aws_region: str, config_k """ delete all configuration entries for a given config key prefix. - to delete all configuration entries for module id: analytics, run: - res-admin config delete analytics. - to delete all configuration entries for alb.listener_rules.*, run: res-admin config delete alb.listener_rules. """ @@ -1053,14 +1050,7 @@ def get_status() -> bool: module_id = cluster_module['module_id'] module_name = cluster_module['name'] module_type = cluster_module['type'] - if module_name == constants.MODULE_ANALYTICS: - url = f'{cluster_endpoint}/_dashboards/' - endpoints.append({ - 'name': 'OpenSearch Service Dashboard', - 'endpoint': url, - 'check_status': check_status(url) - }) - elif module_type == constants.MODULE_TYPE_APP: + if module_type == constants.MODULE_TYPE_APP: url = f'{cluster_endpoint}/{module_id}/healthcheck' endpoints.append({ 'name': module_metadata.get_module_title(module_name), @@ -1213,12 +1203,6 @@ def show_connection_info(cluster_name: str, aws_region: str, aws_profile: str, m 'value': cluster_endpoint, 'weight': 0 }) - elif module_name == constants.MODULE_ANALYTICS: - connection_info_entries.append({ - 'key': 'Analytics Dashboard', - 'value': f'{cluster_endpoint}/_dashboards', - 'weight': 3 - }) elif module_name == constants.MODULE_BASTION_HOST: key_pair_name = cluster_config.get_string('cluster.network.ssh_key_pair') ip_address = cluster_config.get_string(f'{module_id}.public_ip') diff --git a/source/idea/idea-administrator/src/ideaadministrator/integration_tests/cluster_manager_tests.py b/source/idea/idea-administrator/src/ideaadministrator/integration_tests/cluster_manager_tests.py index 5e4b9e3..e44f93e 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/integration_tests/cluster_manager_tests.py +++ b/source/idea/idea-administrator/src/ideaadministrator/integration_tests/cluster_manager_tests.py @@ -302,6 +302,31 @@ def test_admin_disable_project(context: TestContext): assert e.error_code == 'SCHEDULER_HPC_PROJECT_NOT_FOUND' +def test_admin_delete_project(context: TestContext): + assert context.is_test_case_passed(test_constants.PROJECTS_CREATE_PROJECT) + + context.get_cluster_manager_client().invoke_alt( + namespace='Projects.DeleteProject', + payload=DeleteProjectRequest( + project_id=TEST_PROJECT_ID + ), + result_as=DeleteProjectResult, + access_token=context.get_admin_access_token() + ) + + try: + context.get_cluster_manager_client().invoke_alt( + namespace='Projects.GetProject', + payload=GetProjectRequest( + project_id=TEST_PROJECT_ID + ), + result_as=GetProjectResult, + access_token=context.get_admin_access_token() + ) + except exceptions.SocaException as e: + assert e.error_code == 'PROJECT_NOT_FOUND' + + def test_admin_disable_user(context: TestContext): context.get_cluster_manager_client().invoke_alt( namespace='Accounts.DisableUser', diff --git a/source/idea/idea-administrator/src/ideaadministrator/integration_tests/test_constants.py b/source/idea/idea-administrator/src/ideaadministrator/integration_tests/test_constants.py index 3345ada..60eefa5 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/integration_tests/test_constants.py +++ b/source/idea/idea-administrator/src/ideaadministrator/integration_tests/test_constants.py @@ -23,6 +23,7 @@ PROJECTS_UPDATE_PROJECT = 'PROJECTS_UPDATE_PROJECT' PROJECTS_LIST_PROJECTS = 'PROJECTS_LIST_PROJECTS' PROJECTS_DISABLE_PROJECT = 'PROJECTS_DISABLE_PROJECT' +PROJECTS_DELETE_PROJECT = 'PROJECTS_DELETE_PROJECT' CLUSTER_SETTINGS_UPDATE_MODULE_SETTINGS = 'CLUSTER_SETTINGS_UPDATE_MODULE_SETTINGS' SCHEDULER_ADMIN_QUEUE_PROFILES = 'SCHEDULER_ADMIN_QUEUE_PROFILES' @@ -42,13 +43,12 @@ VIRTUAL_DESKTOP_TEST_ADMIN_UPDATE_SESSION_PERMISSIONS = 'VIRTUAL_DESKTOP_TEST_ADMIN_UPDATE_SESSION_PERMISSIONS' VIRTUAL_DESKTOP_TEST_ADMIN_CREATE_SOFTWARE_STACK_FROM_SESSION = 'VIRTUAL_DESKTOP_TEST_ADMIN_CREATE_SOFTWARE_STACK_FROM_SESSION' VIRTUAL_DESKTOP_TEST_ADMIN_GET_SESSION_CONNECTION_INFO = 'VIRTUAL_DESKTOP_TEST_ADMIN_GET_SESSION_CONNECTION_INFO' -VIRTUAL_DESKTOP_TEST_ADMIN_REINDEX_USER_SESSIONS = 'VIRTUAL_DESKTOP_TEST_ADMIN_REINDEX_USER_SESSIONS' -VIRTUAL_DESKTOP_TEST_ADMIN_REINDEX_SOFTWARE_STACKS = 'VIRTUAL_DESKTOP_TEST_ADMIN_REINDEX_SOFTWARE_STACKS' VIRTUAL_DESKTOP_TEST_ADMIN_CREATE_PERMISSION_PROFILE = 'VIRTUAL_DESKTOP_TEST_ADMIN_CREATE_PERMISSION_PROFILE' VIRTUAL_DESKTOP_TEST_ADMIN_UPDATE_PERMISSION_PROFILE = 'VIRTUAL_DESKTOP_TEST_ADMIN_UPDATE_PERMISSION_PROFILE' VIRTUAL_DESKTOP_TEST_ADMIN_LIST_SESSION_PERMISSIONS = 'VIRTUAL_DESKTOP_TEST_ADMIN_LIST_SESSION_PERMISSIONS' VIRTUAL_DESKTOP_TEST_ADMIN_LIST_SHARED_PERMISSIONS = 'VIRTUAL_DESKTOP_TEST_ADMIN_LIST_SHARED_PERMISSIONS' VIRTUAL_DESKTOP_TEST_ADMIN_DELETE_SESSIONS = 'VIRTUAL_DESKTOP_TEST_ADMIN_DELETE_SESSIONS' +VIRTUAL_DESKTOP_TEST_ADMIN_DELETE_SOFTWARE_STACK = 'VIRTUAL_DESKTOP_TEST_ADMIN_DELETE_SOFTWARE_STACK' # VDC Utils VIRTUAL_DESKTOP_TEST_LIST_SUPPORTED_OS = 'VIRTUAL_DESKTOP_TEST_LIST_SUPPORTED_OS' diff --git a/source/idea/idea-administrator/src/ideaadministrator/integration_tests/test_context.py b/source/idea/idea-administrator/src/ideaadministrator/integration_tests/test_context.py index 6a190f4..b122c8c 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/integration_tests/test_context.py +++ b/source/idea/idea-administrator/src/ideaadministrator/integration_tests/test_context.py @@ -135,7 +135,7 @@ def initialize_admin_auth(self): self.idea_context.info('Initializing Admin Authentication ...') result = self._admin_http_client.invoke_alt('Auth.InitiateAuth', InitiateAuthRequest( auth_flow='USER_PASSWORD_AUTH', - username=self.admin_username, + cognito_username=self.admin_username, password=self.admin_password ), result_as=InitiateAuthResult) admin_auth = result.auth @@ -147,7 +147,7 @@ def initialize_admin_auth(self): self.idea_context.info('Renewing Admin Authentication Access Token ...') result = self._admin_http_client.invoke_alt('Auth.InitiateAuth', InitiateAuthRequest( auth_flow='REFRESH_TOKEN_AUTH', - username=self.admin_username, + cognito_username=self.admin_username, refresh_token=self.admin_auth.refresh_token ), result_as=InitiateAuthResult) @@ -165,7 +165,7 @@ def initialize_non_admin_auth(self): self.idea_context.info('Initializing Non-Admin Authentication ...') result = self._admin_http_client.invoke_alt('Auth.InitiateAuth', InitiateAuthRequest( auth_flow='USER_PASSWORD_AUTH', - username=self.non_admin_username, + cognito_username=self.non_admin_username, password=self.non_admin_password ), result_as=InitiateAuthResult) non_admin_auth = result.auth @@ -177,7 +177,7 @@ def initialize_non_admin_auth(self): self.idea_context.info('Renewing Non-Admin Authentication Access Token ...') result = self._admin_http_client.invoke_alt('Auth.InitiateAuth', InitiateAuthRequest( auth_flow='REFRESH_TOKEN_AUTH', - username=self.non_admin_username, + cognito_username=self.non_admin_username, refresh_token=self.non_admin_auth.refresh_token ), result_as=InitiateAuthResult) diff --git a/source/idea/idea-administrator/src/ideaadministrator/integration_tests/virtual_desktop_controller_tests.py b/source/idea/idea-administrator/src/ideaadministrator/integration_tests/virtual_desktop_controller_tests.py index ae0aaf2..2fa8cd1 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/integration_tests/virtual_desktop_controller_tests.py +++ b/source/idea/idea-administrator/src/ideaadministrator/integration_tests/virtual_desktop_controller_tests.py @@ -242,58 +242,6 @@ def test_admin_get_session_connection_info(context: TestContext): finally: vdc_test_helper.after_test(test_case_name, test_results_map, test_case_id) - -def test_admin_reindex_user_sessions(context: TestContext): - test_case_name = 'Test Admin Reindex User Sessions' - test_case_id = test_constants.VIRTUAL_DESKTOP_TEST_ADMIN_REINDEX_USER_SESSIONS - admin_access_token = context.get_admin_access_token() - test_results_map = SessionsTestResultMap(test_case_name) - vdc_api_helper = VirtualDesktopApiHelper(context, admin_access_token, context.admin_username) - vdc_test_helper = VirtualDesktopTestHelper(context) - - try: - vdc_test_helper.before_test(test_case_name) - - response = vdc_api_helper.reindex_user_session() - - if response is not None: - vdc_test_helper.on_test_pass(test_case_name, test_results_map) - - else: - vdc_test_helper.on_test_fail(test_case_name, response, test_results_map) - - except exceptions.SocaException as error: - vdc_test_helper.on_test_exception(test_case_name, error, test_results_map) - - finally: - vdc_test_helper.after_test(test_case_name, test_results_map, test_case_id) - - -def test_admin_reindex_software_stacks(context: TestContext): - test_case_name = 'Test Admin Reindex Software Stacks' - test_case_id = test_constants.VIRTUAL_DESKTOP_TEST_ADMIN_REINDEX_SOFTWARE_STACKS - admin_access_token = context.get_admin_access_token() - test_results_map = SessionsTestResultMap(test_case_name) - vdc_api_helper = VirtualDesktopApiHelper(context, admin_access_token, context.admin_username) - vdc_test_helper = VirtualDesktopTestHelper(context) - try: - vdc_test_helper.before_test(test_case_name) - - response = vdc_api_helper.reindex_software_stacks() - - if response is not None: - vdc_test_helper.on_test_pass(test_case_name, test_results_map) - - else: - vdc_test_helper.on_test_fail(test_case_name, response, test_results_map) - - except exceptions.SocaException as error: - vdc_test_helper.on_test_exception(test_case_name, error, test_results_map) - - finally: - vdc_test_helper.after_test(test_case_name, test_results_map, test_case_id) - - def test_admin_create_permission_profile(context: TestContext): test_case_name = 'Test Admin Create Permission Profile' test_case_id = test_constants.VIRTUAL_DESKTOP_TEST_ADMIN_CREATE_PERMISSION_PROFILE @@ -850,6 +798,33 @@ def test_describe_sessions(context: TestContext): vdc_test_helper.after_test(test_case_name, test_results_map, test_case_id) +def test_admin_delete_software_stack(context: TestContext): + test_case_name = 'Test Admin Delete Software Stack' + test_case_id = test_constants.VIRTUAL_DESKTOP_TEST_ADMIN_DELETE_SOFTWARE_STACK + admin_access_token = context.get_admin_access_token() + test_results_map = SessionsTestResultMap(test_case_name) + vdc_test_helper = VirtualDesktopTestHelper(context) + try: + vdc_test_helper.before_test(test_case_name) + + if vdc_test_helper.is_new_session_created(): + session_helper = SessionsTestHelper(context, vdc_test_helper.get_new_session(), context.admin_username, admin_access_token) + response = session_helper.delete_software_stack(vdc_test_helper.get_new_software_stack()) + + vdc_test_helper.on_test_pass(test_case_name, test_results_map) + + else: + testcase_error_message = f'Created session is None. Skipping {test_case_name}. ' + test_results_map.update_test_result_map(VirtualDesktopSessionTestResults.FAILED, testcase_error_message) + context.error(testcase_error_message) + + except exceptions.SocaException as error: + vdc_test_helper.on_test_exception(test_case_name, error, test_results_map) + + finally: + vdc_test_helper.after_test(test_case_name, test_results_map, test_case_id) + + def test_admin_delete_sessions(context: TestContext): test_case_name = 'Test Admin Delete Sessions' test_case_id = test_constants.VIRTUAL_DESKTOP_TEST_ADMIN_DELETE_SESSIONS @@ -1266,14 +1241,6 @@ def test_user_list_shared_permissions(context: TestContext): 'test_case_id': test_constants.VIRTUAL_DESKTOP_TEST_ADMIN_GET_SESSION_CONNECTION_INFO, 'test_case': test_admin_get_session_connection_info }, - { - 'test_case_id': test_constants.VIRTUAL_DESKTOP_TEST_ADMIN_REINDEX_USER_SESSIONS, - 'test_case': test_admin_reindex_user_sessions - }, - { - 'test_case_id': test_constants.VIRTUAL_DESKTOP_TEST_ADMIN_REINDEX_SOFTWARE_STACKS, - 'test_case': test_admin_reindex_software_stacks - }, { 'test_case_id': test_constants.VIRTUAL_DESKTOP_TEST_ADMIN_CREATE_PERMISSION_PROFILE, 'test_case': test_admin_create_permission_profile @@ -1394,5 +1361,9 @@ def test_user_list_shared_permissions(context: TestContext): { 'test_case_id': test_constants.VIRTUAL_DESKTOP_TEST_ADMIN_DELETE_SESSIONS, 'test_case': test_admin_delete_sessions + }, + { + 'test_case_id': test_constants.VIRTUAL_DESKTOP_TEST_ADMIN_DELETE_SOFTWARE_STACK, + 'test_case': test_admin_delete_software_stack } ] diff --git a/source/idea/idea-administrator/src/ideaadministrator/integration_tests/virtual_desktop_tests_util.py b/source/idea/idea-administrator/src/ideaadministrator/integration_tests/virtual_desktop_tests_util.py index c9a5e7e..54d9685 100644 --- a/source/idea/idea-administrator/src/ideaadministrator/integration_tests/virtual_desktop_tests_util.py +++ b/source/idea/idea-administrator/src/ideaadministrator/integration_tests/virtual_desktop_tests_util.py @@ -33,6 +33,8 @@ UpdateSessionPermissionResponse, CreateSoftwareStackFromSessionRequest, CreateSoftwareStackFromSessionResponse, + DeleteSoftwareStackRequest, + DeleteSoftwareStackResponse, UpdateSessionRequest, UpdateSessionResponse, StopSessionRequest, @@ -62,10 +64,6 @@ ListSoftwareStackResponse, ListSessionsRequest, ListSessionsResponse, - ReIndexUserSessionsRequest, - ReIndexUserSessionsResponse, - ReIndexSoftwareStacksRequest, - ReIndexSoftwareStacksResponse, ListSupportedOSRequest, ListSupportedOSResponse, ListSupportedGPURequest, @@ -331,6 +329,17 @@ def create_software_stack(self) -> CreateSoftwareStackResponse: except (exceptions.SocaException, Exception) as e: self.context.error(f'Failed to Create Software Stack. Error : {e}') + def delete_software_stack(self, software_stack: VirtualDesktopSoftwareStack) -> DeleteSoftwareStackResponse: + try: + response = self.context.get_virtual_desktop_controller_client(timeout=7200).invoke_alt( + namespace='VirtualDesktopAdmin.DeleteSoftwareStack', + payload=DeleteSoftwareStackRequest(software_stack=software_stack), + result_as=DeleteSoftwareStackResponse, + access_token=self.access_token) + return response + except (exceptions.SocaException, Exception) as e: + self.context.error(f'Failed to Delete Software Stack. Error : {e}') + def update_software_stack(self, software_stack: VirtualDesktopSoftwareStack) -> UpdateSoftwareStackResponse: try: @@ -620,28 +629,6 @@ def list_sessions(self, namespace: str) -> ListSessionsResponse: except (exceptions.SocaException, Exception) as e: self.context.error(f'Failed to List Sessions. Error : {e}') - def reindex_user_session(self) -> ReIndexUserSessionsResponse: - try: - response = self.context.get_virtual_desktop_controller_client(timeout=7200).invoke_alt( - namespace='VirtualDesktopAdmin.ReIndexUserSessions', - payload=ReIndexUserSessionsRequest(), - result_as=ReIndexUserSessionsResponse, - access_token=self.access_token) - return response - except (exceptions.SocaException, Exception) as e: - self.context.error(f'Failed to Reindex User Session. Error : {e}') - - def reindex_software_stacks(self) -> ReIndexSoftwareStacksResponse: - try: - response = self.context.get_virtual_desktop_controller_client(timeout=7200).invoke_alt( - namespace='VirtualDesktopAdmin.ReIndexSoftwareStacks', - payload=ReIndexSoftwareStacksRequest(), - result_as=ReIndexSoftwareStacksResponse, - access_token=self.access_token) - return response - except (exceptions.SocaException, Exception) as e: - self.context.error(f'Failed to Reindex Software Stacks. Error : {e}') - # VDC Utils def list_supported_os(self) -> ListSupportedOSResponse: try: diff --git a/source/idea/idea-administrator/src/ideaadministrator_meta/__init__.py b/source/idea/idea-administrator/src/ideaadministrator_meta/__init__.py index c34345e..b06cfb2 100644 --- a/source/idea/idea-administrator/src/ideaadministrator_meta/__init__.py +++ b/source/idea/idea-administrator/src/ideaadministrator_meta/__init__.py @@ -12,4 +12,4 @@ # pkg config for soca-admin. no dependencies. __name__ = 'idea-administrator' -__version__ = '2023.11' +__version__ = '2024.01' diff --git a/source/idea/idea-bootstrap/_templates/linux/join_activedirectory.jinja2 b/source/idea/idea-bootstrap/_templates/linux/join_activedirectory.jinja2 index 7b110d9..33844fb 100644 --- a/source/idea/idea-bootstrap/_templates/linux/join_activedirectory.jinja2 +++ b/source/idea/idea-bootstrap/_templates/linux/join_activedirectory.jinja2 @@ -10,7 +10,7 @@ AD_DOMAIN_NAME="{{ context.config.get_string('directoryservice.name', required=T AD_REALM_NAME="{{ context.config.get_string('directoryservice.name', required=True).upper() }}" AD_SUDOERS_GROUP_NAME="{{ context.config.get_string('directoryservice.sudoers.group_name', required=True) }}" AD_SUDOERS_GROUP_NAME_ESCAPED="{{ context.config.get_string('directoryservice.sudoers.group_name', required=True).replace(' ', '\ ') }}" -SSSD_LDAP_ID_MAPPING="{{ context.config.get_bool('directoryservice.sssd.ldap_id_mapping', default=False) | lower }}" +SSSD_LDAP_ID_MAPPING="{{ context.config.get_bool('directoryservice.sssd.ldap_id_mapping', default=True) | lower }}" AD_TLS_CERTIFICATE_SECRET_ARN="{{context.config.get_string('directoryservice.tls_certificate_secret_arn', default='')}}" AD_LDAP_BASE="{{context.config.get_string('directoryservice.ldap_base', required=True)}}" @@ -154,6 +154,12 @@ if [[ -f /etc/sssd/sssd.conf ]]; then cp /etc/sssd/sssd.conf /etc/sssd/sssd.conf.orig fi +if [[ ${IDEA_MODULE_NAME} == "cluster-manager" ]]; then + enumerate_value=True +else + enumerate_value=False +fi + if [[ "${AD_TLS_CERTIFICATE_SECRET_ARN}" == '' ]]; then echo -e "[sssd] domains = ${AD_DOMAIN_NAME} @@ -180,7 +186,7 @@ use_fully_qualified_names = false fallback_homedir = ${IDEA_CLUSTER_HOME_DIR}/%u # disable or set to false for very large environments -enumerate = true +enumerate = "${enumerate_value}" sudo_provider = none @@ -232,7 +238,7 @@ use_fully_qualified_names = false fallback_homedir = ${IDEA_CLUSTER_HOME_DIR}/%u # disable or set to false for very large environments -enumerate = true +enumerate = "${enumerate_value}" sudo_provider = none diff --git a/source/idea/idea-bootstrap/_templates/linux/join_openldap.jinja2 b/source/idea/idea-bootstrap/_templates/linux/join_openldap.jinja2 index 8841261..d670862 100644 --- a/source/idea/idea-bootstrap/_templates/linux/join_openldap.jinja2 +++ b/source/idea/idea-bootstrap/_templates/linux/join_openldap.jinja2 @@ -23,8 +23,14 @@ if [ -e /etc/sssd/sssd.conf ]; then cp /etc/sssd/sssd.conf /etc/sssd/sssd.conf.orig fi +if [[ ${IDEA_MODULE_NAME} == "cluster-manager" ]]; then + enumerate_value=True +else + enumerate_value=False +fi + echo -e "[domain/default] -enumerate = True +enumerate = "${enumerate_value}" autofs_provider = ldap cache_credentials = True ldap_search_base = ${IDEA_DS_LDAP_BASE} diff --git a/source/idea/idea-bootstrap/_templates/linux/tag_ebs_volumes.jinja2 b/source/idea/idea-bootstrap/_templates/linux/tag_ebs_volumes.jinja2 index 4523954..68a10e2 100644 --- a/source/idea/idea-bootstrap/_templates/linux/tag_ebs_volumes.jinja2 +++ b/source/idea/idea-bootstrap/_templates/linux/tag_ebs_volumes.jinja2 @@ -8,27 +8,31 @@ function tag_ebs_volumes () { --region "{{ context.aws_region }}" \ --query "Volumes[*].[VolumeId]" \ --out text) - local EBS_IDS=$(echo "${VOLUMES}" | tr "\n" " ") - $AWS ec2 create-tags \ - --resources "${EBS_IDS}" \ - --region "{{ context.aws_region }}" \ - --tags "${TAGS}" + if [ ! -z "$VOLUMES" ]; then + echo $VOLUMES | while read EBS_ID; do + $AWS ec2 create-tags \ + --resources "${EBS_ID}" \ + --region "{{ context.aws_region }}" \ + --tags "${TAGS}" + done - local MAX_RETRIES=5 - local RETRY_COUNT=0 - while [[ $? -ne 0 ]] && [[ ${RETRY_COUNT} -lt ${MAX_RETRIES} ]] - do - local SLEEP_TIME=$(( RANDOM % 33 + 8 )) # Minimum of 8 seconds sleeping - log_info "(${RETRY_COUNT}/${MAX_RETRIES}) ec2 tag failed due to EC2 API error, retrying in ${SLEEP_TIME} seconds ..." - sleep ${SLEEP_TIME} - ((RETRY_COUNT++)) - $AWS ec2 create-tags \ - --resources "${EBS_IDS}" \ - --region "{{ context.aws_region }}" \ - --tags "${TAGS}" - done + local MAX_RETRIES=5 + local RETRY_COUNT=0 + while [[ $? -ne 0 ]] && [[ ${RETRY_COUNT} -lt ${MAX_RETRIES} ]] + do + local SLEEP_TIME=$(( RANDOM % 33 + 8 )) # Minimum of 8 seconds sleeping + log_info "(${RETRY_COUNT}/${MAX_RETRIES}) ec2 tag failed due to EC2 API error, retrying in ${SLEEP_TIME} seconds ..." + sleep ${SLEEP_TIME} + ((RETRY_COUNT++)) + echo $VOLUMES | while read EBS_ID; do + $AWS ec2 create-tags \ + --resources "${EBS_ID}" \ + --region "{{ context.aws_region }}" \ + --tags "${TAGS}" + done + done + fi } tag_ebs_volumes # End: Tag EBS Volumes - diff --git a/source/idea/idea-bootstrap/_templates/linux/tag_network_interface.jinja2 b/source/idea/idea-bootstrap/_templates/linux/tag_network_interface.jinja2 index 25378b8..2e1959f 100644 --- a/source/idea/idea-bootstrap/_templates/linux/tag_network_interface.jinja2 +++ b/source/idea/idea-bootstrap/_templates/linux/tag_network_interface.jinja2 @@ -8,23 +8,30 @@ function tag_network_interface () { --region "{{ context.aws_region }}" \ --query "NetworkInterfaces[*].[NetworkInterfaceId]" \ --out text) - local ENI_IDS=$(echo "${INTERFACES}" | tr "\n" " ") - $AWS ec2 create-tags --resources "${ENI_IDS}" \ - --region "{{ context.aws_region }}" \ - --tags "${TAGS}" - - local MAX_RETRIES=5 - local RETRY_COUNT=0 - while [[ $? -ne 0 ]] && [[ ${RETRY_COUNT} -lt ${MAX_RETRIES} ]] - do - local SLEEP_TIME=$(( RANDOM % 33 + 8 )) # Sleep for 8-40 seconds - log_info "(${RETRY_COUNT}/${MAX_RETRIES}) ec2 tag failed due to EC2 API error, retrying in ${SLEEP_TIME} seconds ..." - sleep ${SLEEP_TIME} - ((RETRY_COUNT++)) - $AWS ec2 create-tags --resources "${ENI_IDS}" \ + if [ ! -z "$INTERFACES" ]; then + echo $INTERFACES | while read ENI_ID; do + $AWS ec2 create-tags \ + --resources "${ENI_ID}" \ --region "{{ context.aws_region }}" \ --tags "${TAGS}" - done + done + + local MAX_RETRIES=5 + local RETRY_COUNT=0 + while [[ $? -ne 0 ]] && [[ ${RETRY_COUNT} -lt ${MAX_RETRIES} ]] + do + local SLEEP_TIME=$(( RANDOM % 33 + 8 )) # Sleep for 8-40 seconds + log_info "(${RETRY_COUNT}/${MAX_RETRIES}) ec2 tag failed due to EC2 API error, retrying in ${SLEEP_TIME} seconds ..." + sleep ${SLEEP_TIME} + ((RETRY_COUNT++)) + echo $INTERFACES | while read ENI_ID; do + $AWS ec2 create-tags \ + --resources "${ENI_ID}" \ + --region "{{ context.aws_region }}" \ + --tags "${TAGS}" + done + done + fi } tag_network_interface # End: Tag Network Interface diff --git a/source/idea/idea-bootstrap/cluster-manager/install_app.sh.jinja2 b/source/idea/idea-bootstrap/cluster-manager/install_app.sh.jinja2 index d9cf3ba..89b7e02 100644 --- a/source/idea/idea-bootstrap/cluster-manager/install_app.sh.jinja2 +++ b/source/idea/idea-bootstrap/cluster-manager/install_app.sh.jinja2 @@ -56,16 +56,19 @@ cp -r ${PACKAGE_DIR}/resources ${IDEA_APP_DEPLOY_DIR}/${APP_NAME} {% include '_templates/linux/supervisord.jinja2' %} +ENVIRONMENT=" + res_test_mode=\"%(ENV_RES_TEST_MODE)s\", + RES_TEST_MODE=\"%(ENV_RES_TEST_MODE)s\"" if [[ ! -z "${IDEA_HTTPS_PROXY}" ]]; then - echo "[program:${APP_NAME}] -environment= + ENVIRONMENT+=", https_proxy=\"%(ENV_IDEA_HTTPS_PROXY)s\", HTTPS_PROXY=\"%(ENV_IDEA_HTTPS_PROXY)s\", no_proxy=\"%(ENV_IDEA_NO_PROXY)s\", - NO_PROXY=\"%(ENV_IDEA_NO_PROXY)s\"" > /etc/supervisord.d/${APP_NAME}.ini -else - echo "[program:${APP_NAME}]" > /etc/supervisord.d/${APP_NAME}.ini + NO_PROXY=\"%(ENV_IDEA_NO_PROXY)s\"" fi + +echo "[program:${APP_NAME}] +environment=${ENVIRONMENT}" > /etc/supervisord.d/${APP_NAME}.ini echo "command=/opt/idea/python/latest/bin/resserver process_name=${APP_NAME} redirect_stderr=true diff --git a/source/idea/idea-bootstrap/cluster-manager/setup.sh.jinja2 b/source/idea/idea-bootstrap/cluster-manager/setup.sh.jinja2 index 4b6de63..4d3eb2b 100644 --- a/source/idea/idea-bootstrap/cluster-manager/setup.sh.jinja2 +++ b/source/idea/idea-bootstrap/cluster-manager/setup.sh.jinja2 @@ -27,7 +27,9 @@ IDEA_CLUSTER_S3_BUCKET={{ context.cluster_s3_bucket }} IDEA_CLUSTER_NAME={{ context.cluster_name }} IDEA_CLUSTER_HOME={{ context.cluster_home_dir }} IDEA_APP_DEPLOY_DIR={{ context.app_deploy_dir }} -BOOTSTRAP_DIR=/root/bootstrap" > /etc/environment +BOOTSTRAP_DIR=/root/bootstrap +## Disable the RES test mode by default - Do not enable it in production as it will bypass API authorization +RES_TEST_MODE=False" > /etc/environment {% if context.https_proxy != '' %} echo -e "IDEA_HTTPS_PROXY={{ context.https_proxy }} diff --git a/source/idea/idea-bootstrap/dcv-connection-gateway/install_app.sh.jinja2 b/source/idea/idea-bootstrap/dcv-connection-gateway/install_app.sh.jinja2 index 885a2b1..43626c7 100644 --- a/source/idea/idea-bootstrap/dcv-connection-gateway/install_app.sh.jinja2 +++ b/source/idea/idea-bootstrap/dcv-connection-gateway/install_app.sh.jinja2 @@ -48,7 +48,7 @@ tar -xvf ${BOOTSTRAP_DIR}/${PACKAGE_ARCHIVE} -C ${APP_DIR} DCV_GPG_KEY="{{ context.config.get_string('global-settings.package_config.dcv.gpg_key', required=True) }}" DCV_CONNECTION_GATEWAY_VERSION="{{ context.config.get_string('global-settings.package_config.dcv.connection_gateway.x86_64.linux.rhel_centos_rocky9.version', required=True) }}" DCV_CONNECTION_GATEWAY_URL="{{ context.config.get_string('global-settings.package_config.dcv.connection_gateway.x86_64.linux.rhel_centos_rocky9.url', required=True) }}" - DCV_CONNECTION_GATEWAY_SHA256_HASH="{{ context.config.get_string('global-settings.package_config.dcv.connection_gateway.x86_64.linux.rhel_centos_rocky9.sha256sum', required=True) }}" + DCV_CONNECTION_GATEWAY_SHA256_URL="{{ context.config.get_string('global-settings.package_config.dcv.connection_gateway.x86_64.linux.rhel_centos_rocky9.sha256sum', required=True) }}" INTERNAL_ALB_ENDPOINT="{{ context.config.get_cluster_internal_endpoint() }}" GATEWAY_TO_BROKER_PORT="{{ context.config.get_string("virtual-desktop-controller.dcv_broker.gateway_communication_port", required=True) }}" @@ -71,7 +71,7 @@ tar -xvf ${BOOTSTRAP_DIR}/${PACKAGE_ARCHIVE} -C ${APP_DIR} DCV_GPG_KEY="{{ context.config.get_string('global-settings.package_config.dcv.gpg_key', required=True) }}" DCV_CONNECTION_GATEWAY_VERSION="{{ context.config.get_string('global-settings.package_config.dcv.connection_gateway.x86_64.linux.al2_rhel_centos7.version', required=True) }}" DCV_CONNECTION_GATEWAY_URL="{{ context.config.get_string('global-settings.package_config.dcv.connection_gateway.x86_64.linux.al2_rhel_centos7.url', required=True) }}" - DCV_CONNECTION_GATEWAY_SHA256_HASH="{{ context.config.get_string('global-settings.package_config.dcv.connection_gateway.x86_64.linux.al2_rhel_centos7.sha256sum', required=True) }}" + DCV_CONNECTION_GATEWAY_SHA256_URL="{{ context.config.get_string('global-settings.package_config.dcv.connection_gateway.x86_64.linux.al2_rhel_centos7.sha256sum', required=True) }}" INTERNAL_ALB_ENDPOINT="{{ context.config.get_cluster_internal_endpoint() }}" GATEWAY_TO_BROKER_PORT="{{ context.config.get_string("virtual-desktop-controller.dcv_broker.gateway_communication_port", required=True) }}" @@ -92,30 +92,18 @@ DCV_WEB_VIEWER_INSTALL_LOCATION="/usr/share/dcv/www" timestamp=$(date +%s) -function setup_nginx() { - yum install nginx1 -y - yum install nginx -y -echo """ -server { - listen 80; - listen [::]:80; - root ${DCV_WEB_VIEWER_INSTALL_LOCATION}; -} -""" > /etc/nginx/conf.d/default.conf - systemctl enable nginx - systemctl start nginx -} - function install_dcv_connection_gateway() { yum install -y nc rpm --import ${DCV_GPG_KEY} wget ${DCV_CONNECTION_GATEWAY_URL} - if [[ $(sha256sum nice-dcv-connection-gateway-${DCV_CONNECTION_GATEWAY_VERSION}.rpm | awk '{print $1}') != ${DCV_CONNECTION_GATEWAY_SHA256_HASH} ]]; then + fileName=$(basename ${DCV_CONNECTION_GATEWAY_URL}) + urlSha256Sum=$(wget -O - ${DCV_CONNECTION_GATEWAY_SHA256_URL}) + if [[ $(sha256sum ${fileName} | awk '{print $1}') != ${urlSha256Sum} ]]; then echo -e "FATAL ERROR: Checksum for DCV Connection Gateway failed. File may be compromised." > /etc/motd exit 1 fi - yum install -y nice-dcv-connection-gateway-${DCV_CONNECTION_GATEWAY_VERSION}.rpm - rm -rf nice-dcv-connection-gateway-${DCV_CONNECTION_GATEWAY_VERSION}.rpm + yum install -y ${fileName} + rm -f ${fileName} } function install_dcv_web_viewer() { @@ -213,7 +201,7 @@ url = \"${INTERNAL_ALB_ENDPOINT}:${GATEWAY_TO_BROKER_PORT}\" tls-strict = false [web-resources] -url = \"http://localhost:80\" +local-resources-path = \"/usr/share/dcv/www\" tls-strict = false " > /etc/dcv-connection-gateway/dcv-connection-gateway.conf @@ -221,8 +209,7 @@ tls-strict = false systemctl start dcv-connection-gateway } -setup_nginx install_dcv_connection_gateway install_dcv_web_viewer configure_certificates -configure_dcv_connection_gateway +configure_dcv_connection_gateway \ No newline at end of file diff --git a/source/idea/idea-bootstrap/openldap-server/_templates/install_openldap.jinja2 b/source/idea/idea-bootstrap/openldap-server/_templates/install_openldap.jinja2 index 103dcbd..55721ca 100644 --- a/source/idea/idea-bootstrap/openldap-server/_templates/install_openldap.jinja2 +++ b/source/idea/idea-bootstrap/openldap-server/_templates/install_openldap.jinja2 @@ -188,9 +188,15 @@ if [[ -f /etc/sssd/sssd.conf ]]; then cp /etc/sssd/sssd.conf /etc/sssd/sssd.conf.orig fi +if [[ ${IDEA_MODULE_NAME} == "cluster-manager" ]]; then + enumerate_value=True +else + enumerate_value=False +fi + echo -e " [domain/default] -enumerate = True +enumerate = "${enumerate_value}" autofs_provider = ldap cache_credentials = True ldap_search_base = ${IDEA_DS_LDAP_BASE} diff --git a/source/idea/idea-bootstrap/virtual-desktop-controller/install_app.sh.jinja2 b/source/idea/idea-bootstrap/virtual-desktop-controller/install_app.sh.jinja2 index 4da76bb..8334193 100644 --- a/source/idea/idea-bootstrap/virtual-desktop-controller/install_app.sh.jinja2 +++ b/source/idea/idea-bootstrap/virtual-desktop-controller/install_app.sh.jinja2 @@ -49,16 +49,19 @@ cp -r ${PACKAGE_DIR}/resources ${IDEA_APP_DEPLOY_DIR}/${APP_NAME} {% include '_templates/linux/supervisord.jinja2' %} +ENVIRONMENT=" + res_test_mode=\"%(ENV_RES_TEST_MODE)s\", + RES_TEST_MODE=\"%(ENV_RES_TEST_MODE)s\"" if [[ ! -z "${IDEA_HTTPS_PROXY}" ]]; then - echo "[program:${APP_NAME}] -environment= + ENVIRONMENT+=", https_proxy=\"%(ENV_IDEA_HTTPS_PROXY)s\", HTTPS_PROXY=\"%(ENV_IDEA_HTTPS_PROXY)s\", no_proxy=\"%(ENV_IDEA_NO_PROXY)s\", - NO_PROXY=\"%(ENV_IDEA_NO_PROXY)s\"" > /etc/supervisord.d/${APP_NAME}.ini -else - echo "[program:${APP_NAME}]" > /etc/supervisord.d/${APP_NAME}.ini + NO_PROXY=\"%(ENV_IDEA_NO_PROXY)s\"" fi + +echo "[program:${APP_NAME}] +environment=${ENVIRONMENT}" > /etc/supervisord.d/${APP_NAME}.ini echo "command=/opt/idea/python/latest/bin/resserver process_name=${APP_NAME} redirect_stderr=true diff --git a/source/idea/idea-bootstrap/virtual-desktop-controller/setup.sh.jinja2 b/source/idea/idea-bootstrap/virtual-desktop-controller/setup.sh.jinja2 index 9ac6b01..45d7a2c 100644 --- a/source/idea/idea-bootstrap/virtual-desktop-controller/setup.sh.jinja2 +++ b/source/idea/idea-bootstrap/virtual-desktop-controller/setup.sh.jinja2 @@ -28,7 +28,9 @@ IDEA_CLUSTER_S3_BUCKET={{ context.cluster_s3_bucket }} IDEA_CLUSTER_NAME={{ context.cluster_name }} IDEA_CLUSTER_HOME={{ context.cluster_home_dir }} IDEA_APP_DEPLOY_DIR={{ context.app_deploy_dir }} -BOOTSTRAP_DIR=/root/bootstrap" > /etc/environment +BOOTSTRAP_DIR=/root/bootstrap +## Disable the RES test mode by default - Do not enable it in production as it will bypass API authorization +RES_TEST_MODE=False" > /etc/environment {% if context.https_proxy != '' %} echo -e "IDEA_HTTPS_PROXY={{ context.https_proxy }} diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/README.md b/source/idea/idea-cluster-manager/src/ideaclustermanager/README.md index 6eb6712..154578d 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/README.md +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/README.md @@ -59,10 +59,6 @@ Contains the various API files for each api. The files contain the acl to define Contains the API for creating and manipulating accounts and groups. The api functions all call the functions from the accounts service. -#### analytics_api.py - -Contains analytics API. The only action in this API is to query OpenSearch. The analytics service it uses is set in the SocaContext which is defined in the idea sdk and is the parent class of the app_context. - #### api_invoker.py Contains the entrypoint for the API. When a request is received on the api server, it calls the api invoker, which then invokes the respective API. diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/accounts_service.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/accounts_service.py index ddd85ab..9852641 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/accounts_service.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/accounts_service.py @@ -10,6 +10,7 @@ # and limitations under the License. from ideasdk.client.evdi_client import EvdiClient from ideasdk.context import SocaContext +from ideadatamodel import AuthResult from ideadatamodel.auth import ( User, Group, @@ -39,6 +40,7 @@ from ideaclustermanager.app.accounts.db.single_sign_on_state_dao import SingleSignOnStateDAO from ideaclustermanager.app.accounts.helpers.single_sign_on_helper import SingleSignOnHelper from ideaclustermanager.app.tasks.task_manager import TaskManager +from ideaclustermanager.app.accounts.helpers.sssd_helper import SSSD from typing import Optional, List import os @@ -79,6 +81,7 @@ def __init__(self, context: SocaContext, self.task_manager = task_manager self.token_service = token_service + self.sssd = SSSD(context) self.group_name_helper = GroupNameHelper(context) self.user_dao = UserDAO(context, user_pool=user_pool) self.group_dao = GroupDAO(context) @@ -93,6 +96,7 @@ def __init__(self, context: SocaContext, self.ds_automation_dir = self.context.config().get_string('directoryservice.automation_dir', required=True) + def is_cluster_administrator(self, username: str) -> bool: cluster_administrator = self.context.config().get_string('cluster.administrator_username', required=True) return username == cluster_administrator @@ -140,7 +144,6 @@ def create_group(self, group: Group) -> Group: db_existing_group = self.group_dao.get_group(group_name) if db_existing_group is not None: raise exceptions.invalid_params(f'group: {group_name} already exists') - # Perform AD validation only for groups identified as external groups. Internal grpups are RES specific groups and are not expected to be present in the AD. if group.type == constants.GROUP_TYPE_EXTERNAL and ds_readonly: if not group.ds_name: @@ -148,15 +151,14 @@ def create_group(self, group: Group) -> Group: group_in_ad = self.ldap_client.get_group(group.ds_name) if group_in_ad is None: raise exceptions.invalid_params(f'group with name {group.ds_name} is not found in Read-Only Directory Service: {constants.DIRECTORYSERVICE_ACTIVE_DIRECTORY}') - group_id_in_ad = group_in_ad['gid'] - if group_id_in_ad is None: - raise exceptions.invalid_params(f'Group id is not found in Directory Service: {constants.DIRECTORYSERVICE_ACTIVE_DIRECTORY}') - group.gid = group_id_in_ad group_name_in_ad = group_in_ad['name'] if group_name_in_ad is None: raise exceptions.invalid_params(f'Group name matching the provided name {group.ds_name} is not found in Directory Service: {constants.DIRECTORYSERVICE_ACTIVE_DIRECTORY}') if group_name_in_ad != group.ds_name: raise exceptions.invalid_params(f'group.ds_name {group.ds_name} does not match the value {group_name_in_ad} read from Directory Service: {constants.DIRECTORYSERVICE_ACTIVE_DIRECTORY}') + group.gid = self.sssd.get_gid_for_group(group_name_in_ad) + if group.gid is None: + raise exceptions.soca_exception(error_code=errorcodes.GID_NOT_FOUND, message=f'Unable to retrieve GID for Group: {group_name_in_ad}') group.enabled = True @@ -340,13 +342,9 @@ def add_users_to_group(self, usernames: List[str], group_name: str, bypass_activ 'additional_groups': list(set(user.get('additional_groups', []) + [group_name])) }) - if bypass_active_user_check or user['is_active']: + if bypass_active_user_check or user['is_active'] or Utils.is_test_mode(): self.group_members_dao.create_membership(group_name, username) - # For group_types MODULE or CLUSTER, the user must be added to those groups in cognito too. This gives users admin/user access to various sections on the UI. - if group['group_type'] not in (constants.GROUP_TYPE_USER, constants.GROUP_TYPE_PROJECT): - self.user_pool.admin_add_user_to_group(username=username, group_name=group_name) - self.logger.info(f'add user projects for user: {username} in group: {group_name} - DAO ...') self.context.projects.user_projects_dao.group_member_added(group_name=group_name, username=username) @@ -398,10 +396,6 @@ def remove_user_from_groups(self, username: str, group_names: List[str]): self.group_members_dao.delete_membership(group_name, username) - # For group_types MODULE or CLUSTER, the user must be removed from those groups in cognito too. This removes user's admin/user access to various sections on the UI. - if group["group_type"] not in (constants.GROUP_TYPE_USER, constants.GROUP_TYPE_PROJECT): - self.user_pool.admin_remove_user_from_group(username=username, group_name=group_name) - self.user_dao.update_user({'username': username, 'additional_groups': additional_groups}) # removing user projects after updating the user object, to ensure updated user membership is accessible during user-project deletion. @@ -455,10 +449,6 @@ def remove_users_from_group(self, usernames: List[str], group_name: str, force: self.group_members_dao.delete_membership(group_name, username) - # For group_types MODULE or CLUSTER, the user must be removed from those groups in cognito too. This removes user's admin/user access to various sections on the UI. - if group['group_type'] not in (constants.GROUP_TYPE_USER, constants.GROUP_TYPE_PROJECT): - self.user_pool.admin_remove_user_from_group(username=username, group_name=group_name) - self.context.projects.user_projects_dao.group_member_removed( group_name=group_name, username=username @@ -502,8 +492,6 @@ def add_admin_user(self, username: str): 'sudo': True }) - self.user_pool.admin_add_user_as_admin(username) - def remove_admin_user(self, username: str): if Utils.is_empty(username): @@ -528,8 +516,6 @@ def remove_admin_user(self, username: str): 'sudo': False }) - self.user_pool.admin_remove_user_as_admin(username) - # user management methods def get_user(self, username: str) -> User: @@ -547,6 +533,24 @@ def get_user(self, username: str) -> User: return self.user_dao.convert_from_db(user) + def get_user_by_email(self, email: str) -> User: + email = AuthUtils.sanitize_email(email=email) + if not email: + raise exceptions.invalid_params('email is required') + + users = self.user_dao.get_user_by_email(email=email) + if len(users) > 1: + self.logger.warn(f'Multiple users found with email {email}') + raise exceptions.SocaException(error_code=errorcodes.AUTH_MULTIPLE_USERS_FOUND, + message=f'Multiple users found with email {email}') + user = users[0] if users else None + if not user: + raise exceptions.SocaException( + error_code=errorcodes.AUTH_USER_NOT_FOUND, + message=f'User not found with email: {email}' + ) + return self.user_dao.convert_from_db(user) + def create_user(self, user: User, email_verified: bool = False) -> User: """ create a new user @@ -577,31 +581,6 @@ def create_user(self, user: User, email_verified: bool = False) -> User: email = user.email email = AuthUtils.sanitize_email(email) - # password - password = user.password - if email_verified: - if password == None or len(password.strip()) == 0: - raise exceptions.invalid_params('Password is required') - - user_pool_password_policy = self.user_pool.describe_password_policy() - # Validate password compliance versus Cognito user pool password policy - # Cognito: https://docs.aws.amazon.com/cognito/latest/developerguide/user-pool-settings-policies.html - if len(password) < user_pool_password_policy.minimum_length: - raise exceptions.invalid_params(f'Password should be at least {user_pool_password_policy.minimum_length} characters long') - elif len(password) > 256: - raise exceptions.invalid_params(f'Password can be up to 256 characters') - elif user_pool_password_policy.require_numbers and re.search('[0-9]', password) is None: - raise exceptions.invalid_params('Password should include at least 1 number') - elif user_pool_password_policy.require_uppercase and re.search('[A-Z]', password) is None: - raise exceptions.invalid_params('Password should include at least 1 uppercase letter') - elif user_pool_password_policy.require_lowercase and re.search('[a-z]', password) is None: - raise exceptions.invalid_params('Password should include at least 1 lowercase letter') - elif user_pool_password_policy.require_symbols and re.search('[\^\$\*\.\[\]{}\(\)\?"!@#%&\/\\,><\':;\|_~`=\+\-]', password) is None: - raise exceptions.invalid_params('Password should include at least 1 of these special characters: ^ $ * . [ ] { } ( ) ? " ! @ # % & / \ , > < \' : ; | _ ~ ` = + -') - else: - self.logger.debug('create_user() - setting password to random value') - password = Utils.generate_password(8, 2, 2, 2, 2) - # login_shell login_shell = user.login_shell if login_shell == None or len(login_shell.strip()) == 0: @@ -615,14 +594,12 @@ def create_user(self, user: User, email_verified: bool = False) -> User: # note: no validations on uid / gid if existing uid/gid is provided. # ensuring uid/gid uniqueness is administrator's responsibility. - # uid - uid = user.uid - if uid is None and not bootstrap_user: - raise exceptions.invalid_params('user.uid missing in AD user data') - - gid = user.gid - if gid is None and not bootstrap_user: - raise exceptions.invalid_params('user.gid missing in AD user data') + # uid and gid + uid = None + gid = None + uid, gid = self.sssd.get_uid_and_gid_for_user(username) + if (uid is None or gid is None) and not bootstrap_user: + raise exceptions.soca_exception(error_code=errorcodes.UID_AND_GID_NOT_FOUND, message=f'Unable to retrieve UID and GID for User: {username}') # sudo sudo = bool(user.sudo) @@ -633,14 +610,39 @@ def create_user(self, user: User, email_verified: bool = False) -> User: # is_active is_active = bool(user.is_active) - self.logger.info(f'creating Cognito user pool entry: {username}, Email: {email} , email_verified: {email_verified}') - - self.user_pool.admin_create_user( - username=username, - email=email, - password=password, - email_verified=email_verified - ) + # password + if bootstrap_user: + password = user.password + if email_verified: + if password == None or len(password.strip()) == 0: + raise exceptions.invalid_params('Password is required') + + user_pool_password_policy = self.user_pool.describe_password_policy() + # Validate password compliance versus Cognito user pool password policy + # Cognito: https://docs.aws.amazon.com/cognito/latest/developerguide/user-pool-settings-policies.html + if len(password) < user_pool_password_policy.minimum_length: + raise exceptions.invalid_params(f'Password should be at least {user_pool_password_policy.minimum_length} characters long') + elif len(password) > 256: + raise exceptions.invalid_params(f'Password can be up to 256 characters') + elif user_pool_password_policy.require_numbers and re.search('[0-9]', password) is None: + raise exceptions.invalid_params('Password should include at least 1 number') + elif user_pool_password_policy.require_uppercase and re.search('[A-Z]', password) is None: + raise exceptions.invalid_params('Password should include at least 1 uppercase letter') + elif user_pool_password_policy.require_lowercase and re.search('[a-z]', password) is None: + raise exceptions.invalid_params('Password should include at least 1 lowercase letter') + elif user_pool_password_policy.require_symbols and re.search('[\^\$\*\.\[\]{}\(\)\?"!@#%&\/\\,><\':;\|_~`=\+\-]', password) is None: + raise exceptions.invalid_params('Password should include at least 1 of these special characters: ^ $ * . [ ] { } ( ) ? " ! @ # % & / \ , > < \' : ; | _ ~ ` = + -') + else: + self.logger.debug('create_user() - setting password to random value') + password = Utils.generate_password(8, 2, 2, 2, 2) + + self.logger.info(f'creating Cognito user pool entry: {username}, Email: {email} , email_verified: {email_verified}') + self.user_pool.admin_create_user( + username=username, + email=email, + password=password, + email_verified=email_verified + ) # additional groups additional_groups = Utils.get_as_list(user.additional_groups, []) @@ -661,21 +663,6 @@ def create_user(self, user: User, email_verified: bool = False) -> User: 'is_active': is_active, }) - if self.is_sso_enabled(): - self.logger.debug(f'Performing IDP Link for {username} / {email}') - self.user_pool.admin_link_idp_for_user(username, email) - if admin: - self.logger.debug(f'Performing ADMIN for {username}') - self.user_pool.admin_add_user_as_admin(username) - - # todo: Remove once use of RES specific groups is removed - for user_group in constants.RES_USER_GROUPS: - try: - self.logger.info(f'Adding username {username} to RES_USER_GROUP: {user_group}') - self.add_users_to_group([username], user_group, bypass_active_user_check=True) - except Exception as e: - self.logger.debug(f"Could not add user {username} to RES_USER_GROUP: {user_group}") - for additional_group in additional_groups: self.logger.info(f'Adding username {username} to additional group: {additional_group}') self.add_users_to_group([username], additional_group) @@ -696,7 +683,7 @@ def modify_user(self, user: User, email_verified: bool = False) -> User: """ Modify User - Only ``email`` updates are supported at the moment. + ``email``, ``login_shell``, ``sudo``, ``uid``, ``gid``, ``is_active`` updates are supported at the moment. :param user: :param email_verified: @@ -723,9 +710,6 @@ def modify_user(self, user: User, email_verified: bool = False) -> User: existing_email = Utils.get_value_as_string('email', existing_user) if existing_email != new_email: user_updates['email'] = new_email - self.user_pool.admin_update_email(username, new_email, email_verified=email_verified) - if self.is_sso_enabled(): - self.user_pool.admin_link_idp_for_user(username, new_email) if Utils.is_not_empty(user.login_shell): user_updates['login_shell'] = user.login_shell @@ -740,8 +724,13 @@ def modify_user(self, user: User, email_verified: bool = False) -> User: user_updates['gid'] = user.gid updated_user = self.user_dao.update_user(user_updates) + updated_user = self.user_dao.convert_from_db(updated_user) + if user.is_active: + # Only handle user activation as the account servie doesn't define the deactivation workflow currently. + self.activate_user(updated_user) + updated_user.is_active = True - return self.user_dao.convert_from_db(updated_user) + return updated_user def activate_user(self, existing_user: User): if not existing_user.is_active: @@ -754,6 +743,7 @@ def activate_user(self, existing_user: User): self.logger.warning(f'Could not add user {username} to group {additional_group}') self.user_dao.update_user({'username': username, 'is_active': True}) + def enable_user(self, username: str): if Utils.is_empty(username): raise exceptions.invalid_params('username is required') @@ -766,10 +756,9 @@ def enable_user(self, username: str): is_enabled = Utils.get_value_as_bool('enabled', existing_user, False) if is_enabled: return - - self.user_pool.admin_enable_user(username) self.user_dao.update_user({'username': username, 'enabled': True}) + def disable_user(self, username: str): if Utils.is_empty(username): raise exceptions.invalid_params('username is required') @@ -785,10 +774,8 @@ def disable_user(self, username: str): is_enabled = Utils.get_value_as_bool('enabled', existing_user, False) if not is_enabled: return - - self.user_pool.admin_disable_user(username) + self.user_dao.update_user({'username': username, 'enabled': False}) - self.evdi_client.publish_user_disabled_event(username=username) def delete_user(self, username: str): @@ -814,14 +801,6 @@ def delete_user(self, username: str): else: raise e - # disable user from db, user pool and delete from directory service - self.logger.info(f'{log_tag} disabling user') - self.disable_user(username=username) - - # delete user in user pool - self.logger.info(f'{log_tag} delete user from user pool') - self.user_pool.admin_delete_user(username=username) - # delete user from db self.logger.info(f'{log_tag} delete user in ddb') self.user_dao.delete_user(username=username) @@ -831,6 +810,9 @@ def reset_password(self, username: str): if Utils.is_empty(username): raise exceptions.invalid_params('username is required') + if not self.is_cluster_administrator(username): + raise AuthUtils.invalid_operation('Only Cluster Administrator password can be reset.') + # trigger reset password email self.user_pool.admin_reset_password(username) @@ -839,11 +821,14 @@ def list_users(self, request: ListUsersRequest) -> ListUsersResult: def change_password(self, access_token: str, username: str, old_password: str, new_password: str): """ - change password for given username in user pool and ldap + change password for given username in user pool this method expects an access token from an already logged-in user, who is trying to change their password. :return: """ + if not self.is_cluster_administrator(username): + raise AuthUtils.invalid_operation('Only Cluster Administrator password can be changed.') + # change password in user pool before changing in ldap self.user_pool.change_password( username=username, @@ -851,41 +836,61 @@ def change_password(self, access_token: str, username: str, old_password: str, n old_password=old_password, new_password=new_password ) - + + def get_user_from_access_token(self, access_token: str) -> Optional[User]: + decoded_token = self.token_service.decode_token(token=access_token) + token_username = decoded_token.get('username') + return self.get_user_from_token_username(token_username=token_username) + + def get_user_from_token_username(self, token_username: str) -> Optional[User]: + if not token_username: + raise exceptions.unauthorized_access() + email = self.token_service.get_email_from_token_username(token_username=token_username) + user = None + if email: + user = self.get_user_by_email(email=email) + else: + # This is for clusteradmin + user = self.get_user(username=token_username) + return user + + def add_role_dbusername_to_auth_result(self, authresult: InitiateAuthResult, ssoAuth: bool = False) -> Optional[InitiateAuthResult]: + access_token = authresult.auth.access_token + user = self.get_user_from_access_token(access_token=access_token) + if user.enabled: + authresult.role = user.role + authresult.db_username = user.username + return authresult + else: + self.sign_out(authresult.auth.refresh_token, sso_auth=ssoAuth) + self.logger.error(msg=f'User {user.username} is disabled. Denied login.') + raise exceptions.unauthorized_access() + # public API methods for user onboarding, login, forgot password flows. - def initiate_auth(self, request: InitiateAuthRequest) -> InitiateAuthResult: - auth_flow = request.auth_flow if Utils.is_empty(auth_flow): raise exceptions.invalid_params('auth_flow is required.') if auth_flow == 'USER_PASSWORD_AUTH': - - username = request.username - if Utils.is_empty(username): - raise exceptions.invalid_params('username is required.') - + cognito_username = request.cognito_username password = request.password - if Utils.is_empty(password): - raise exceptions.invalid_params('password is required.') - - return self.user_pool.initiate_username_password_auth(request) - + if not self.is_cluster_administrator(cognito_username): + raise exceptions.unauthorized_access() + authresult = self.user_pool.initiate_username_password_auth(request) + if not authresult.challenge_name: + authresult = self.add_role_dbusername_to_auth_result(authresult=authresult) + return authresult elif auth_flow == 'REFRESH_TOKEN_AUTH': - - username = request.username - if Utils.is_empty(username): - raise exceptions.invalid_params('username is required.') - + cognito_username = request.cognito_username refresh_token = request.refresh_token - if Utils.is_empty(refresh_token): - raise exceptions.invalid_params('refresh_token is required.') - - return self.user_pool.initiate_refresh_token_auth(username, refresh_token) - + if not self.is_cluster_administrator(cognito_username): + raise exceptions.unauthorized_access() + authresult = self.user_pool.initiate_refresh_token_auth( + username=cognito_username, refresh_token=refresh_token) + authresult = self.add_role_dbusername_to_auth_result(authresult=authresult) + return authresult elif auth_flow == 'SSO_AUTH': - if not self.is_sso_enabled(): raise exceptions.unauthorized_access() @@ -894,31 +899,30 @@ def initiate_auth(self, request: InitiateAuthRequest) -> InitiateAuthResult: raise exceptions.invalid_params('authorization_code is required.') db_sso_state = self.sso_state_dao.get_sso_state(authorization_code) - if db_sso_state is None: + if not db_sso_state: raise exceptions.unauthorized_access() - auth_result = self.sso_state_dao.convert_from_db(db_sso_state) - self.sso_state_dao.delete_sso_state(authorization_code) - - return InitiateAuthResult( - auth=auth_result + authresult = InitiateAuthResult( + auth=AuthResult( + access_token= db_sso_state.get('access_token'), + refresh_token= db_sso_state.get('refresh_token'), + id_token= db_sso_state.get('id_token'), + expires_in= db_sso_state.get('expires_in'), + token_type= db_sso_state.get('token_type'), + ) ) - + authresult = self.add_role_dbusername_to_auth_result(authresult=authresult, ssoAuth=True) + return authresult elif auth_flow == 'SSO_REFRESH_TOKEN_AUTH': - if not self.is_sso_enabled(): raise exceptions.unauthorized_access() - - username = request.username - if Utils.is_empty(username): - raise exceptions.invalid_params('username is required.') - + cognito_username = request.cognito_username refresh_token = request.refresh_token - if Utils.is_empty(refresh_token): - raise exceptions.invalid_params('refresh_token is required.') - - return self.user_pool.initiate_refresh_token_auth(username, refresh_token, sso=True) + authresult = self.user_pool.initiate_refresh_token_auth( + username=cognito_username, refresh_token=refresh_token, sso=True) + authresult = self.add_role_dbusername_to_auth_result(authresult=authresult, ssoAuth=True) + return authresult def respond_to_auth_challenge(self, request: RespondToAuthChallengeRequest) -> RespondToAuthChallengeResult: @@ -1038,55 +1042,7 @@ def create_defaults(self): except Exception as e: self.logger.error(f"Error during clusteradmin user creation: {e}") - # create managers group - cluster_managers_group_name = self.group_name_helper.get_cluster_managers_group() - self.create_res_group(ds_readonly, group_name=cluster_managers_group_name, - title="Managers (Administrators without sudo access)", group_type=constants.GROUP_TYPE_CLUSTER, ref=None) - - # for all "app" modules in the cluster, create the module users and module administrators group to enable fine-grained access - # if an application module is added at a later point in time, a cluster-manager restart should fix the issue. - # ideally, an 'resctl initialize-defaults' command is warranted to address this scenario and will be taken up in a future release. - modules = self.context.get_cluster_modules() - for module in modules: - if module['type'] != constants.MODULE_TYPE_APP: - continue - module_id = module['module_id'] - module_name = module['name'] - - module_administrators_group_name = self.group_name_helper.get_module_administrators_group(module_id=module_id) - self.create_res_group(ds_readonly, group_name=module_administrators_group_name, - title=f"Administrators for Module: {module_name}, ModuleId: {module_id}, DS: {module_administrators_group_name}", group_type=constants.GROUP_TYPE_MODULE, ref=module_id) - - module_users_group_name = self.group_name_helper.get_module_users_group(module_id=module_id) - self.create_res_group(ds_readonly, group_name=module_users_group_name, - title=f"Users for Module: {module_name}, ModuleId: {module_id}", group_type=constants.GROUP_TYPE_MODULE, ref=module_id) - - def create_res_group(self, ds_readonly: bool, group_name: Optional[str], title: Optional[str], group_type: Optional[str], ref: str = None): - if ds_readonly: - self.logger.info(f'Skipping {group_name} sync with AD') - - group = self.group_dao.get_group(group_name) - - if group is None: - self.logger.info(f'Group {group_name} not found in RES DynamoDB') - - try: - self.create_group( - group=Group( - title=title, - name=group_name, - ds_name=group_name, - gid=None, - group_type=group_type, - ref=ref, - type=constants.GROUP_TYPE_INTERNAL - ), - ) - except Exception as e: - self.logger.error(f'Error {e}') - - self.logger.info(f'creating group: {group_name}') - def configure_sso(self, request: ConfigureSSORequest): self.single_sign_on_helper.configure_sso(request) self.context.ad_sync.sync_from_ad() # submit ad_sync task after configuring SSO + diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/cognito_user_pool.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/cognito_user_pool.py index 58b536d..1843d15 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/cognito_user_pool.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/cognito_user_pool.py @@ -231,83 +231,6 @@ def admin_disable_user(self, username: str): ) self._context.cache().short_term().delete(self.build_user_cache_key(username)) - def admin_add_user_as_admin(self, username: str): - self.admin_add_user_to_group(username=username, group_name=self.admin_group_name) - - # todo: Remove once use of RES specific groups is removed - for admin_group in constants.RES_ADMIN_GROUPS: - try: - self._logger.info(f'Adding username {username} to RES_ADMIN_GROUP: {admin_group}') - self._context.accounts.add_users_to_group([username], admin_group, bypass_active_user_check=True) - except Exception as e: - self._logger.debug(f"Could not add user {username} to RES_ADMIN_GROUP: {admin_group}") - - def admin_add_user_to_group(self, username: str, group_name: str): - if Utils.is_empty(username): - raise exceptions.invalid_params('username is required') - if Utils.is_empty(group_name): - raise exceptions.invalid_params('username is required') - - self._context.aws().cognito_idp().admin_add_user_to_group( - UserPoolId=self.user_pool_id, - Username=username, - GroupName=group_name - ) - - def admin_link_idp_for_user(self, username: str, email: str): - - if Utils.is_empty(username): - raise exceptions.invalid_params('username is required') - - cluster_administrator = self._context.config().get_string('cluster.administrator_username', required=True) - if username in cluster_administrator or username.startswith('clusteradmin'): - self._logger.info(f'system administration user found: {username}. skip linking with IDP.') - return - - provider_name = self._context.config().get_string('identity-provider.cognito.sso_idp_provider_name', required=True) - provider_type = self._context.config().get_string('identity-provider.cognito.sso_idp_provider_type', required=True) - if provider_type == constants.SSO_IDP_PROVIDER_OIDC: - provider_email_attribute = 'email' - else: - provider_email_attribute = self._context.config().get_string('identity-provider.cognito.sso_idp_provider_email_attribute', required=True) - - self._context.aws().cognito_idp().admin_link_provider_for_user( - UserPoolId=self.user_pool_id, - DestinationUser={ - 'ProviderName': 'Cognito', - 'ProviderAttributeName': 'cognito:username', - 'ProviderAttributeValue': username - }, - SourceUser={ - 'ProviderName': provider_name, - 'ProviderAttributeName': provider_email_attribute, - 'ProviderAttributeValue': email - } - ) - - def admin_remove_user_as_admin(self, username: str): - self.admin_remove_user_from_group(username=username, group_name=self.admin_group_name) - - # todo: Remove once use of RES specific groups is removed - for admin_group in constants.RES_ADMIN_GROUPS: - try: - self._logger.info(f'Removing username {username} from RES_ADMIN_GROUP: {admin_group}') - self._context.accounts.remove_users_from_group([username], admin_group) - except Exception as e: - self._logger.debug(f"Could not add user {username} to RES_ADMIN_GROUP: {admin_group}") - - def admin_remove_user_from_group(self, username: str, group_name: str): - if Utils.is_empty(username): - raise exceptions.invalid_params('username is required') - if Utils.is_empty(group_name): - raise exceptions.invalid_params('group_name is required') - - self._context.aws().cognito_idp().admin_remove_user_from_group( - UserPoolId=self.user_pool_id, - Username=username, - GroupName=group_name - ) - def password_updated(self, username: str): if not self.is_activedirectory(): return @@ -396,27 +319,27 @@ def admin_global_sign_out(self, username: str): def initiate_username_password_auth(self, request: InitiateAuthRequest) -> InitiateAuthResult: - username = request.username - if Utils.is_empty(username): - raise exceptions.invalid_params('username is required.') + cognito_username = request.cognito_username + if not cognito_username: + raise exceptions.invalid_params('cognito username is required.') password = request.password - if Utils.is_empty(password): + if not password: raise exceptions.invalid_params('password is required.') # In SSO-enabled mode - local auth is not allowed except for clusteradmin cluster_admin_username = self._context.config().get_string('cluster.administrator_username', required=True) sso_enabled = self._context.config().get_bool('identity-provider.cognito.sso_enabled', required=True) - if sso_enabled and (username != cluster_admin_username): - self._logger.error(f"Ignoring local authentication request with SSO enabled. Username: {username}") - raise exceptions.unauthorized_access(f"Ignoring local authentication request with SSO enabled. Username: {username}") + if sso_enabled and (cognito_username != cluster_admin_username): + self._logger.error(f"Ignoring local authentication request with SSO enabled. Username: {cognito_username}") + raise exceptions.unauthorized_access(f"Ignoring local authentication request with SSO enabled. Username: {cognito_username}") try: cognito_result = self._context.aws().cognito_idp().admin_initiate_auth( AuthFlow='ADMIN_USER_PASSWORD_AUTH', AuthParameters={ - 'USERNAME': username, + 'USERNAME': cognito_username, 'PASSWORD': password, - 'SECRET_HASH': self.get_secret_hash(username) + 'SECRET_HASH': self.get_secret_hash(cognito_username) }, UserPoolId=self.user_pool_id, ClientId=self.get_client_id() @@ -489,7 +412,7 @@ def respond_to_auth_challenge(self, request: RespondToAuthChallengeRequest) -> R auth=auth_result ) - def initiate_refresh_token_auth(self, username: str, refresh_token: str, sso: bool = False): + def initiate_refresh_token_auth(self, username: str, refresh_token: str, sso: bool = False) -> InitiateAuthResult: if Utils.is_empty(username): raise exceptions.invalid_params('username is required.') @@ -516,7 +439,7 @@ def initiate_refresh_token_auth(self, username: str, refresh_token: str, sso: bo cognito_auth_result = Utils.get_value_as_dict('AuthenticationResult', cognito_result) auth_result = self.build_auth_result(cognito_auth_result) - return RespondToAuthChallengeResult( + return InitiateAuthResult( auth=auth_result ) diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/db/single_sign_on_state_dao.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/db/single_sign_on_state_dao.py index 049d436..3f2fa4f 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/db/single_sign_on_state_dao.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/db/single_sign_on_state_dao.py @@ -54,18 +54,6 @@ def initialize(self): ) self.table = self.context.aws().dynamodb_table().Table(self.get_table_name()) - @staticmethod - def convert_from_db(sso_state: Dict) -> AuthResult: - return AuthResult( - **{ - 'access_token': Utils.get_value_as_string('access_token', sso_state), - 'refresh_token': Utils.get_value_as_string('refresh_token', sso_state), - 'id_token': Utils.get_value_as_string('id_token', sso_state), - 'expires_in': Utils.get_value_as_int('expires_in', sso_state), - 'token_type': Utils.get_value_as_string('token_type', sso_state) - } - ) - def create_sso_state(self, sso_state: Dict) -> Dict: state = Utils.get_value_as_string('state', sso_state) @@ -76,7 +64,7 @@ def create_sso_state(self, sso_state: Dict) -> Dict: **sso_state, 'ttl': Utils.current_time_ms() + (10 * 60 * 1000), # 10 minutes 'created_on': Utils.current_time_ms(), - 'updated_on': Utils.current_time_ms() + 'updated_on': Utils.current_time_ms(), } self.table.put_item( Item=created_state, diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/db/user_dao.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/db/user_dao.py index 99cc195..b84e8b2 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/db/user_dao.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/db/user_dao.py @@ -16,9 +16,8 @@ from ideaclustermanager.app.accounts.auth_utils import AuthUtils from ideaclustermanager.app.accounts.cognito_user_pool import CognitoUserPool - -from typing import Optional, Dict -from boto3.dynamodb.conditions import Attr +from typing import Optional, Dict, List +from boto3.dynamodb.conditions import Attr, Key class UserDAO: @@ -47,6 +46,10 @@ def initialize(self): { 'AttributeName': 'role', 'AttributeType': 'S' + }, + { + 'AttributeName': 'email', + 'AttributeType': 'S' } ], 'KeySchema': [ @@ -71,6 +74,24 @@ def initialize(self): "username" ] }, + }, + { + "IndexName": "email-index", + "KeySchema": [ + { + "AttributeName": "email", + "KeyType": "HASH" + } + ], + "Projection": { + "ProjectionType": "INCLUDE", + "NonKeyAttributes": [ + "role", + "username", + "is_active", + "enabled" + ] + }, } ], 'BillingMode': 'PAY_PER_REQUEST' @@ -79,7 +100,8 @@ def initialize(self): ) self.table = self.context.aws().dynamodb_table().Table(self.get_table_name()) - def convert_from_db(self, user: Dict) -> User: + @staticmethod + def convert_from_db(user: Dict) -> User: user_entry = User( **{ 'username': Utils.get_value_as_string('username', user), @@ -174,6 +196,16 @@ def get_user(self, username: str) -> Optional[Dict]: self.logger.debug(f"user_lookup: {result} - {_lu_stop - _lu_start}ms") return Utils.get_value_as_dict('Item', result) + def get_user_by_email(self, email: str) -> Optional[List[Dict]]: + email = AuthUtils.sanitize_email(email) + result = self.table.query( + IndexName = "email-index", + KeyConditionExpression = Key('email').eq(email), + ) + if not result: + return None + + return result["Items"] def update_user(self, user: Dict) -> Dict: username = Utils.get_value_as_string('username', user) diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/helpers/single_sign_on_helper.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/helpers/single_sign_on_helper.py index 5cc045e..39802d7 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/helpers/single_sign_on_helper.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/helpers/single_sign_on_helper.py @@ -9,7 +9,7 @@ # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions # and limitations under the License. -from ideasdk.utils import Utils +from ideasdk.utils import ApiUtils, Utils from ideadatamodel import exceptions, constants from ideasdk.config.cluster_config import ClusterConfig from ideasdk.context import SocaContext @@ -74,7 +74,7 @@ def get_callback_logout_urls(self) -> (List[str], List[str]): callback_urls = [ f'https://{load_balancer_dns_name}{sso_auth_callback_path}' ] - + if len(cluster_manager_web_context_path) > 0 and cluster_manager_web_context_path[-1] == '/': cluster_manager_web_context_path = cluster_manager_web_context_path[:-1] @@ -83,7 +83,7 @@ def get_callback_logout_urls(self) -> (List[str], List[str]): logout_urls = [ f'https://{load_balancer_dns_name}{cluster_manager_web_context_path}' - ] + ] if Utils.is_not_empty(custom_dns_name): callback_urls.append(f'https://{custom_dns_name}{sso_auth_callback_path}') logout_urls.append(f'https://{custom_dns_name}{cluster_manager_web_context_path}') @@ -235,7 +235,7 @@ def get_saml_provider_details(self, request) -> Dict: provider_details['MetadataFile'] = decoded_saml_metadata_file else: provider_details['MetadataURL'] = saml_metadata_url - + provider_details['IDPSignout'] = "true" return provider_details @@ -317,6 +317,9 @@ def create_or_update_identity_provider(self, request): if Utils.is_empty(request.provider_name): raise exceptions.invalid_params('provider_name is required') + ApiUtils.validate_input(request.provider_name, + constants.SSO_SOURCE_PROVIDER_NAME_REGEX, + constants.SSO_SOURCE_PROVIDER_NAME_ERROR_MESSAGE) if Utils.is_empty(request.provider_type): raise exceptions.invalid_params('provider_type is required') if Utils.is_empty(request.provider_email_attribute): diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/helpers/sssd_helper.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/helpers/sssd_helper.py new file mode 100644 index 0000000..9dff4b7 --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/helpers/sssd_helper.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from ideasdk.context import SocaContext +import pwd +import grp + +class SSSD: + def __init__(self, context: SocaContext): + self.context = context + self.logger = context.logger(self.get_name()) + + def get_name(self) -> str: + return 'sssd' + + def ldap_id_mapping(self) -> str: + return self.context.config().get_string('directoryservice.sssd.ldap_id_mapping', required=True) + + def get_uid_and_gid_for_user(self, username) -> tuple[int, int] | tuple[None, None]: + try: + user_info = pwd.getpwnam(username) + + uid = user_info.pw_uid + gid = user_info.pw_gid + + return uid, gid + except KeyError: + self.logger.warning(f"User: {username} not yet available") + return None, None + + def get_gid_for_group(self, groupname) -> int | None: + try: + group_info = grp.getgrnam(groupname) + + gid = group_info.gr_gid + + return gid + except KeyError: + self.logger.warning(f"Group: {groupname} not yet available") + return None diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/user_home_directory.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/user_home_directory.py index 39ad334..a5299c6 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/user_home_directory.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/accounts/user_home_directory.py @@ -12,7 +12,7 @@ from ideasdk.context import SocaContext from ideadatamodel.auth import User from ideasdk.utils import Utils -from ideadatamodel import exceptions +from ideadatamodel import exceptions, errorcodes import time import os @@ -105,20 +105,10 @@ def initialize_home_dir(self): self.own_path(dest_file) def initialize(self): - # wait for system to sync the newly created user by using system libraries to resolve the user - # this happens on a fresh installation of auth-server, where all system services have just started - # and a new clusteradmin user is created. - # although the user is created in directory services, it's not yet synced with the local system - # If you continue to see this log message it may indicate that the underlying cluster-manager - # host is not probably linked to the back-end directory service in some fashion. - while True: - try: - pwd.getpwnam(self.user.username) - break - except KeyError: - self._logger.info(f'{self.user.username} not available yet. waiting for user to be synced ...') - time.sleep(5) - + try: + pwd.getpwnam(self.user.username) + except KeyError: + raise exceptions.soca_exception(error_code=errorcodes.USER_NOT_AVAILABLE, message=f'{self.user.username} is not available.') self.initialize_home_dir() self.initialize_ssh_dir() diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/adsync/adsync_tasks.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/adsync/adsync_tasks.py index 3b107b4..f647a8f 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/adsync/adsync_tasks.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/adsync/adsync_tasks.py @@ -10,6 +10,7 @@ from typing import Dict import ideaclustermanager from ideaclustermanager.app.tasks.base_task import BaseTask +from ideasdk.utils import Utils import time @@ -61,6 +62,5 @@ def invoke(self, payload: Dict): start_time = time.time() ldap_groups = self.context.ad_sync.fetch_all_ldap_groups() group_addition_failures = self.context.ad_sync.sync_all_groups(ldap_groups) - if self.context.accounts.is_sso_enabled(): - self.context.ad_sync.sync_all_users(ldap_groups, group_addition_failures) + self.context.ad_sync.sync_all_users(ldap_groups, group_addition_failures) self.logger.info(f"-------------TIME: {time.time()-start_time}------------") diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/accounts_api.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/accounts_api.py index d09c73a..bb2ee8c 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/accounts_api.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/accounts_api.py @@ -13,6 +13,8 @@ from ideadatamodel.auth import ( GetUserRequest, GetUserResult, + GetUserByEmailRequest, + GetUserByEmailResult, ModifyUserRequest, ModifyUserResult, EnableUserRequest, @@ -58,6 +60,10 @@ def __init__(self, context: ideaclustermanager.AppContext): 'scope': self.SCOPE_READ, 'method': self.get_user }, + 'Accounts.GetUserByEmail': { + 'scope': self.SCOPE_READ, + 'method': self.get_user_by_email + }, 'Accounts.ModifyUser': { 'scope': self.SCOPE_WRITE, 'method': self.modify_user @@ -138,6 +144,11 @@ def get_user(self, context: ApiInvocationContext): user = self.context.accounts.get_user(request.username) context.success(GetUserResult(user=user)) + def get_user_by_email(self, context: ApiInvocationContext): + request = context.get_request_payload_as(GetUserByEmailRequest) + user = self.context.accounts.get_user_by_email(request.email) + context.success(GetUserByEmailResult(user=user)) + def modify_user(self, context: ApiInvocationContext): request = context.get_request_payload_as(ModifyUserRequest) email_verified = Utils.get_as_bool(request.email_verified, False) diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/analytics_api.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/analytics_api.py deleted file mode 100644 index fe22d35..0000000 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/analytics_api.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance -# with the License. A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES -# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions -# and limitations under the License. - -import ideaclustermanager - -from ideasdk.api import ApiInvocationContext, BaseAPI -from ideadatamodel.analytics import ( - OpenSearchQueryRequest, - OpenSearchQueryResult -) -from ideadatamodel import exceptions -from ideasdk.utils import Utils - - -class AnalyticsAPI(BaseAPI): - """ - Analytics API to query opensearch cluster. - - this is a stop-gap API and ideally, each module should implement their own version of Analytics API, - scoped to the indices or aliases exposed by a particular module. - - any write or index settings apis must not be exposed via this class. - an AnalyticsAdminAPI should be exposed with elevated access to enable such functionality. - - invocation simply checks if the invocation is authenticated (valid token) - """ - - def __init__(self, context: ideaclustermanager.AppContext): - self.context = context - - def opensearch_query(self, context: ApiInvocationContext): - """ - Send Raw ElasticSearch Query and Search Request - """ - request = context.get_request_payload_as(OpenSearchQueryRequest) - if Utils.is_empty(request.data): - raise exceptions.invalid_params('data is required') - - result = self.context.analytics_service().os_client.os_client.search( - **request.data - ) - - context.success(OpenSearchQueryResult( - data=result - )) - - def invoke(self, context: ApiInvocationContext): - if not context.is_authenticated(): - raise exceptions.unauthorized_access() - - namespace = context.namespace - if namespace == 'Analytics.OpenSearchQuery': - self.opensearch_query(context) diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/api_invoker.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/api_invoker.py index 3d79780..b4a65eb 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/api_invoker.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/api_invoker.py @@ -12,7 +12,7 @@ import ideaclustermanager from ideasdk.api import ApiInvocationContext from ideasdk.protocols import ApiInvokerProtocol -from ideasdk.auth import TokenService +from ideasdk.auth import TokenService, ApiAuthorizationServiceBase from ideasdk.utils import Utils from ideadatamodel.auth import ( CreateUserRequest, @@ -39,13 +39,11 @@ from ideasdk.filesystem.filebrowser_api import FileBrowserAPI from ideaclustermanager.app.api.cluster_settings_api import ClusterSettingsAPI from ideaclustermanager.app.api.filesystem_api import FileSystemAPI -from ideaclustermanager.app.api.analytics_api import AnalyticsAPI from ideaclustermanager.app.api.projects_api import ProjectsAPI from ideaclustermanager.app.api.accounts_api import AccountsAPI from ideaclustermanager.app.api.auth_api import AuthAPI from ideaclustermanager.app.api.email_templates_api import EmailTemplatesAPI from ideaclustermanager.app.api.snapshots_api import SnapshotsAPI - from typing import Optional, Dict @@ -56,7 +54,6 @@ def __init__(self, context: ideaclustermanager.AppContext): self.app_api = SocaAppAPI(context) self.file_browser_api = FileBrowserAPI(context) self.cluster_settings_api = ClusterSettingsAPI(context) - self.analytics_api = AnalyticsAPI(context) self.projects_api = ProjectsAPI(context) self.filesystem_api = FileSystemAPI(context) self.auth_api = AuthAPI(context) @@ -73,6 +70,9 @@ def __init__(self, context: ideaclustermanager.AppContext): def get_token_service(self) -> Optional[TokenService]: return self._context.token_service + def get_api_authorization_service(self) -> Optional[ApiAuthorizationServiceBase]: + return self._context.api_authorization_service + def get_request_logging_payload(self, context: ApiInvocationContext) -> Optional[Dict]: namespace = context.namespace @@ -190,8 +190,6 @@ def invoke(self, context: ApiInvocationContext): self.file_browser_api.invoke(context) elif namespace.startswith('ClusterSettings.'): self.cluster_settings_api.invoke(context) - elif namespace.startswith('Analytics.'): - self.analytics_api.invoke(context) elif namespace.startswith('Projects.'): self.projects_api.invoke(context) elif namespace.startswith('FileSystem.'): diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/filesystem_api.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/filesystem_api.py index c15f383..cae593b 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/filesystem_api.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/filesystem_api.py @@ -1,16 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import List, Union, Set -import botocore.exceptions from pydantic import ValidationError import ideaclustermanager -from ideaclustermanager.app.shared_filesystem.shared_filesystem_service import ( - SharedFilesystemService, -) -from ideadatamodel.exceptions import SocaException -from ideadatamodel.shared_filesystem.shared_filesystem_api import FSxONTAPDeploymentType from ideasdk.api import ApiInvocationContext, BaseAPI from ideadatamodel.shared_filesystem import ( @@ -18,34 +11,22 @@ RemoveFileSystemFromProjectRequest, AddFileSystemToProjectResult, RemoveFileSystemFromProjectResult, - CommonCreateFileSystemRequest, CreateFileSystemResult, - FileSystem, - FSxONTAPSVM, - FSxONTAPVolume, - ListFileSystemInVPCResult, - CommonOnboardFileSystemRequest, OnboardEFSFileSystemRequest, OnboardONTAPFileSystemRequest, OnboardFileSystemResult, CreateEFSFileSystemRequest, - CreateONTAPFileSystemRequest, - EFSFileSystem, - FSxONTAPFileSystem, + CreateONTAPFileSystemRequest ) from ideadatamodel import ( - exceptions, - constants, - errorcodes, + exceptions ) from ideasdk.utils import Utils - class FileSystemAPI(BaseAPI): def __init__(self, context: ideaclustermanager.AppContext): self.context = context self.config = self.context.config() - self.shared_filesystem_service = SharedFilesystemService(context=self.context) self.logger = self.context.logger("shared-filesystem") self.SCOPE_WRITE = f"{self.context.module_id()}/write" @@ -83,22 +64,13 @@ def __init__(self, context: ideaclustermanager.AppContext): } def create_filesystem(self, context: ApiInvocationContext): - request = context.get_request_payload_as(CommonCreateFileSystemRequest) - if Utils.is_any_empty( - request.filesystem_name, - request.filesystem_title, - ): - raise exceptions.soca_exception( - error_code=errorcodes.INVALID_PARAMS, - message="needed parameters cannot be empty", - ) try: if context.namespace == "FileSystem.CreateEFS": - self._create_efs( + self.context.shared_filesystem.create_efs( context.get_request_payload_as(CreateEFSFileSystemRequest) ) elif context.namespace == "FileSystem.CreateONTAP": - self._create_fsx_ontap( + self.context.shared_filesystem.create_fsx_ontap( context.get_request_payload_as(CreateONTAPFileSystemRequest) ) except ValidationError as e: @@ -107,365 +79,29 @@ def create_filesystem(self, context: ApiInvocationContext): context.success(CreateFileSystemResult()) - def _create_efs(self, request: CreateEFSFileSystemRequest): - self._validate_correct_subnet_selection({request.subnet_id_1, request.subnet_id_2}) - self._validate_filesystem_does_not_exist(request.filesystem_name) - try: - efs_client = self.context.aws().efs() - security_group_id = self.config.get_string( - f"{constants.MODULE_SHARED_STORAGE}.security_group_id" - ) - - tags_for_fs = self.shared_filesystem_service.create_tags( - request.filesystem_name - ) - - efs_create_response = efs_client.create_file_system( - ThroughputMode="elastic", Encrypted=True, Tags=tags_for_fs - ) - efs_filesystem = EFSFileSystem(efs=efs_create_response) - efs_filesystem_id = efs_filesystem.get_filesystem_id() - - self.shared_filesystem_service.efs_check_filesystem_exists( - filesystem_id=efs_filesystem_id, wait=True - ) - filesystem_policy = { - "Version": "2012-10-17", - "Id": "efs-prevent-anonymous-access-policy", - "Statement": [ - { - "Sid": "efs-statement", - "Effect": "Allow", - "Principal": {"AWS": "*"}, - "Action": [ - "elasticfilesystem:ClientRootAccess", - "elasticfilesystem:ClientWrite", - "elasticfilesystem:ClientMount", - ], - "Resource": f"arn:aws:elasticfilesystem:{self.context.aws().aws_region()}:{self.context.aws().aws_account_id()}:file-system/{efs_filesystem_id}", - "Condition": { - "Bool": {"elasticfilesystem:AccessedViaMountTarget": "true"} - }, - } - ], - } - efs_client.put_file_system_policy( - FileSystemId=efs_filesystem_id, Policy=Utils.to_json(filesystem_policy) - ) - - # Create mount targets - efs_client.create_mount_target( - FileSystemId=efs_filesystem_id, - SubnetId=request.subnet_id_1, - SecurityGroups=[security_group_id], - ) - efs_client.create_mount_target( - FileSystemId=efs_filesystem_id, - SubnetId=request.subnet_id_2, - SecurityGroups=[security_group_id], - ) - - # Sync cluster-settings ddb - config_entries = self.shared_filesystem_service.build_config_for_new_efs( - efs=efs_filesystem, request=request - ) - self.config.db.sync_cluster_settings_in_db( - config_entries=config_entries, overwrite=True - ) - except botocore.exceptions.ClientError as e: - error_message = e.response["Error"]["Message"] - raise exceptions.general_exception(error_message) - - def _create_fsx_ontap(self, request: CreateONTAPFileSystemRequest): - self._validate_correct_subnet_selection({request.primary_subnet, request.standby_subnet}) - self._validate_filesystem_does_not_exist(request.filesystem_name) - try: - fsx_client = self.context.aws().fsx() - security_group_id = self.config.get_string( - f"{constants.MODULE_SHARED_STORAGE}.security_group_id" - ) - - tags_for_fs = self.shared_filesystem_service.create_tags( - request.filesystem_name - ) - - _subnet_ids = [] - if request.deployment_type == FSxONTAPDeploymentType.SINGLE_AZ: - _subnet_ids = [request.primary_subnet] - else: - _subnet_ids = [request.primary_subnet, request.standby_subnet] - fs_create_response = fsx_client.create_file_system( - FileSystemType="ONTAP", - SecurityGroupIds=[ - security_group_id, - ], - Tags=tags_for_fs, - StorageCapacity=request.storage_capacity, - SubnetIds=_subnet_ids, - OntapConfiguration={ - "PreferredSubnetId": request.primary_subnet, - "DeploymentType": request.deployment_type, - "ThroughputCapacity": 128, # parameter required - }, - ) - - fsx_ontap_filesystem = FSxONTAPFileSystem( - filesystem=fs_create_response["FileSystem"] - ) - fs_id = fsx_ontap_filesystem.get_filesystem_id() - - svm = fsx_client.create_storage_virtual_machine( - FileSystemId=fs_id, - Name="fsx", - ) - - svm_volume = fsx_client.create_volume( - VolumeType="ONTAP", - Name="vol1", - OntapConfiguration={ - "JunctionPath": "/vol1", - "SecurityStyle": request.volume_security_style, - "SizeInMegabytes": 1024, - "StorageVirtualMachineId": svm["StorageVirtualMachine"][ - "StorageVirtualMachineId" - ], - "StorageEfficiencyEnabled": True, - }, - ) - - fsx_ontap_svm = FSxONTAPSVM( - storage_virtual_machine=svm["StorageVirtualMachine"] - ) - fsx_ontap_volume = FSxONTAPVolume(volume=svm_volume["Volume"]) - - # Update cluster-settings dynamodb - config_entries = self.shared_filesystem_service.build_config_for_new_ontap( - svm=fsx_ontap_svm, volume=fsx_ontap_volume, request=request - ) - self.config.db.sync_cluster_settings_in_db( - config_entries=config_entries, overwrite=True - ) - except botocore.exceptions.ClientError as e: - error_message = e.response["Error"]["Message"] - self.logger.error(error_message) - raise exceptions.general_exception(error_message) - def add_filesystem_to_project(self, context: ApiInvocationContext): request = context.get_request_payload_as(AddFileSystemToProjectRequest) - - filesystem_name = request.filesystem_name - project_name = request.project_name - self._check_required_parameters(request=request) - self.update_filesystem_to_project_mapping(filesystem_name, project_name) + self.context.shared_filesystem.add_filesystem_to_project(request) context.success(AddFileSystemToProjectResult()) - def update_filesystem_to_project_mapping(self, filesystem_name: str, project_name: str): - fs = self._get_filesystem(filesystem_name) - projects = fs.get_projects() - if Utils.is_empty(projects): - projects = [] - if project_name not in projects: - projects.append(project_name) - self._update_projects_for_filesystem(fs, projects) - def remove_filesystem_from_project(self, context: ApiInvocationContext): request = context.get_request_payload_as(RemoveFileSystemFromProjectRequest) - - filesystem_name = request.filesystem_name - project_name = request.project_name - self._check_required_parameters(request=request) - - fs = self._get_filesystem(filesystem_name) - projects = fs.get_projects() - if project_name in projects: - projects.remove(project_name) - self._update_projects_for_filesystem(fs, projects) - + self.context.shared_filesystem.remove_filesystem_from_project(request) context.success(RemoveFileSystemFromProjectResult()) def list_file_systems_in_vpc(self, context: ApiInvocationContext): - onboarded_filesystems = self._list_shared_filesystems() - onboarded_filesystem_ids = set([fs.get_filesystem_id() for fs in onboarded_filesystems]) - - efs_filesystems = self._list_unonboarded_efs_file_systems(onboarded_filesystem_ids) - fsx_filesystems = self._list_unonboarded_ontap_file_systems() - context.success(ListFileSystemInVPCResult(efs=efs_filesystems, fsx=fsx_filesystems)) + result = self.context.shared_filesystem.list_file_systems_in_vpc() + context.success(result) def onboard_filesystem(self, context: ApiInvocationContext): - request = context.get_request_payload_as(CommonOnboardFileSystemRequest) - if Utils.is_any_empty( - request.filesystem_id, - request.filesystem_name, - request.filesystem_title, - ): - raise exceptions.soca_exception( - error_code=errorcodes.INVALID_PARAMS, - message="needed parameters cannot be empty", - ) - self._validate_filesystem_does_not_exist(request.filesystem_name) if context.namespace == 'FileSystem.OnboardEFSFileSystem': - config_entries = self.shared_filesystem_service.build_config_for_vpc_efs(context.get_request_payload_as(OnboardEFSFileSystemRequest)) - self.context.config().db.sync_cluster_settings_in_db( - config_entries=config_entries, - overwrite=True - ) + self.context.shared_filesystem.onboard_efs_filesystem(context.get_request_payload_as(OnboardEFSFileSystemRequest)) + elif context.namespace == 'FileSystem.OnboardONTAPFileSystem': - config_entries = self.shared_filesystem_service.build_config_for_vpc_ontap(context.get_request_payload_as(OnboardONTAPFileSystemRequest)) - self.context.config().db.sync_cluster_settings_in_db( - config_entries=config_entries, - overwrite=True - ) + self.context.shared_filesystem.onboard_ontap_filesystem(context.get_request_payload_as(OnboardONTAPFileSystemRequest)) + context.success(OnboardFileSystemResult()) - @staticmethod - def _check_required_parameters( - request: Union[ - AddFileSystemToProjectRequest, RemoveFileSystemFromProjectRequest - ] - ): - if Utils.is_empty(request.filesystem_name): - raise exceptions.invalid_params("filesystem_name is required") - if Utils.is_empty(request.project_name): - raise exceptions.invalid_params("project_name is required") - - def _list_unonboarded_efs_file_systems(self, onboarded_filesystem_ids: Set[str]) -> List[EFSFileSystem]: - try: - onboarded_internal_filesystem_ids = [ - self.config.db.get_config_entry("shared-storage.home.efs.file_system_id")['value'], - self.config.db.get_config_entry("shared-storage.internal.efs.file_system_id")['value'] - ] - env_vpc_id = self.config.db.get_config_entry("cluster.network.vpc_id")['value'] - efs_client = self.context.aws().efs() - efs_response = efs_client.describe_file_systems()["FileSystems"] - - filesystems: List[EFSFileSystem] = [] - for efs in efs_response: - if efs['LifeCycleState'] != 'available': - continue - fs_id = efs['FileSystemId'] - efs_mt_response = efs_client.describe_mount_targets(FileSystemId=fs_id) - if len(efs_mt_response['MountTargets']) == 0: - continue - if env_vpc_id == efs_mt_response['MountTargets'][0]['VpcId'] and \ - fs_id not in onboarded_filesystem_ids and \ - fs_id not in onboarded_internal_filesystem_ids: - filesystems.append(EFSFileSystem(efs = efs)) - return filesystems - except botocore.exceptions.ClientError as e: - error_message = e.response["Error"]["Message"] - raise exceptions.general_exception(error_message) - - def _list_unonboarded_ontap_file_systems(self) -> List[FSxONTAPFileSystem]: - try: - env_vpc_id = self.config.db.get_config_entry("cluster.network.vpc_id")['value'] - - fsx_client = self.context.aws().fsx() - fsx_response = fsx_client.describe_file_systems()["FileSystems"] - ec2_client = self.context.aws().ec2() - - filesystems: List[FSxONTAPFileSystem] = [] - for fsx in fsx_response: - if fsx['Lifecycle'] != 'AVAILABLE': - continue - fs_id = fsx['FileSystemId'] - subnet_response = ec2_client.describe_subnets(SubnetIds=fsx['SubnetIds']) - if env_vpc_id == subnet_response['Subnets'][0]['VpcId']: - volume_response = fsx_client.describe_volumes( - Filters=[{ - 'Name': 'file-system-id', - 'Values': [fs_id] - }] - ) - svm_response = fsx_client.describe_storage_virtual_machines( - Filters=[{ - 'Name': 'file-system-id', - 'Values': [fs_id] - }] - ) - if len(svm_response['StorageVirtualMachines']) == 0 or len(volume_response['Volumes']) == 0: - continue - list_created_svms = list(filter( - lambda svm_obj: svm_obj['Lifecycle'] == "CREATED", - svm_response['StorageVirtualMachines'] - )) - list_created_volumes = list(filter( - lambda volume_obj: volume_obj['Lifecycle'] == "CREATED", - volume_response['Volumes'] - )) - svm_list = [FSxONTAPSVM(storage_virtual_machine = svm) for svm in list_created_svms] - volume_list = [FSxONTAPVolume(volume = volume) for volume in list_created_volumes] - filesystems.append(FSxONTAPFileSystem(filesystem = fsx, svm = svm_list, volume = volume_list)) - return filesystems - except botocore.exceptions.ClientError as e: - error_message = e.response["Error"]["Message"] - raise exceptions.general_exception(error_message) - - def _list_shared_filesystems(self) -> List[FileSystem]: - shared_storage_config = self.config.get_config(constants.MODULE_SHARED_STORAGE) - shared_storage_config_dict = shared_storage_config.as_plain_ordered_dict() - - filesystem: List[FileSystem] = [] - for fs_name, config in shared_storage_config_dict.items(): - if Utils.is_not_empty(Utils.get_as_dict(config)) and "projects" in list( - config.keys() - ): - filesystem.append(FileSystem(name=fs_name, storage=config)) - - return filesystem - - def _get_filesystem(self, filesystem_name: str): - filesystems = self._list_shared_filesystems() - - if Utils.is_empty(filesystems): - raise exceptions.soca_exception( - error_code=errorcodes.NO_SHARED_FILESYSTEM_FOUND, - message="did not find any shared filesystem", - ) - - for fs in filesystems: - if fs.get_name() == filesystem_name: - return fs - - raise exceptions.soca_exception( - error_code=errorcodes.FILESYSTEM_NOT_FOUND, - message=f"could not find filesystem {filesystem_name}", - ) - - def _update_projects_for_filesystem( - self, filesystem: FileSystem, projects: List[str] - ): - self.config.db.set_config_entry( - f"{constants.MODULE_SHARED_STORAGE}.{filesystem.get_name()}.projects", # update entry on cluster settings dynamodb table - projects, - ) - self.config.put( - f"{constants.MODULE_SHARED_STORAGE}.{filesystem.get_name()}.projects", # update local config tree - projects, - ) - - def _validate_filesystem_does_not_exist(self, filesystem_name: str): - try: - if Utils.is_not_empty(self._get_filesystem(filesystem_name)): - raise exceptions.soca_exception( - error_code=errorcodes.INVALID_PARAMS, - message=f"{filesystem_name} already exists", - ) - except SocaException as e: - if e.error_code == errorcodes.FILESYSTEM_NOT_FOUND or e.error_code == errorcodes.NO_SHARED_FILESYSTEM_FOUND: - pass - else: - raise e - - def _validate_correct_subnet_selection(self, subnet_ids: Set[str]): - res_private_subnets = self.config.db.get_config_entry("cluster.network.private_subnets")['value'] - for subnet_id in subnet_ids: - if subnet_id is not None and subnet_id not in res_private_subnets: - self.logger.error(f"{subnet_id} is not a RES private subnet") - raise exceptions.soca_exception( - error_code=errorcodes.INVALID_PARAMS, - message=f"{subnet_id} is not a RES private subnet", - ) - def invoke(self, context: ApiInvocationContext): namespace = context.namespace diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/projects_api.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/projects_api.py index efbe6e3..7379bf2 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/projects_api.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/projects_api.py @@ -12,11 +12,11 @@ import ideaclustermanager from ideadatamodel.shared_filesystem import FileSystem -from ideaclustermanager.app.api.filesystem_api import FileSystemAPI from ideasdk.api import ApiInvocationContext, BaseAPI from ideadatamodel.projects import ( CreateProjectRequest, + DeleteProjectRequest, GetProjectRequest, UpdateProjectRequest, ListProjectsRequest, @@ -27,14 +27,13 @@ ListFileSystemsForProjectResult ) from ideadatamodel import exceptions, constants, errorcodes -from ideasdk.utils import Utils +from ideasdk.utils import Utils, ApiUtils class ProjectsAPI(BaseAPI): def __init__(self, context: ideaclustermanager.AppContext): self.context = context - self.file_system_api = FileSystemAPI(context=self.context) self.logger = context.logger('projects') self.SCOPE_WRITE = f'{self.context.module_id()}/write' self.SCOPE_READ = f'{self.context.module_id()}/read' @@ -44,6 +43,10 @@ def __init__(self, context: ideaclustermanager.AppContext): 'scope': self.SCOPE_WRITE, 'method': self.create_project }, + 'Projects.DeleteProject': { + 'scope': self.SCOPE_WRITE, + 'method': self.delete_project + }, 'Projects.GetProject': { 'scope': self.SCOPE_READ, 'method': self.get_project @@ -76,10 +79,18 @@ def __init__(self, context: ideaclustermanager.AppContext): def create_project(self, context: ApiInvocationContext): request = context.get_request_payload_as(CreateProjectRequest) + ApiUtils.validate_input(request.project.name, + constants.PROJECT_ID_REGEX, + constants.PROJECT_ID_ERROR_MESSAGE) result = self.context.projects.create_project(request) if not Utils.is_empty(request.filesystem_names): for fs_name in request.filesystem_names: - self.file_system_api.update_filesystem_to_project_mapping(fs_name, request.project.name) + self.context.shared_filesystem.update_filesystem_to_project_mapping(fs_name, request.project.name) + context.success(result) + + def delete_project(self, context: ApiInvocationContext): + request = context.get_request_payload_as(DeleteProjectRequest) + result = self.context.projects.delete_project(request) context.success(result) def get_project(self, context: ApiInvocationContext): @@ -153,17 +164,17 @@ def invoke(self, context: ApiInvocationContext): acl_entry = Utils.get_value_as_dict(namespace, self.acl) if acl_entry is None: raise exceptions.unauthorized_access() - + acl_entry_scope = Utils.get_value_as_string('scope', acl_entry) is_authorized = context.is_authorized(elevated_access=True, scopes=[acl_entry_scope]) is_authenticated_user = context.is_authenticated_user() - + if is_authorized: acl_entry['method'](context) return - + if is_authenticated_user and namespace in ('Projects.GetUserProjects', 'Projects.GetProject'): acl_entry['method'](context) return - + raise exceptions.unauthorized_access() diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/snapshots_api.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/snapshots_api.py index 2bc20a3..52938a1 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/snapshots_api.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/api/snapshots_api.py @@ -13,7 +13,10 @@ from ideadatamodel.snapshots import ( CreateSnapshotRequest, CreateSnapshotResult, - ListSnapshotsRequest + ListSnapshotsRequest, + ApplySnapshotRequest, + ApplySnapshotResult, + ListApplySnapshotRecordsRequest ) from ideadatamodel import exceptions from ideasdk.utils import Utils @@ -37,6 +40,14 @@ def __init__(self, context: ideaclustermanager.AppContext): 'Snapshots.ListSnapshots': { 'scope': self.SCOPE_READ, 'method': self.list_snapshots + }, + 'Snapshots.ApplySnapshot': { + 'scope': self.SCOPE_WRITE, + 'method': self.apply_snapshot + }, + 'Snapshots.ListAppliedSnapshots': { + 'scope': self.SCOPE_READ, + 'method': self.list_applied_snapshots } } @@ -47,7 +58,7 @@ def create_snapshot(self, context: ApiInvocationContext): self.context.snapshots.create_snapshot(request.snapshot) context.success(CreateSnapshotResult( - result='Successfully created Snapshot.' + message='Successfully created Snapshot.' )) def list_snapshots(self, context: ApiInvocationContext): @@ -55,6 +66,21 @@ def list_snapshots(self, context: ApiInvocationContext): result = self.context.snapshots.list_snapshots(request) context.success(result) + def apply_snapshot(self, context: ApiInvocationContext): + request = context.get_request_payload_as(ApplySnapshotRequest) + if not request.snapshot: + raise exceptions.invalid_params('Snapshot details are empty.') + + self.context.snapshots.apply_snapshot(request.snapshot) + context.success(ApplySnapshotResult( + message='Successfully submitted Apply Snapshot request.' + )) + + def list_applied_snapshots(self, context: ApiInvocationContext): + request = context.get_request_payload_as(ListApplySnapshotRecordsRequest) + result = self.context.snapshots.list_applied_snapshots(request) + context.success(result) + def invoke(self, context: ApiInvocationContext): namespace = context.namespace diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/app_context.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/app_context.py index 1a45e95..1cfc282 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/app_context.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/app_context.py @@ -9,8 +9,9 @@ # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions # and limitations under the License. from ideasdk.context import SocaContext, SocaContextOptions -from ideasdk.auth import TokenService +from ideasdk.auth import TokenService, ApiAuthorizationServiceBase from ideasdk.utils import GroupNameHelper +from ideasdk.client.vdc_client import AbstractVirtualDesktopControllerClient from ideaclustermanager.app.projects.projects_service import ProjectsService from ideaclustermanager.app.accounts.accounts_service import AccountsService @@ -21,6 +22,7 @@ from ideaclustermanager.app.accounts.ad_automation_agent import ADAutomationAgent from ideaclustermanager.app.email_templates.email_templates_service import EmailTemplatesService from ideaclustermanager.app.notifications.notifications_service import NotificationsService +from ideaclustermanager.app.shared_filesystem.shared_filesystem_service import SharedFilesystemService from ideaclustermanager.app.tasks.task_manager import TaskManager from typing import Optional, Union @@ -34,6 +36,7 @@ def __init__(self, options: SocaContextOptions): ) self.token_service: Optional[TokenService] = None + self.api_authorization_service: Optional[ApiAuthorizationServiceBase] = None self.projects: Optional[ProjectsService] = None self.user_pool: Optional[CognitoUserPool] = None self.ldap_client: Optional[Union[OpenLDAPClient, ActiveDirectoryClient]] = None @@ -45,3 +48,5 @@ def __init__(self, options: SocaContextOptions): self.group_name_helper: Optional[GroupNameHelper] = None self.snapshots: Optional[SnapshotsService] = None self.ad_sync: Optional[ADSyncService] = None + self.vdc_client: Optional[AbstractVirtualDesktopControllerClient] = None + self.shared_filesystem: Optional[SharedFilesystemService] diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/app_main.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/app_main.py index 71d920a..6d295ef 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/app_main.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/app_main.py @@ -52,7 +52,6 @@ def main(**kwargs): enable_distributed_lock=True, enable_leader_election=True, enable_metrics=True, - enable_analytics=True ) ), **kwargs diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/auth/__init__.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/auth/__init__.py new file mode 100644 index 0000000..59d9e03 --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/auth/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. \ No newline at end of file diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/auth/api_authorization_service.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/auth/api_authorization_service.py new file mode 100644 index 0000000..4c9d851 --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/auth/api_authorization_service.py @@ -0,0 +1,22 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from ideasdk.auth.api_authorization_service_base import ApiAuthorizationServiceBase +from ideaclustermanager.app.accounts.accounts_service import AccountsService +from ideadatamodel.auth import User +from typing import Optional + +class ClusterManagerApiAuthorizationService(ApiAuthorizationServiceBase): + def __init__(self, accounts: AccountsService): + self.accounts = accounts + + def get_user_from_token_username(self, token_username: str) -> Optional[User]: + return self.accounts.get_user_from_token_username(token_username=token_username) diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/cluster_manager_app.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/cluster_manager_app.py index 13358b7..19ae4c1 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/cluster_manager_app.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/cluster_manager_app.py @@ -15,6 +15,7 @@ from ideasdk.client.evdi_client import EvdiClient from ideasdk.server import SocaServerOptions from ideasdk.utils import GroupNameHelper +from ideasdk.client.vdc_client import SocaClientOptions, VirtualDesktopControllerClient import ideaclustermanager from ideaclustermanager.app.api.api_invoker import ClusterManagerApiInvoker @@ -29,7 +30,7 @@ from ideaclustermanager.app.accounts.cognito_user_pool import CognitoUserPool, CognitoUserPoolOptions from ideaclustermanager.app.accounts.ldapclient.ldap_client_factory import build_ldap_client - +from ideaclustermanager.app.auth.api_authorization_service import ClusterManagerApiAuthorizationService from ideaclustermanager.app.accounts.ad_automation_agent import ADAutomationAgent from ideaclustermanager.app.accounts.account_tasks import ( SyncUserInDirectoryServiceTask, @@ -47,6 +48,7 @@ from ideaclustermanager.app.email_templates.email_templates_service import EmailTemplatesService from ideaclustermanager.app.notifications.notifications_service import NotificationsService from ideaclustermanager.app.snapshots.snapshots_service import SnapshotsService +from ideaclustermanager.app.shared_filesystem.shared_filesystem_service import SharedFilesystemService from typing import Optional @@ -95,6 +97,7 @@ def app_initialize(self): provider_url = self.context.config().get_string('identity-provider.cognito.provider_url', required=True) client_id = self.context.config().get_secret('cluster-manager.client_id', required=True) client_secret = self.context.config().get_secret('cluster-manager.client_secret', required=True) + vdc_module_id = self.context.config().get_module_id(constants.MODULE_VIRTUAL_DESKTOP_CONTROLLER) self.context.token_service = TokenService( context=self.context, options=TokenServiceOptions( @@ -102,6 +105,10 @@ def app_initialize(self): cognito_user_pool_domain_url=domain_url, client_id=client_id, client_secret=client_secret, + client_credentials_scope=[ + f'{vdc_module_id}/read', + f'{vdc_module_id}/write', + ], administrators_group_name=administrators_group_name, managers_group_name=managers_group_name ) @@ -156,6 +163,9 @@ def app_initialize(self): evdi_client=evdi_client, token_service=self.context.token_service ) + + #api authorization service + self.context.api_authorization_service = ClusterManagerApiAuthorizationService(accounts=self.context.accounts) # adsync service self.context.ad_sync = ADSyncService( @@ -163,6 +173,16 @@ def app_initialize(self): task_manager=self.context.task_manager, ) + internal_endpoint = self.context.config().get_cluster_internal_endpoint() + self.context.vdc_client = VirtualDesktopControllerClient( + context=self.context, + options=SocaClientOptions( + endpoint=f'{internal_endpoint}/{vdc_module_id}/api/v1', + enable_logging=False, + verify_ssl=False), + token_service=self.context.token_service + ) + self.context.snapshots = SnapshotsService( context=self.context ) @@ -171,7 +191,8 @@ def app_initialize(self): self.context.projects = ProjectsService( context=self.context, accounts_service=self.context.accounts, - task_manager=self.context.task_manager + task_manager=self.context.task_manager, + vdc_client=self.context.vdc_client ) # email templates @@ -185,6 +206,10 @@ def app_initialize(self): accounts=self.context.accounts, email_templates=self.context.email_templates ) + + self.context.shared_filesystem = SharedFilesystemService( + context=self.context + ) # web portal self.web_portal = WebPortal( diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/projects/db/projects_dao.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/projects/db/projects_dao.py index 6a1028d..b70f274 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/projects/db/projects_dao.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/projects/db/projects_dao.py @@ -81,6 +81,7 @@ def convert_from_db(project: Dict) -> Project: title = Utils.get_value_as_string('title', project) description = Utils.get_value_as_string('description', project) ldap_groups = Utils.get_value_as_list('ldap_groups', project, []) + users = project.get('users', []) enabled = Utils.get_value_as_bool('enabled', project, False) enable_budgets = Utils.get_value_as_bool('enable_budgets', project, False) budget_name = Utils.get_value_as_string('budget_name', project) @@ -96,8 +97,8 @@ def convert_from_db(project: Dict) -> Project: for key, value in db_tags.items(): tags.append(SocaKeyValue(key=key, value=value)) - created_on = Utils.get_value_as_int('created_on', project) - updated_on = Utils.get_value_as_int('updated_on', project) + created_on = Utils.get_value_as_int('created_on', project, 0) + updated_on = Utils.get_value_as_int('updated_on', project, 0) return Project( project_id=project_id, @@ -105,6 +106,7 @@ def convert_from_db(project: Dict) -> Project: name=name, description=description, ldap_groups=ldap_groups, + users=users, tags=tags, enabled=enabled, enable_budgets=enable_budgets, @@ -137,6 +139,9 @@ def convert_to_db(project: Project) -> Dict: if project.ldap_groups is not None: db_project['ldap_groups'] = project.ldap_groups + if project.users is not None: + db_project['users'] = project.users + if project.tags is not None: tags = {} for tag in project.tags: @@ -168,7 +173,6 @@ def get_project_by_id(self, project_id: str) -> Optional[Dict]: if Utils.is_empty(project_id): raise exceptions.invalid_params('project_id is required') - result = self.table.get_item( Key={ 'project_id': project_id diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/projects/db/user_projects_dao.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/projects/db/user_projects_dao.py index d192693..d3a0b62 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/projects/db/user_projects_dao.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/projects/db/user_projects_dao.py @@ -17,7 +17,7 @@ from ideaclustermanager.app.projects.db.projects_dao import ProjectsDAO from typing import List -from boto3.dynamodb.conditions import Key +from boto3.dynamodb.conditions import Attr, Key class UserProjectsDAO: @@ -98,21 +98,53 @@ def initialize(self): self.user_projects_table = self.context.aws().dynamodb_table().Table(self.get_user_projects_table_name()) self.project_groups_table = self.context.aws().dynamodb_table().Table(self.get_project_groups_table_name()) - def create_user_project(self, project_id: str, username: str): + def create_user_project(self, project_id: str, username: str, by_ldaps: list[str]): if Utils.is_empty(project_id): raise exceptions.invalid_params('project_id is required') if Utils.is_empty(username): raise exceptions.invalid_params('username is required') self.logger.info(f'added user project: {project_id}, username: {username}') - self.user_projects_table.put_item( - Item={ - 'project_id': project_id, - 'username': username - } - ) + # if user is already in this project, update by_ldap attribute + if self.is_user_in_project(username, project_id): + result = self.user_projects_table.get_item( + Key={ + 'username': username, + 'project_id': project_id + } + ) + existing_user = Utils.get_value_as_dict('Item', result) + existing_user_ldaps = set(existing_user['by_ldaps'] if existing_user['by_ldaps'] else []) + + by_ldaps = list(existing_user_ldaps.union(set(by_ldaps))) + + # Update Expression + update_expression = 'SET #attr = :val' + + # Expression Attribute Names and Values + expression_attribute_names = {'#attr': 'by_ldaps'} + expression_attribute_values = {':val': by_ldaps} + self.user_projects_table.update_item( + Key={ + 'project_id': project_id, + 'username': username + }, + UpdateExpression=update_expression, + ExpressionAttributeNames=expression_attribute_names, + ExpressionAttributeValues=expression_attribute_values + ) + + else: + self.user_projects_table.put_item( + Item={ + 'project_id': project_id, + 'username': username, + 'by_ldaps': by_ldaps + } + ) - def delete_user_project(self, project_id: str, username: str): + + def delete_user_project(self, project_id: str, username: str, by_ldaps: list[str], force: bool = False): if Utils.is_empty(project_id): raise exceptions.invalid_params('project_id is required') if Utils.is_empty(username): @@ -125,38 +157,53 @@ def delete_user_project(self, project_id: str, username: str): message=f'project: {project_id} not found.' ) - project_groups = set(project['ldap_groups'] if 'ldap_groups' in project else []) - user = self.accounts_service.user_dao.get_user( username=username) if user is None: raise exceptions.soca_exception( error_code=errorcodes.AUTH_USER_NOT_FOUND, message=f'user: {username} not found.' - ) - - user_groups = set(user['additional_groups'] if 'additional_groups' in user else []) - - common_groups = project_groups.intersection(user_groups) - # delete user-project relation only user's additional_groups and project's ldap_groups has no common groups. - if len(common_groups) == 0: - self.logger.info( - f'deleted user project: {project_id}, username: {username}') - self.user_projects_table.delete_item( + ) + if self.is_user_in_project(username, project_id): + result = self.user_projects_table.get_item( Key={ - 'project_id': project_id, - 'username': username + 'username': username, + 'project_id': project_id } ) - else: - self.logger.info( - f'Not deleting user project: {project_id}, username: {username} as user has permissions to access project through group/s: {common_groups}') + existing_user = Utils.get_value_as_dict('Item', result) + existing_user_ldaps = set(existing_user['by_ldaps'] if existing_user['by_ldaps'] else []) + by_ldaps = list(existing_user_ldaps - set(by_ldaps)) + if by_ldaps: + # Update Expression + update_expression = 'SET #attr = :val' + + # Expression Attribute Names and Values + expression_attribute_names = {'#attr': 'by_ldaps'} + expression_attribute_values = {':val': by_ldaps} + self.user_projects_table.update_item( + Key={ + 'project_id': project_id, + 'username': username + }, + UpdateExpression=update_expression, + ExpressionAttributeNames=expression_attribute_names, + ExpressionAttributeValues=expression_attribute_values + ) + else: + # remove previous record + self.user_projects_table.delete_item( + Key={ + 'project_id': project_id, + 'username': username + } + ) def ldap_group_added(self, project_id: str, group_name: str): if Utils.is_empty(project_id): raise exceptions.invalid_params('project_id is required') if Utils.is_empty(group_name): - raise exceptions.invalid_params('username is required') + raise exceptions.invalid_params('group_name is required') self.project_groups_table.put_item( Item={ @@ -169,14 +216,15 @@ def ldap_group_added(self, project_id: str, group_name: str): for username in usernames: self.create_user_project( project_id=project_id, - username=username + username=username, + by_ldaps=[group_name] ) - def ldap_group_removed(self, project_id: str, group_name: str): + def ldap_group_removed(self, project_id: str, group_name: str, force: bool = False): if Utils.is_empty(project_id): raise exceptions.invalid_params('project_id is required') if Utils.is_empty(group_name): - raise exceptions.invalid_params('username is required') + raise exceptions.invalid_params('group_name is required') self.project_groups_table.delete_item( Key={ @@ -189,9 +237,21 @@ def ldap_group_removed(self, project_id: str, group_name: str): for username in usernames: self.delete_user_project( project_id=project_id, - username=username + username=username, + by_ldaps=[group_name], + force=force, ) + def delete_project(self, project_id: str): + if Utils.is_empty(project_id): + raise exceptions.invalid_params('project_id is required') + + result = self.project_groups_table.scan( + FilterExpression=Attr("project_id").eq(project_id) + ) + for item in Utils.get_value_as_list('Items', result, []): + self.ldap_group_removed(project_id, item['group_name'], True) + def get_projects_by_username(self, username: str) -> List[str]: if Utils.is_empty(username): raise exceptions.invalid_params('username is required') @@ -233,7 +293,8 @@ def group_member_added(self, group_name: str, username: str): for project_id in project_ids: self.create_user_project( project_id=project_id, - username=username + username=username, + by_ldaps=[group_name] ) def group_member_removed(self, group_name: str, username: str): @@ -245,7 +306,8 @@ def group_member_removed(self, group_name: str, username: str): for project_id in project_ids: self.delete_user_project( project_id=project_id, - username=username + username=username, + by_ldaps=[group_name] ) def project_disabled(self, project_id: str): @@ -254,9 +316,16 @@ def project_disabled(self, project_id: str): project = self.projects_dao.get_project_by_id(project_id) ldap_groups = project['ldap_groups'] + usernames = project['users'] for ldap_group in ldap_groups: self.ldap_group_removed(project_id=project_id, group_name=ldap_group) + for username in usernames: + self.delete_user_project( + project_id=project_id, + username=username, + by_ldaps=["SELF"] + ) def project_enabled(self, project_id: str): if Utils.are_empty(project_id): @@ -264,10 +333,18 @@ def project_enabled(self, project_id: str): project = self.projects_dao.get_project_by_id(project_id) ldap_groups = project['ldap_groups'] + users = project.get('users', []) for ldap_group in ldap_groups: self.ldap_group_added(project_id=project_id, group_name=ldap_group) + for user in users: + self.create_user_project( + project_id=project_id, + username=user, + by_ldaps=['SELF'] + ) + def is_user_in_project(self, username: str, project_id: str): if Utils.is_empty(username): raise exceptions.invalid_params('username is required') @@ -281,7 +358,7 @@ def is_user_in_project(self, username: str, project_id: str): } ) return Utils.get_value_as_dict('Item', result) is not None - + def has_projects_in_group(self, group_name: str) -> bool: query_result = self.project_groups_table.query( Limit=1, @@ -294,3 +371,17 @@ def has_projects_in_group(self, group_name: str) -> bool: ) memberships = query_result['Items'] if 'Items' in query_result else [] return len(memberships) > 0 + + def is_group_in_project(self, group_name: str, project_id: str): + if Utils.is_empty(group_name): + raise exceptions.invalid_params('group_name is required') + if Utils.is_empty(project_id): + raise exceptions.invalid_params('project_id is required') + + result = self.project_groups_table.get_item( + Key={ + 'username': group_name, + 'project_id': project_id + } + ) + return Utils.get_value_as_dict('Item', result) is not None diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/projects/project_tasks.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/projects/project_tasks.py index 98ec574..ddf01f2 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/projects/project_tasks.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/projects/project_tasks.py @@ -68,8 +68,17 @@ def invoke(self, payload: Dict): project_id = payload['project_id'] project = self.context.projects.projects_dao.get_project_by_id(project_id) if project['enabled']: - groups_added = Utils.get_value_as_list('groups_added', payload, []) - groups_removed = Utils.get_value_as_list('groups_removed', payload, []) + groups_added = payload.get('groups_added', []) + groups_removed = payload.get('groups_removed', []) + users_added = payload.get('users_added', []) + users_removed = payload.get('users_removed', []) + + for username in users_removed: + self.context.projects.user_projects_dao.delete_user_project( + project_id=project_id, + username=username, + by_ldaps=["SELF"] + ) for ldap_group_name in groups_removed: self.context.projects.user_projects_dao.ldap_group_removed( @@ -82,6 +91,13 @@ def invoke(self, payload: Dict): project_id=project_id ) + for username in users_added: + self.context.projects.user_projects_dao.create_user_project( + project_id=project_id, + username=username, + by_ldaps=["SELF"] + ) + for ldap_group_name in groups_added: self.context.projects.user_projects_dao.ldap_group_added( project_id=project_id, diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/projects/projects_service.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/projects/projects_service.py index bd99b5b..1b6a0cc 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/projects/projects_service.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/projects/projects_service.py @@ -14,6 +14,8 @@ from ideadatamodel.projects import ( CreateProjectRequest, CreateProjectResult, + DeleteProjectRequest, + DeleteProjectResult, GetProjectRequest, GetProjectResult, UpdateProjectRequest, @@ -28,23 +30,26 @@ GetUserProjectsResult, Project ) +from ideadatamodel.api.api_model import ApiAuthorization from ideasdk.utils import Utils, GroupNameHelper from ideasdk.context import SocaContext +from ideasdk.client.vdc_client import AbstractVirtualDesktopControllerClient from ideaclustermanager.app.projects.db.projects_dao import ProjectsDAO from ideaclustermanager.app.projects.db.user_projects_dao import UserProjectsDAO from ideaclustermanager.app.accounts.accounts_service import AccountsService from ideaclustermanager.app.tasks.task_manager import TaskManager -from typing import List +from typing import List, Optional class ProjectsService: - def __init__(self, context: SocaContext, accounts_service: AccountsService, task_manager: TaskManager): + def __init__(self, context: SocaContext, accounts_service: AccountsService, task_manager: TaskManager, vdc_client: AbstractVirtualDesktopControllerClient): self.context = context self.accounts_service = accounts_service self.task_manager = task_manager + self.vdc_client = vdc_client self.logger = context.logger('projects') self.projects_dao = ProjectsDAO(context) @@ -81,8 +86,6 @@ def create_project(self, request: CreateProjectRequest) -> CreateProjectResult: if existing is not None: raise exceptions.invalid_params(f'project with name: {project.name} already exists') - if Utils.is_empty(project.ldap_groups): - raise exceptions.invalid_params('ldap_groups[] is required') for ldap_group_name in project.ldap_groups: # check if group exists # Active Directory mode checks the back-end LDAP @@ -119,6 +122,43 @@ def create_project(self, request: CreateProjectRequest) -> CreateProjectResult: project=enabled_project.project ) + def delete_project(self, request: DeleteProjectRequest) -> DeleteProjectResult: + """ + Delete a Project + validate required fields, remove the project from DynamoDB and Cache. + :param request: DeleteProjectRequest + :param access_token: access token used for this request + :param api_authorization: authorization for this request + :return: DeleteProjectResult + """ + if Utils.is_empty(request): + raise exceptions.invalid_params('request is required') + + project_id = request.project_id + project_name = request.project_name + if Utils.is_empty(project_id) and Utils.is_empty(project_name): + raise exceptions.invalid_params('either project id or project name is required') + + project = self.projects_dao.get_project_by_id(project_id) if project_id else self.projects_dao.get_project_by_name(project_name) + if project is not None: + project_id = self.projects_dao.convert_from_db(project).project_id + sessions_by_project_id = self.vdc_client.list_sessions_by_project_id(project_id) + if sessions_by_project_id: + session_ids_by_project_id = [session.dcv_session_id for session in sessions_by_project_id] + raise exceptions.general_exception(f'project is still used by virtual desktop sessions. ' + f'Project ID: {project_id}, Session IDs: {session_ids_by_project_id}') + + software_stacks_by_project_id = self.vdc_client.list_software_stacks_by_project_id(project_id) + if software_stacks_by_project_id: + stack_ids_by_project_id = [software_stack.stack_id for software_stack in software_stacks_by_project_id] + raise exceptions.general_exception(f'project is still used by software stacks. ' + f'Project ID: {project_id}, Stack IDs: {stack_ids_by_project_id}') + + self.user_projects_dao.delete_project(project_id) + self.projects_dao.delete_project(project_id) + + return DeleteProjectResult() + def get_project(self, request: GetProjectRequest) -> GetProjectResult: """ Retrieve the Project from the cache @@ -203,7 +243,14 @@ def update_project(self, request: UpdateProjectRequest) -> UpdateProjectResult: for ldap_group_name in groups_added: # check if group exists self.accounts_service.get_group(ldap_group_name) + users_added = None + users_removed = None + + existing_users = set(existing.get('users', [])) + updated_users = set(project.users) + users_added = updated_users - existing_users + users_removed = existing_users - updated_users # none values will be skipped by db update. ensure enabled/disabled cannot be called via update project. project.enabled = None @@ -211,13 +258,15 @@ def update_project(self, request: UpdateProjectRequest) -> UpdateProjectResult: updated_project = self.projects_dao.convert_from_db(db_updated) if updated_project.enabled: - if groups_added is not None or groups_removed is not None: + if groups_added or groups_removed or users_added or users_removed: self.task_manager.send( task_name='projects.project-groups-updated', payload={ 'project_id': updated_project.project_id, 'groups_added': list(groups_added), - 'groups_removed': list(groups_removed) + 'groups_removed': list(groups_removed), + 'users_added': list(users_added), + 'users_removed': list(users_removed) }, message_group_id=updated_project.project_id ) @@ -381,4 +430,4 @@ def get_user_projects(self, request: GetUserProjectsRequest) -> GetUserProjectsR return GetUserProjectsResult( projects=result - ) \ No newline at end of file + ) diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/shared_filesystem/shared_filesystem_service.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/shared_filesystem/shared_filesystem_service.py index 45aa708..9f336fe 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/shared_filesystem/shared_filesystem_service.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/shared_filesystem/shared_filesystem_service.py @@ -1,7 +1,8 @@ import time -from typing import List, Dict +from typing import List, Dict, Set, Union import ideaclustermanager +from ideasdk.context import SocaContext from ideadatamodel import ( constants, CommonCreateFileSystemRequest, @@ -11,19 +12,29 @@ FSxONTAPVolume, CreateONTAPFileSystemRequest, exceptions, + errorcodes ) from ideadatamodel.shared_filesystem import ( + AddFileSystemToProjectRequest, OnboardEFSFileSystemRequest, OnboardONTAPFileSystemRequest, + RemoveFileSystemFromProjectRequest, + CommonOnboardFileSystemRequest, + OffboardFileSystemRequest, + FileSystem, + FSxONTAPFileSystem, + ListFileSystemInVPCResult ) -from ideasdk.utils import Utils +from ideadatamodel.shared_filesystem.shared_filesystem_api import FSxONTAPDeploymentType +from ideasdk.utils import Utils, ApiUtils import botocore.exceptions class SharedFilesystemService: - def __init__(self, context: ideaclustermanager.AppContext): + def __init__(self, context: SocaContext): self.context = context self.config = self.context.config() + self.logger = self.context.logger("shared-filesystem") def create_tags(self, filesystem_name: str): backup_plan_tags = self.config.get_list( @@ -58,6 +69,18 @@ def common_filesystem_config(request: CommonCreateFileSystemRequest): "title": request.filesystem_title, } + def onboard_efs_filesystem(self, request: OnboardEFSFileSystemRequest): + self._validate_onboard_filesystem_request(request) + self._validate_filesystem_does_not_exist(request.filesystem_name) + self._validate_filesystem_present_in_vpc_and_not_onboarded(request.filesystem_id) + + config_entries = self.build_config_for_vpc_efs(request) + + self.context.config().db.sync_cluster_settings_in_db( + config_entries=config_entries, + overwrite=True + ) + def build_config_for_vpc_efs( self, request: OnboardEFSFileSystemRequest ): @@ -90,6 +113,79 @@ def build_config_for_vpc_efs( ) return config_entries + def create_efs(self, request: CreateEFSFileSystemRequest): + self._validate_create_filesystem_request(request) + self._validate_correct_subnet_selection({request.subnet_id_1, request.subnet_id_2}) + self._validate_filesystem_does_not_exist(request.filesystem_name) + ApiUtils.validate_input(request.mount_directory, + constants.MOUNT_DIRECTORY_REGEX, + constants.MOUNT_DIRECTORY_ERROR_MESSAGE) + try: + efs_client = self.context.aws().efs() + security_group_id = self.config.get_string( + f"{constants.MODULE_SHARED_STORAGE}.security_group_id" + ) + + tags_for_fs = self.create_tags( + request.filesystem_name + ) + + efs_create_response = efs_client.create_file_system( + ThroughputMode="elastic", Encrypted=True, Tags=tags_for_fs + ) + efs_filesystem = EFSFileSystem(efs=efs_create_response) + efs_filesystem_id = efs_filesystem.get_filesystem_id() + + self.efs_check_filesystem_exists( + filesystem_id=efs_filesystem_id, wait=True + ) + filesystem_policy = { + "Version": "2012-10-17", + "Id": "efs-prevent-anonymous-access-policy", + "Statement": [ + { + "Sid": "efs-statement", + "Effect": "Allow", + "Principal": {"AWS": "*"}, + "Action": [ + "elasticfilesystem:ClientRootAccess", + "elasticfilesystem:ClientWrite", + "elasticfilesystem:ClientMount", + ], + "Resource": f"arn:{self.context.aws().aws_partition()}:elasticfilesystem:{self.context.aws().aws_region()}:{self.context.aws().aws_account_id()}:file-system/{efs_filesystem_id}", + "Condition": { + "Bool": {"elasticfilesystem:AccessedViaMountTarget": "true"} + }, + } + ], + } + efs_client.put_file_system_policy( + FileSystemId=efs_filesystem_id, Policy=Utils.to_json(filesystem_policy) + ) + + # Create mount targets + efs_client.create_mount_target( + FileSystemId=efs_filesystem_id, + SubnetId=request.subnet_id_1, + SecurityGroups=[security_group_id], + ) + efs_client.create_mount_target( + FileSystemId=efs_filesystem_id, + SubnetId=request.subnet_id_2, + SecurityGroups=[security_group_id], + ) + + # Sync cluster-settings ddb + config_entries = self.build_config_for_new_efs( + efs=efs_filesystem, request=request + ) + self.config.db.sync_cluster_settings_in_db( + config_entries=config_entries, overwrite=True + ) + except botocore.exceptions.ClientError as e: + error_message = e.response["Error"]["Message"] + raise exceptions.general_exception(error_message) + def build_config_for_new_efs( self, efs: EFSFileSystem, request: CreateEFSFileSystemRequest ): @@ -119,6 +215,18 @@ def build_config_for_new_efs( ) return config_entries + def onboard_ontap_filesystem(self, request: OnboardONTAPFileSystemRequest): + self._validate_onboard_filesystem_request(request) + self._validate_filesystem_does_not_exist(request.filesystem_name) + self._validate_filesystem_present_in_vpc_and_not_onboarded(request.filesystem_id) + + config_entries = self.build_config_for_vpc_ontap(request) + + self.context.config().db.sync_cluster_settings_in_db( + config_entries=config_entries, + overwrite=True + ) + def build_config_for_vpc_ontap( self, request: OnboardONTAPFileSystemRequest, ): @@ -168,6 +276,101 @@ def build_config_for_vpc_ontap( ) return config_entries + def create_fsx_ontap(self, request: CreateONTAPFileSystemRequest): + self._validate_create_filesystem_request(request) + self._validate_correct_subnet_selection({request.primary_subnet, request.standby_subnet}) + self._validate_filesystem_does_not_exist(request.filesystem_name) + if not request.mount_directory and not request.mount_drive: + raise exceptions.soca_exception( + error_code=errorcodes.INVALID_PARAMS, + message="One of mount drive or mount directory is required", + ) + + if request.mount_directory: + ApiUtils.validate_input(request.mount_directory, + constants.MOUNT_DIRECTORY_REGEX, + constants.MOUNT_DIRECTORY_ERROR_MESSAGE) + + if request.mount_drive: + ApiUtils.validate_input(request.mount_drive, + constants.MOUNT_DRIVE_REGEX, + constants.MOUNT_DRIVE_ERROR_MESSAGE) + + ApiUtils.validate_input(request.file_share_name, + constants.FILE_SYSTEM_NAME_REGEX, + constants.FILE_SYSTEM_NAME_ERROR_MESSAGE) + ApiUtils.validate_input_range(request.storage_capacity, constants.ONTAP_STORAGE_CAPACITY_RANGE) + try: + fsx_client = self.context.aws().fsx() + security_group_id = self.config.get_string( + f"{constants.MODULE_SHARED_STORAGE}.security_group_id" + ) + + tags_for_fs = self.create_tags( + request.filesystem_name + ) + + _subnet_ids = [] + if request.deployment_type == FSxONTAPDeploymentType.SINGLE_AZ: + _subnet_ids = [request.primary_subnet] + else: + _subnet_ids = [request.primary_subnet, request.standby_subnet] + fs_create_response = fsx_client.create_file_system( + FileSystemType="ONTAP", + SecurityGroupIds=[ + security_group_id, + ], + Tags=tags_for_fs, + StorageCapacity=request.storage_capacity, + SubnetIds=_subnet_ids, + OntapConfiguration={ + "PreferredSubnetId": request.primary_subnet, + "DeploymentType": request.deployment_type, + "ThroughputCapacity": 128, # parameter required + }, + ) + + fsx_ontap_filesystem = FSxONTAPFileSystem( + filesystem=fs_create_response["FileSystem"] + ) + fs_id = fsx_ontap_filesystem.get_filesystem_id() + + svm = fsx_client.create_storage_virtual_machine( + FileSystemId=fs_id, + Name="fsx", + ) + + svm_volume = fsx_client.create_volume( + VolumeType="ONTAP", + Name="vol1", + OntapConfiguration={ + "JunctionPath": "/vol1", + "SecurityStyle": request.volume_security_style, + "SizeInMegabytes": 1024, + "StorageVirtualMachineId": svm["StorageVirtualMachine"][ + "StorageVirtualMachineId" + ], + "StorageEfficiencyEnabled": True, + }, + ) + + fsx_ontap_svm = FSxONTAPSVM( + storage_virtual_machine=svm["StorageVirtualMachine"] + ) + fsx_ontap_volume = FSxONTAPVolume(volume=svm_volume["Volume"]) + + # Update cluster-settings dynamodb + config_entries = self.build_config_for_new_ontap( + svm=fsx_ontap_svm, volume=fsx_ontap_volume, request=request + ) + self.config.db.sync_cluster_settings_in_db( + config_entries=config_entries, overwrite=True + ) + except botocore.exceptions.ClientError as e: + error_message = e.response["Error"]["Message"] + self.logger.error(error_message) + raise exceptions.general_exception(error_message) + def build_config_for_new_ontap( self, svm: FSxONTAPSVM, @@ -208,6 +411,45 @@ def build_config_for_new_ontap( ) return config_entries + def offboard_filesystem(self, request: OffboardFileSystemRequest): + filesystem_name = request.filesystem_name + self.config.db.delete_config_entries(f'{constants.MODULE_SHARED_STORAGE}.{filesystem_name}') + + def add_filesystem_to_project(self, request: AddFileSystemToProjectRequest): + filesystem_name = request.filesystem_name + project_name = request.project_name + self._check_required_parameters(request=request) + self.update_filesystem_to_project_mapping(filesystem_name, project_name) + + def remove_filesystem_from_project(self, request: RemoveFileSystemFromProjectRequest): + filesystem_name = request.filesystem_name + project_name = request.project_name + self._check_required_parameters(request=request) + + fs = self.get_filesystem(filesystem_name) + projects = fs.get_projects() + if project_name in projects: + projects.remove(project_name) + self._update_projects_for_filesystem(fs, projects) + + def list_file_systems_in_vpc(self) -> ListFileSystemInVPCResult: + onboarded_filesystems = self._list_shared_filesystems() + onboarded_filesystem_ids = set([fs.get_filesystem_id() for fs in onboarded_filesystems]) + + efs_filesystems = self._list_unonboarded_efs_file_systems(onboarded_filesystem_ids) + fsx_filesystems = self._list_unonboarded_ontap_file_systems(onboarded_filesystem_ids) + + return ListFileSystemInVPCResult(efs=efs_filesystems, fsx=fsx_filesystems) + + def update_filesystem_to_project_mapping(self, filesystem_name: str, project_name: str): + fs = self.get_filesystem(filesystem_name) + projects = fs.get_projects() + if Utils.is_empty(projects): + projects = [] + if project_name not in projects: + projects.append(project_name) + self._update_projects_for_filesystem(fs, projects) + def traverse_config(self, config_entries: List[Dict], prefix: str, config: Dict): for key in config: if "." in key or ":" in key: @@ -254,3 +496,195 @@ def efs_check_filesystem_exists( return False else: raise e + + def _validate_onboard_filesystem_request(self, request: CommonOnboardFileSystemRequest): + if not request.filesystem_id or not request.filesystem_name or not request.filesystem_title: + raise exceptions.soca_exception( + error_code=errorcodes.INVALID_PARAMS, + message="needed parameters cannot be empty", + ) + + def _validate_create_filesystem_request(self, request: CommonCreateFileSystemRequest): + if not request.filesystem_name or not request.filesystem_title: + raise exceptions.soca_exception( + error_code=errorcodes.INVALID_PARAMS, + message="needed parameters cannot be empty", + ) + ApiUtils.validate_input(request.filesystem_name, + constants.FILE_SYSTEM_NAME_REGEX, + constants.FILE_SYSTEM_NAME_ERROR_MESSAGE) + + def _validate_filesystem_does_not_exist(self, filesystem_name: str): + try: + if Utils.is_not_empty(self.get_filesystem(filesystem_name)): + raise exceptions.soca_exception( + error_code=errorcodes.INVALID_PARAMS, + message=f"{filesystem_name} already exists", + ) + except exceptions.SocaException as e: + if e.error_code == errorcodes.FILESYSTEM_NOT_FOUND or e.error_code == errorcodes.NO_SHARED_FILESYSTEM_FOUND: + pass + else: + raise e + + def _validate_correct_subnet_selection(self, subnet_ids: Set[str]): + res_private_subnets = self.config.db.get_config_entry("cluster.network.private_subnets")['value'] + for subnet_id in subnet_ids: + if subnet_id is not None and subnet_id not in res_private_subnets: + self.logger.error(f"{subnet_id} is not a RES private subnet") + raise exceptions.soca_exception( + error_code=errorcodes.INVALID_PARAMS, + message=f"{subnet_id} is not a RES private subnet", + ) + + def _validate_filesystem_present_in_vpc_and_not_onboarded(self, filesystem_id): + onboarded_filesystems = self._list_shared_filesystems() + onboarded_filesystem_ids = set([fs.get_filesystem_id() for fs in onboarded_filesystems]) + + if filesystem_id in onboarded_filesystem_ids: + raise exceptions.soca_exception( + error_code=errorcodes.FILESYSTEM_ALREADY_ONBOARDED, + message=f"{filesystem_id} has already been onboarded" + ) + + efs_filesystems = self._list_unonboarded_efs_file_systems(onboarded_filesystem_ids) + fsx_filesystems = self._list_unonboarded_ontap_file_systems(onboarded_filesystem_ids) + + filesystems_in_vpc = [*efs_filesystems, *fsx_filesystems] + + for filesystem in filesystems_in_vpc: + if filesystem_id == filesystem.get_filesystem_id(): + return True + + raise exceptions.soca_exception( + error_code=errorcodes.FILESYSTEM_NOT_IN_VPC, + message=f"{filesystem_id} not part of the env's VPC thus not accessible" + ) + + def _update_projects_for_filesystem( + self, filesystem: FileSystem, projects: List[str] + ): + self.config.db.set_config_entry( + f"{constants.MODULE_SHARED_STORAGE}.{filesystem.get_name()}.projects", # update entry on cluster settings dynamodb table + projects, + ) + self.config.put( + f"{constants.MODULE_SHARED_STORAGE}.{filesystem.get_name()}.projects", # update local config tree + projects, + ) + + def _list_shared_filesystems(self) -> List[FileSystem]: + shared_storage_config = self.config.get_config(constants.MODULE_SHARED_STORAGE) + shared_storage_config_dict = shared_storage_config.as_plain_ordered_dict() + + filesystem: List[FileSystem] = [] + for fs_name, config in shared_storage_config_dict.items(): + if Utils.is_not_empty(Utils.get_as_dict(config)) and "projects" in list( + config.keys() + ): + filesystem.append(FileSystem(name=fs_name, storage=config)) + + return filesystem + + def get_filesystem(self, filesystem_name: str): + filesystems = self._list_shared_filesystems() + + if Utils.is_empty(filesystems): + raise exceptions.soca_exception( + error_code=errorcodes.NO_SHARED_FILESYSTEM_FOUND, + message="did not find any shared filesystem", + ) + + for fs in filesystems: + if fs.get_name() == filesystem_name: + return fs + + raise exceptions.soca_exception( + error_code=errorcodes.FILESYSTEM_NOT_FOUND, + message=f"could not find filesystem {filesystem_name}", + ) + + def _list_unonboarded_efs_file_systems(self, onboarded_filesystem_ids: Set[str]) -> List[EFSFileSystem]: + try: + onboarded_internal_filesystem_ids = [ + self.config.db.get_config_entry("shared-storage.home.efs.file_system_id")['value'], + self.config.db.get_config_entry("shared-storage.internal.efs.file_system_id")['value'] + ] + env_vpc_id = self.config.db.get_config_entry("cluster.network.vpc_id")['value'] + efs_client = self.context.aws().efs() + efs_response = efs_client.describe_file_systems()["FileSystems"] + + filesystems: List[EFSFileSystem] = [] + for efs in efs_response: + if efs['LifeCycleState'] != 'available': + continue + fs_id = efs['FileSystemId'] + efs_mt_response = efs_client.describe_mount_targets(FileSystemId=fs_id) + if len(efs_mt_response['MountTargets']) == 0: + continue + if env_vpc_id == efs_mt_response['MountTargets'][0]['VpcId'] and \ + fs_id not in onboarded_filesystem_ids and \ + fs_id not in onboarded_internal_filesystem_ids: + filesystems.append(EFSFileSystem(efs = efs)) + return filesystems + except botocore.exceptions.ClientError as e: + error_message = e.response["Error"]["Message"] + raise exceptions.general_exception(error_message) + + def _list_unonboarded_ontap_file_systems(self, onboarded_filesystem_ids: Set[str]) -> List[FSxONTAPFileSystem]: + try: + env_vpc_id = self.config.db.get_config_entry("cluster.network.vpc_id")['value'] + + fsx_client = self.context.aws().fsx() + fsx_response = fsx_client.describe_file_systems()["FileSystems"] + ec2_client = self.context.aws().ec2() + + filesystems: List[FSxONTAPFileSystem] = [] + for fsx in fsx_response: + if fsx['Lifecycle'] != 'AVAILABLE': + continue + fs_id = fsx['FileSystemId'] + subnet_response = ec2_client.describe_subnets(SubnetIds=fsx['SubnetIds']) + if env_vpc_id == subnet_response['Subnets'][0]['VpcId'] and \ + fs_id not in onboarded_filesystem_ids: + volume_response = fsx_client.describe_volumes( + Filters=[{ + 'Name': 'file-system-id', + 'Values': [fs_id] + }] + ) + svm_response = fsx_client.describe_storage_virtual_machines( + Filters=[{ + 'Name': 'file-system-id', + 'Values': [fs_id] + }] + ) + if len(svm_response['StorageVirtualMachines']) == 0 or len(volume_response['Volumes']) == 0: + continue + list_created_svms = list(filter( + lambda svm_obj: svm_obj['Lifecycle'] == "CREATED", + svm_response['StorageVirtualMachines'] + )) + list_created_volumes = list(filter( + lambda volume_obj: volume_obj['Lifecycle'] == "CREATED", + volume_response['Volumes'] + )) + svm_list = [FSxONTAPSVM(storage_virtual_machine = svm) for svm in list_created_svms] + volume_list = [FSxONTAPVolume(volume = volume) for volume in list_created_volumes] + filesystems.append(FSxONTAPFileSystem(filesystem = fsx, svm = svm_list, volume = volume_list)) + return filesystems + except botocore.exceptions.ClientError as e: + error_message = e.response["Error"]["Message"] + raise exceptions.general_exception(error_message) + + + @staticmethod + def _check_required_parameters( + request: Union[ + AddFileSystemToProjectRequest, RemoveFileSystemFromProjectRequest + ] + ): + if Utils.is_empty(request.filesystem_name): + raise exceptions.invalid_params("filesystem_name is required") + if Utils.is_empty(request.project_name): + raise exceptions.invalid_params("project_name is required") diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/__init__.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/__init__.py index e69de29..6d8d18a 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/__init__.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot.py new file mode 100644 index 0000000..75daf0d --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot.py @@ -0,0 +1,364 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from ideasdk.context import SocaContext +from ideadatamodel.snapshots import Snapshot, ApplySnapshotStatus, TableName +from ideadatamodel import exceptions, SocaListingPayload +from ideasdk.utils import Utils, scan_db_records +from ideaclustermanager.app.snapshots.helpers.apply_snapshot_version_control_helper import get_table_keys_by_res_version +from ideaclustermanager.app.snapshots import snapshot_constants +from ideaclustermanager.app.snapshots.db.apply_snapshot_dao import ApplySnapshotDAO +from ideaclustermanager.app.snapshots.helpers.apply_snapshot_temp_tables_helper import ApplySnapshotTempTablesHelper +from ideaclustermanager.app.snapshots.helpers.apply_snapshots_config import RES_VERSION_IN_TOPOLOGICAL_ORDER, RES_VERSION_TO_DATA_TRANSFORMATION_CLASS, TABLES_IN_MERGE_DEPENDENCY_ORDER, TABLE_TO_MERGE_LOGIC_CLASS +from ideaclustermanager.app.snapshots.helpers.apply_snapshot_observability_helper import ApplySnapshotObservabilityHelper +from ideaclustermanager.app.snapshots.helpers.merged_record_utils import MergedRecordDelta + +import botocore.exceptions +import concurrent.futures +import json +import os +import re +import threading +import time +from typing import List, Dict, Optional + + +class ApplySnapshot: + def __init__(self, snapshot: Snapshot, apply_snapshot_dao: ApplySnapshotDAO, context: SocaContext): + self.snapshot = snapshot + self.context = context + self.logger = context.logger('apply-snapshot') + self.apply_snapshot_observability_helper = ApplySnapshotObservabilityHelper(self.logger) + + self.apply_snapshot_dao = apply_snapshot_dao + self.apply_snapshot_temp_tables_helper = ApplySnapshotTempTablesHelper(context) + self._ddb_client = self.context.aws().dynamodb_table() + + self.created_on = Utils.current_time_ms() + + self.validate_input() + + def initialize(self): + self.logger.info( + f"Snapshot Apply request received for s3_bucket_name: {self.snapshot.s3_bucket_name}, snapshot_path: {self.snapshot.snapshot_path}" + ) + + try: + self.metadata = self.fetch_snapshot_metadata() + except botocore.exceptions.ClientError as e: + error_message = e.response["Error"]["Message"] + self.apply_snapshot_dao.create( + { + 's3_bucket_name': self.snapshot.s3_bucket_name, + 'snapshot_path': self.snapshot.snapshot_path, + 'status': ApplySnapshotStatus.FAILED, + 'failure_reason': error_message, + }, + created_on=self.created_on + ) + raise e + + try: + # Extract the RES Version of the snapshot + self.snapshot_res_version = self.metadata[snapshot_constants.VERSION_KEY] + + # Extract the table_export_description object from metadata that stores details off all the tables present in the snapshot + self.table_export_descriptions = self.metadata[snapshot_constants.TABLE_EXPORT_DESCRIPTION_KEY] + + self.tables_to_be_imported = self.get_list_of_tables_to_be_imported() + self.logger.info(f"Tables that will be imported are: {(',').join(self.tables_to_be_imported)}") + self.tables_to_be_imported_with_table_key_info = get_table_keys_by_res_version( + self.tables_to_be_imported, self.snapshot_res_version + ) + + self.apply_snapshot_record = self.apply_snapshot_dao.convert_from_db( + self.apply_snapshot_dao.create( + { + 's3_bucket_name': self.snapshot.s3_bucket_name, + 'snapshot_path': self.snapshot.snapshot_path, + 'status': ApplySnapshotStatus.IN_PROGRESS, + }, + created_on=self.created_on + ) + ) + + self.logger.info("Starting thread to perform ApplySnapshot main process") + threading.Thread(target=self.apply_snapshot_main).start() + except Exception as e: + self.logger.error(f"Applying Snapshot {self.snapshot} failed with error {repr(e)}") + self.apply_snapshot_dao.create( + { + 's3_bucket_name': self.snapshot.s3_bucket_name, + 'snapshot_path': self.snapshot.snapshot_path, + 'status': ApplySnapshotStatus.FAILED, + 'failure_reason': f'{repr(e)}', + }, + created_on=self.created_on + ) + raise e + + def get_temp_table_name(self, table_name) -> str: + return f'{self.context.cluster_name()}.temp-{table_name}-{self.created_on}' + + def validate_input(self): + """Validates snapshot input object to check if the requred parameters are present. + Required parameters + - s3_bucket_name + - snapshot_path + + Raises: + exceptions.invalid_params: if required parameters `s3_bucket_name` or `snapshot_path` are not passed. + """ + if not self.snapshot.s3_bucket_name or not self.snapshot.s3_bucket_name.strip(): + raise exceptions.invalid_params('s3_bucket_name is required') + if not re.match(snapshot_constants.SNAPSHOT_S3_BUCKET_NAME_REGEX, self.snapshot.s3_bucket_name): + raise exceptions.invalid_params( + f's3_bucket_name must match regex: {snapshot_constants.SNAPSHOT_S3_BUCKET_NAME_REGEX}' + ) + + if not self.snapshot.snapshot_path or not self.snapshot.snapshot_path.strip(): + raise exceptions.invalid_params('snapshot_path is required') + if not re.match(snapshot_constants.SNAPSHOT_PATH_REGEX, self.snapshot.snapshot_path): + raise exceptions.invalid_params(f'snapshot_path must match regex: {snapshot_constants.SNAPSHOT_PATH_REGEX}') + + def fetch_snapshot_metadata(self): + """Fetches the metadata.json file from the S3 bucket passed as input. Loads the data into python dict and returns it. + + Raises: + botocore.exceptions.ClientError: ClientError could occure in the following cases + - The S3 bucket does not have proper permissions set that allow RES application to get an aboject from the bucket + - The S3 bucket or the path do not exist + + Returns: + Dict: returns the metadata.json file contents in python Dict format + """ + try: + metadata = json.loads( + self.context.aws() + .s3() + .get_object( + Bucket=self.snapshot.s3_bucket_name, + Key=f"{self.snapshot.snapshot_path}/{snapshot_constants.METADATA_FILE_NAME_AND_EXTENSION}", + )['Body'] + .read() + ) + self.logger.debug(f"Metadata file contents: {metadata}") + + return metadata + except botocore.exceptions.ClientError as e: + self.logger.error( + f"An error occured while trying to fetch {snapshot_constants.METADATA_FILE_NAME_AND_EXTENSION} file from S3 bucket {self.snapshot.s3_bucket_name} at path {self.snapshot.snapshot_path} during the apply snapshot process: {e}" + ) + raise e + + def get_list_of_tables_to_be_imported(self) -> List[TableName]: + """Returns a list of table names that should be imported from the snapshot. + + Returns: + List: List of table names that represents the intersection of the tables present in the snapshot + and list of tables the ApplySnapshot process supports applying. + """ + return [e.value for e in TableName if e.value in self.table_export_descriptions] + + def apply_snapshot_main(self): + + try: + # Step 1.1: Import all applicable tables from snapshot + self.import_all_tables() + self.logger.info(f"All tables imported successfully") + + # Step 1.2: Scans all table data and stores it in a python dict + data_by_table = self.scan_ddb_tables() + self.logger.debug(f"All table data fetched from snapshot {data_by_table}") + + # Step 2: Apply applicable data transformation logic on table data + transformed_data = self.apply_data_transformations(data_by_table) + self.logger.debug(f"Data transformations applied {transformed_data}") + + # Step 3: Merge data to env's actual DDB tables + self.merge_transformed_data_for_all_tables(transformed_data) + + self.logger.info(f"Apply snapshot operation completed") + + except botocore.exceptions.ClientError as e: + error_message = e.response["Error"]["Message"] + self.apply_snapshot_dao.update_status(self.apply_snapshot_record, ApplySnapshotStatus.FAILED, error_message) + except exceptions.SocaException as e: + error_message = e.message + self.apply_snapshot_dao.update_status(self.apply_snapshot_record, ApplySnapshotStatus.FAILED, error_message) + except Exception as e: + error_message = repr(e) + self.logger.error(f"Applying Snapshot {self.apply_snapshot_record.apply_snapshot_identifier} failed with error {error_message}") + self.apply_snapshot_dao.update_status(self.apply_snapshot_record, ApplySnapshotStatus.FAILED, error_message) + finally: + imported_tables = [self.get_temp_table_name(table_name) for table_name in self.tables_to_be_imported] + self.apply_snapshot_temp_tables_helper.delete_imported_tables(table_names=imported_tables) + + def import_all_tables(self) -> None: + """ Creates one thread per table to import all tables in the self.tables_to_be_imported_with_table_key_info list. + + Raises: + exceptions.table_import_failed: Raises the exception when some of the tables fail to be imported + """ + start_time = time.time() + tables_failed_import: List[str] = [] + + with concurrent.futures.ThreadPoolExecutor() as executor: + tasks = { + executor.submit( + self.import_table, + table_name, + self.tables_to_be_imported_with_table_key_info[table_name].partition_key, + self.tables_to_be_imported_with_table_key_info[table_name].sort_key, + ): table_name for table_name in self.tables_to_be_imported_with_table_key_info.keys() + } + for future in concurrent.futures.as_completed(tasks): + table_name = tasks[future] + exception = future.exception() + if exception: + tables_failed_import.append(table_name) + + self.logger.debug(f"Table import completed for all tables in {time.time()-start_time} seconds") + + if tables_failed_import: + error_message = f"Tables ({(',').join(tables_failed_import)}) failed to import. For more details see CloudWatch logs or DynamoDB > Imports from S3 in AWS Console for imports suffixed with '{self.created_on}'" + self.logger.error(error_message) + raise exceptions.table_import_failed(error_message) + + def import_table( + self, table_name: str, partition_key: str, sort_key: Optional[str] + ): + """Initiates an import_table operation for a table from the snapshot. Waits for the import operation to complete (either sucesfully or fail). + + Args: + table_name (str): Name of the table that needs to be imported + partition_key (str): partition_key that should be used for the temp table craeated + sort_key (_type_): sort_key (if any) that should be used for the temp DyanmoDB table created + """ + table_export_details = self.table_export_descriptions.get(table_name) + + table_export_manifest_file = table_export_details['ExportManifest'] + table_export_directory = os.path.dirname(table_export_manifest_file) + + temp_table_name = self.get_temp_table_name(table_name) + + try: + res = self.apply_snapshot_temp_tables_helper.initiate_import_table( + s3_bucket_name=self.snapshot.s3_bucket_name, + s3_key_prefix=f'{table_export_directory}/data/', + table_name=temp_table_name, + partition_key=partition_key, + sort_key=sort_key, + ) + + import_arn = res['ImportTableDescription']['ImportArn'] + + imported = self.context.aws_util().dynamodb_check_import_completed_successfully(import_arn) + + if not imported: + self.logger.error(f'Table "{table_name}" failed import') + raise RuntimeError(f'Table "{table_name}" failed import') + except botocore.exceptions.ClientError as e: + self.logger.error(f'Table "{table_name}" failed import with error: {e}') + raise e + + def scan_ddb_tables(self) -> Dict[TableName, List]: + data_by_table = {} + for table_name in self.tables_to_be_imported: + data_by_table[table_name] = {} + + temp_table_name = self.get_temp_table_name(table_name) + + table_obj = self._ddb_client.Table(temp_table_name) + + result = scan_db_records(SocaListingPayload(), table_obj) + + data_by_table[table_name] = result["Items"] + return data_by_table + + def apply_data_transformations(self, data: Dict[TableName, List]) -> Dict: + res_index = RES_VERSION_IN_TOPOLOGICAL_ORDER.index(self.snapshot_res_version) + + for i in range(res_index, len(RES_VERSION_IN_TOPOLOGICAL_ORDER)): + version = RES_VERSION_IN_TOPOLOGICAL_ORDER[i] + + if version in RES_VERSION_TO_DATA_TRANSFORMATION_CLASS and RES_VERSION_TO_DATA_TRANSFORMATION_CLASS[version]: + data_transformer_obj = RES_VERSION_TO_DATA_TRANSFORMATION_CLASS[version]() + + self.logger.info(f"Initiating {type(data_transformer_obj).__name__}'s transform_data()") + data = data_transformer_obj.transform_data(data, self.logger) + self.logger.info(f"Completed executing {type(data_transformer_obj).__name__}'s transform_data()") + + return data + + def merge_transformed_data_for_all_tables(self, transformed_data: Dict[TableName, List]) -> None: + merged_table_to_delta_mappings: Dict[TableName, List[MergedRecordDelta]] = {} + + def _rollback_merged_tables(): + try: + for table_name in list(reversed(TABLES_IN_MERGE_DEPENDENCY_ORDER)): + if table_name in merged_table_to_delta_mappings and table_name in TABLE_TO_MERGE_LOGIC_CLASS and TABLE_TO_MERGE_LOGIC_CLASS[table_name]: + merger = TABLE_TO_MERGE_LOGIC_CLASS[table_name]() + + self.logger.info(f"Initiating {type(merger).__name__}'s rollback()") + merger.rollback(self.context, merged_table_to_delta_mappings[table_name], self.apply_snapshot_observability_helper) + del merged_table_to_delta_mappings[table_name] + self.logger.info(f"Completed executing {type(merger).__name__}'s rollback()") + except botocore.exceptions.ClientError as e: + error_message = e.response["Error"]["Message"] + self.logger.error(f"Apply Snapshot {self.apply_snapshot_record.apply_snapshot_identifier} failed to rollback with error {error_message}. Tables {list(merged_table_to_delta_mappings.keys())} failed to rollback.") + raise exceptions.table_rollback_failed(error_message) + except exceptions.SocaException as e: + error_message = e.message + self.logger.error(f"Apply Snapshot {self.apply_snapshot_record.apply_snapshot_identifier} failed to rollback with error {error_message}. Tables {list(merged_table_to_delta_mappings.keys())} failed to rollback.") + raise exceptions.table_rollback_failed(error_message) + except Exception as e: + error_message = repr(e) + self.logger.error(f"Apply Snapshot {self.apply_snapshot_record.apply_snapshot_identifier} failed to rollback with error {error_message}. Tables {list(merged_table_to_delta_mappings.keys())} failed to rollback.") + raise exceptions.table_rollback_failed(error_message) + + try: + for table_name in TABLES_IN_MERGE_DEPENDENCY_ORDER: + if table_name in TABLE_TO_MERGE_LOGIC_CLASS and TABLE_TO_MERGE_LOGIC_CLASS[table_name]: + merger = TABLE_TO_MERGE_LOGIC_CLASS[table_name]() + + self.logger.info(f"Initiating {type(merger).__name__}'s merge()") + merge_delta, success = merger.merge( + self.context, transformed_data.get(table_name, []), + self._get_snapshot_record_dedup_id(), + merged_table_to_delta_mappings, + self.apply_snapshot_observability_helper + ) + + merged_table_to_delta_mappings[table_name] = merge_delta + + if success: + self.logger.info(f"Completed executing {type(merger).__name__}'s merge()") + else: + error_message = f"Apply Snapshot {self.apply_snapshot_record.apply_snapshot_identifier} failed to apply {table_name}. Initiating rollback." + raise exceptions.table_merge_failed(error_message) + + self.apply_snapshot_dao.update_status(self.apply_snapshot_record, ApplySnapshotStatus.COMPLETED) + + except exceptions.SocaException as e: + error_message = e.message + self.logger.error(error_message) + + self.apply_snapshot_dao.update_status(self.apply_snapshot_record, ApplySnapshotStatus.ROLLBACK_IN_PROGRESS, error_message) + + try: + _rollback_merged_tables() + self.apply_snapshot_dao.update_status(self.apply_snapshot_record, ApplySnapshotStatus.ROLLBACK_COMPLETE, error_message) + except exceptions.SocaException as e: + self.apply_snapshot_dao.update_status(self.apply_snapshot_record, ApplySnapshotStatus.ROLLBACE_FAILED, e.message) + + def _get_snapshot_record_dedup_id(self) -> str: + return f"{self.snapshot_res_version}_{self.created_on}" diff --git a/source/idea/idea-administrator/resources/lambda_functions/idea_custom_resource_opensearch_private_ips/__init__.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_data_transformation_from_version/__init__.py similarity index 100% rename from source/idea/idea-administrator/resources/lambda_functions/idea_custom_resource_opensearch_private_ips/__init__.py rename to source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_data_transformation_from_version/__init__.py diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_data_transformation_from_version/abstract_transformation_from_res_version.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_data_transformation_from_version/abstract_transformation_from_res_version.py new file mode 100644 index 0000000..d721725 --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_data_transformation_from_version/abstract_transformation_from_res_version.py @@ -0,0 +1,24 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from ideadatamodel.snapshots import TableName + +from abc import abstractmethod +from logging import Logger +from typing import Dict, List + + +class TransformationFromRESVersion: + """Every data transformation class must inherit this class. """ + + @abstractmethod + def transform_data(self, env_data_by_table: Dict[TableName, List], logger: Logger) -> Dict[TableName, List]: + ... diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_data_transformation_from_version/data_transformation_from_2023_11.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_data_transformation_from_version/data_transformation_from_2023_11.py new file mode 100644 index 0000000..10a536f --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_data_transformation_from_version/data_transformation_from_2023_11.py @@ -0,0 +1,24 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from ideaclustermanager.app.snapshots.apply_snapshot_data_transformation_from_version.abstract_transformation_from_res_version import ( + TransformationFromRESVersion, +) +from ideadatamodel.snapshots import TableName + +from logging import Logger +from typing import Dict, List + + +class TransformationFromVersion2023_11(TransformationFromRESVersion): + + def transform_data(self, env_data_by_table: Dict[TableName, List], logger = Logger) -> Dict[TableName, List]: + return env_data_by_table diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/__init__.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/__init__.py new file mode 100644 index 0000000..59d9e03 --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. \ No newline at end of file diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/filesystems_cluster_settings_table_merger.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/filesystems_cluster_settings_table_merger.py new file mode 100644 index 0000000..5facaef --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/filesystems_cluster_settings_table_merger.py @@ -0,0 +1,277 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from typing import Dict, List, Tuple +import time + +from ideaclustermanager.app.snapshots.apply_snapshot_merge_table.merge_table import ( + MergeTable, + +) +from ideaclustermanager.app.snapshots.helpers.apply_snapshot_observability_helper import ( + ApplyResourceStatus, + ApplySnapshotObservabilityHelper, +) +from ideaclustermanager.app.snapshots.helpers.merged_record_utils import ( + MergedRecordActionType, + MergedRecordDelta, +) +from ideasdk.context import SocaContext +from pyhocon import ConfigFactory + +from ideadatamodel import errorcodes, exceptions +from ideadatamodel.constants import ( + MODULE_SHARED_STORAGE, + STORAGE_PROVIDER_EFS, + STORAGE_PROVIDER_FSX_NETAPP_ONTAP, +) +from ideadatamodel import ( + OnboardEFSFileSystemRequest, + AddFileSystemToProjectRequest, + OnboardONTAPFileSystemRequest, + OffboardFileSystemRequest, + ListProjectsRequest + +) +from ideadatamodel.snapshots.snapshot_model import TableName + +TABLE_NAME = TableName.CLUSTER_SETTINGS_TABLE_NAME + +RETRY_COUNT = 15 + +class FileSystemsClusterSettingTableMerger(MergeTable): + def merge(self, context: SocaContext, table_data_to_merge: List[Dict], dedup_id: str, + _merged_record_deltas: Dict[TableName, List[MergedRecordDelta]], logger: ApplySnapshotObservabilityHelper) -> Tuple[List[MergedRecordDelta], bool]: + record_deltas: List[MergedRecordDelta] = [] + + onboarded_filesystem_ids = self.get_list_of_onboarded_filesystem_ids(context) + accessible_filesystem_ids = self.get_list_of_accessible_filesystem_ids(context) + + env_project_names = set(self.get_names_of_projects_in_env(context)) + + details = self.extract_filesystem_details_to_dict(table_data_to_merge) + + for filesystem_name in details: + + if not isinstance(details[filesystem_name], dict): + continue + + # If 'provider' empty, skip applying filesystem + provider = details[filesystem_name].get("provider") + if not provider: + logger.warning(TABLE_NAME, filesystem_name, ApplyResourceStatus.SKIPPED, f"filesystem provider not mentioned for filesystem {filesystem_name}") + continue + + # If filesystem of a scope other than 'project' skip applying filesystem + scope = details[filesystem_name].get('scope') + if not scope or 'project' not in scope: + logger.warning(TABLE_NAME, filesystem_name, ApplyResourceStatus.SKIPPED, f"filesystem '{filesystem_name}' not of 'project' scope") + continue + + # If filesystem does not belong to the same VPC or is already onboarded, skip applying filesystem + try: + filesystem_id = details[filesystem_name][provider]["file_system_id"] + if filesystem_id in onboarded_filesystem_ids: + raise exceptions.soca_exception( + error_code=errorcodes.FILESYSTEM_ALREADY_ONBOARDED, + message=f"{filesystem_id} has already been onboarded" + ) + if filesystem_id not in accessible_filesystem_ids: + raise exceptions.soca_exception( + error_code=errorcodes.FILESYSTEM_NOT_IN_VPC, + message=f"{filesystem_id} not part of the env's VPC thus not accessible" + ) + except exceptions.SocaException as e: + # Gracefully handling cases when filesystem is already onboarded or not accessible + if e.error_code == errorcodes.FILESYSTEM_ALREADY_ONBOARDED or e.error_code == errorcodes.FILESYSTEM_NOT_IN_VPC: + logger.debug(TABLE_NAME, filesystem_name, ApplyResourceStatus.SKIPPED, f"{e.message}") + continue + else: + raise e + except Exception as e: + logger.error(TABLE_NAME, filesystem_name, ApplyResourceStatus.FAILED_APPLY, str(e)) + return record_deltas, False + + # Check to see if the env already has a filesystem with the same name. If so aply the filesystem with the dedup_id attached. + filesystem_with_name_already_present = False + try: + context.shared_filesystem.get_filesystem(filesystem_name) + + filesystem_with_name_already_present = True + except exceptions.SocaException as e: + if e.error_code != errorcodes.NO_SHARED_FILESYSTEM_FOUND and e.error_code != errorcodes.FILESYSTEM_NOT_FOUND: + raise e + except Exception as e: + logger.error( + TABLE_NAME, filesystem_name, ApplyResourceStatus.FAILED_APPLY, {e} + ) + return record_deltas, False + + # FileSystem with same name exists. + # This merge will add a new filesystem by appending the RES version and dedup ID to its name instead of overriding the existing ones, + filesystem_name_to_use = ( + MergeTable.unique_resource_id_generator(filesystem_name, dedup_id) + if filesystem_with_name_already_present + else filesystem_name + ) + + try: + if provider == STORAGE_PROVIDER_EFS: + self._onboard_efs(filesystem_name_to_use, details[filesystem_name], context) + elif provider == STORAGE_PROVIDER_FSX_NETAPP_ONTAP: + self._onboard_ontap(filesystem_name_to_use, details[filesystem_name], context) + + # Wait for some time for the filesystem changes to be picked up by the config listner and added to the local config tree + self._wait_for_onboarded_filesystem_to_sync_to_config_tree(filesystem_name, context) + accessible_filesystem_ids.remove(filesystem_id) + + # All the onboarded filesystem to corrresponding projects. This is a soft dependency. If a project does not exist, it will be ignored. A rollback will not be triggered in this case. + self._add_filesystem_to_projects(filesystem_name_to_use, details[filesystem_name], env_project_names, context, dedup_id, logger) + except exceptions.SocaException as e: + # Gracefully handling cases when filesystem is already onboarded or not accessible + if e.error_code == errorcodes.FILESYSTEM_ALREADY_ONBOARDED or e.error_code == errorcodes.FILESYSTEM_NOT_IN_VPC: + logger.debug(TABLE_NAME, filesystem_name, ApplyResourceStatus.SKIPPED, f"{e.message}") + continue + else: + raise e + except Exception as e: + logger.error(TABLE_NAME, filesystem_name_to_use, ApplyResourceStatus.FAILED_APPLY, str(e)) + return record_deltas, False + + if filesystem_with_name_already_present: + logger.debug(TABLE_NAME, filesystem_name_to_use, ApplyResourceStatus.APPLIED, f"fileSystem with same name already exists. Onboarded the filesystem successfully with name {filesystem_name_to_use}") + else: + logger.debug(TABLE_NAME, filesystem_name_to_use, ApplyResourceStatus.APPLIED, f"onboarded the filesystem successfully") + + record_deltas.append( + MergedRecordDelta( + original_record={}, + snapshot_record={filesystem_name: details[filesystem_name]}, + resolved_record={filesystem_name_to_use: details[filesystem_name]}, + action_performed=MergedRecordActionType.CREATE + ) + ) + + return record_deltas, True + + def get_list_of_onboarded_filesystem_ids(self, context): + onboarded_filesystems = context.shared_filesystem._list_shared_filesystems() + return set([fs.get_filesystem_id() for fs in onboarded_filesystems]) + + def get_list_of_accessible_filesystem_ids(self, context): + efs_filesystems = context.shared_filesystem._list_unonboarded_efs_file_systems([]) + fsx_filesystems = context.shared_filesystem._list_unonboarded_ontap_file_systems([]) + + filesystems_in_vpc = [*efs_filesystems, *fsx_filesystems] + + return set([fs.get_filesystem_id() for fs in filesystems_in_vpc]) + + def get_names_of_projects_in_env(self, context): + env_projects = context.projects.list_projects(ListProjectsRequest()).listing + return [project.name for project in env_projects] + + def extract_filesystem_details_to_dict( + self, table_data_to_merge: List[Dict] + ) -> Dict: + filesystem_details = {} + for setting in table_data_to_merge: + key = setting['key'] + value = setting['value'] + + if str(key).startswith(MODULE_SHARED_STORAGE): + filesystem_details[key] = value + + config = ConfigFactory.from_dict(filesystem_details) + + return config.get(MODULE_SHARED_STORAGE, {}).as_plain_ordered_dict() + + def _onboard_efs(self, filesystem_name: str, details: Dict, context: SocaContext): + onboard_efs_request = OnboardEFSFileSystemRequest( + filesystem_name=filesystem_name, + filesystem_title=details["title"], + filesystem_id=details[STORAGE_PROVIDER_EFS]["file_system_id"], + mount_directory=details["mount_dir"], + ) + context.shared_filesystem.onboard_efs_filesystem( + onboard_efs_request + ) + + def _onboard_ontap(self, filesystem_name: str, details: Dict, context: SocaContext): + onboard_ontap_request = OnboardONTAPFileSystemRequest( + filesystem_name=filesystem_name, + filesystem_title=details["title"], + filesystem_id=details[STORAGE_PROVIDER_FSX_NETAPP_ONTAP]["file_system_id"], + mount_directory=details["mount_dir"], + mount_drive=details["mount_drive"], + svm_id=details[STORAGE_PROVIDER_FSX_NETAPP_ONTAP]["svm"]["svm_id"], + volume_id=details[STORAGE_PROVIDER_FSX_NETAPP_ONTAP]["volume"]["volume_id"], + file_share_name=details[STORAGE_PROVIDER_FSX_NETAPP_ONTAP]["volume"]["cifs_share_name"] + ) + context.shared_filesystem.onboard_ontap_filesystem(onboard_ontap_request) + + def _wait_for_onboarded_filesystem_to_sync_to_config_tree(self, filesystem_name, context): + """Wait for some time for the filesystem changes to be picked up by the config listner and added to the local config tree + + Args: + filesystem_name (str): + context (SocaContext): + """ + retry_count = RETRY_COUNT + while retry_count: + retry_count -= 1 + time.sleep(5) + + try: + context.shared_filesystem.get_filesystem(filesystem_name) + return + except exceptions.SocaException as e: + pass + + def _add_filesystem_to_projects(self, filesystem_name: str, details: Dict, env_project_names: List[str], context: SocaContext, dedup_id: str, logger: ApplySnapshotObservabilityHelper): + fs_projects = details.get("projects") + if not fs_projects: + return + + for project_name in fs_projects: + # If project not present in env, skip attaching filesystem to project + if project_name not in env_project_names: + logger.debug(TABLE_NAME, filesystem_name, ApplyResourceStatus.SKIPPED_SOFT_DEP ,f"project '{project_name}' not present in env.") + continue + + # If a new project with dedup_id was created while applying the projects table due to name conflict, the filesystem must be attached to the project with dedup_id + deduped_project_name = MergeTable.unique_resource_id_generator(project_name, dedup_id) + project_name_to_use = deduped_project_name if deduped_project_name in env_project_names else project_name + + try: + context.shared_filesystem.add_filesystem_to_project( + AddFileSystemToProjectRequest( + filesystem_name=filesystem_name, project_name=project_name_to_use + ) + ) + logger.debug(TABLE_NAME, filesystem_name, ApplyResourceStatus.APPLIED_SOFT_DEP ,f"added {project_name_to_use} to filesystem successfully") + except Exception as e: + logger.debug(TABLE_NAME, filesystem_name, ApplyResourceStatus.SKIPPED_SOFT_DEP ,f"adding project '{project_name_to_use}' to filesystem failed with error {e}") + + def rollback(self, context: SocaContext, record_deltas: List[MergedRecordDelta], logger: ApplySnapshotObservabilityHelper) -> None: + while record_deltas: + record_delta = record_deltas[0] + filesystem_name = list(record_delta.resolved_record.keys())[0] + + if record_delta.action_performed == MergedRecordActionType.CREATE: + try: + context.shared_filesystem.offboard_filesystem(OffboardFileSystemRequest(filesystem_name=filesystem_name)) + except Exception as e: + logger.error(TABLE_NAME, filesystem_name, ApplyResourceStatus.FAILED_ROLLBACK, str(e)) + raise e + + logger.debug(TABLE_NAME, filesystem_name, ApplyResourceStatus.ROLLBACKED, "offboarded filesystem succeeded") + + record_deltas.pop(0) diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/merge_table.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/merge_table.py new file mode 100644 index 0000000..f6a765c --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/merge_table.py @@ -0,0 +1,57 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from ideasdk.context import SocaContext +from ideaclustermanager.app.snapshots.helpers.apply_snapshot_observability_helper import ApplySnapshotObservabilityHelper +from ideaclustermanager.app.snapshots.helpers.merged_record_utils import MergedRecordDelta +from ideadatamodel.snapshots.snapshot_model import TableName + +from abc import abstractmethod +from typing import Dict, List, Tuple + + +class MergeTable: + @abstractmethod + def merge(self, context: SocaContext, table_data_to_merge: List, + dedup_id: str, merged_record_deltas: Dict[TableName, List[MergedRecordDelta]], + logger: ApplySnapshotObservabilityHelper) -> Tuple[List[MergedRecordDelta], bool]: + """ Merges the records represented by table_data_to_merge List to the env by invoking the corresponding + Create APIs or the create functions in the Service. If exception arrises during merge operation for a record, the exceotion should be handeled + and a tuple indicating the elta for merged records with a boolean False indicating merge failure must be returned, Eg. return merge_delata_dict, False. + + Args: + context (SocaContext): + table_data_to_merge (List): List of python dicts representing records to be merged. + dedup_id (str): Dedup ID for resolving conflict records. + merged_record_deltas (Dict[TableName, List[MergedRecordDelta]]): Deltas for the merged records. + logger (ApplySnapshotObservabilityHelper): Helper for logging the operations. + + Returns: + Tuple[Dict, bool]: Dict represents the represents the delta/changes. This will be passed to the rollback function, in case rollback is must be initiated. + bool represents if the merge was successful for all records. If merging a record fails, the delta will be returned with False indicating the merge was not successful. + Returned bool indicates if a rollback should be initiated. + """ + ... + + @abstractmethod + def rollback(self, context: SocaContext, merge_delta: List[MergedRecordDelta], logger: ApplySnapshotObservabilityHelper): + """ Rolls back a merge applied on a table. + + Args: + context (SocaContext): + merge_delta (Dict): Discribes the delta introduced by the merge operation. + logger (ApplySnapshotObservabilityHelper): Helper for logging the operations + """ + ... + + @staticmethod + def unique_resource_id_generator(key: str, dedup_id: str) -> str: + return f'{key}_{dedup_id}' \ No newline at end of file diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/permission_profiles_table_merger.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/permission_profiles_table_merger.py new file mode 100644 index 0000000..fb6f70b --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/permission_profiles_table_merger.py @@ -0,0 +1,118 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from ideaclustermanager.app.snapshots.apply_snapshot_merge_table.merge_table import MergeTable +from ideaclustermanager.app.snapshots.helpers.merged_record_utils import MergedRecordDelta, MergedRecordActionType +from ideaclustermanager.app.snapshots.helpers.apply_snapshot_observability_helper import ApplySnapshotObservabilityHelper, ApplyResourceStatus +import ideaclustermanager.app.snapshots.helpers.db_utils as db_utils + +from ideadatamodel import ( + VirtualDesktopPermission, + errorcodes, + exceptions, +) +from ideadatamodel.snapshots.snapshot_model import TableName + +from ideasdk.context import SocaContext + +import copy +from typing import Dict, List, Optional, Tuple + +TABLE_NAME = TableName.PERMISSION_PROFILES_TABLE_NAME + + +class PermissionProfilesTableMerger(MergeTable): + def merge(self, context: SocaContext, table_data_to_merge: List[Dict], + dedup_id: str, _merged_record_deltas: Dict[TableName, List[MergedRecordDelta]], + logger: ApplySnapshotObservabilityHelper) -> Tuple[List[MergedRecordDelta], bool]: + record_deltas: List[MergedRecordDelta] = [] + + try: + permission_types = context.vdc_client.get_base_permissions() + except Exception as e: + logger.error(TABLE_NAME, "", ApplyResourceStatus.FAILED_APPLY, str(e)) + return [], False + + for permission_profile_db_record in table_data_to_merge: + if not permission_profile_db_record or not permission_profile_db_record.get(db_utils.PERMISSION_PROFILE_DB_HASH_KEY): + logger.warning(TABLE_NAME, "", ApplyResourceStatus.SKIPPED, "permission profile ID is empty") + continue + + profile_id = permission_profile_db_record[db_utils.PERMISSION_PROFILE_DB_HASH_KEY] + try: + resolved_record, action_type = self.resolve_record( + context, copy.deepcopy(permission_profile_db_record), + dedup_id, permission_types, logger) + if not action_type: + logger.debug(TABLE_NAME, profile_id, ApplyResourceStatus.SKIPPED, "permission profile is unchanged") + continue + + record_delta = MergedRecordDelta( + snapshot_record=permission_profile_db_record, + resolved_record=resolved_record, + action_performed=action_type, + ) + profile_id = resolved_record[db_utils.PERMISSION_PROFILE_DB_HASH_KEY] + + if record_delta.action_performed == MergedRecordActionType.CREATE: + permission_profile = db_utils.convert_db_dict_to_permission_profile_object(record_delta.resolved_record, permission_types) + permission_profile = context.vdc_client.create_permission_profile(permission_profile) + record_delta.resolved_record = db_utils.convert_permission_profile_object_to_db_dict(permission_profile, permission_types) + record_deltas.append(record_delta) + except Exception as e: + logger.error(TABLE_NAME, profile_id, ApplyResourceStatus.FAILED_APPLY, str(e)) + return record_deltas, False + + logger.debug(TABLE_NAME, profile_id, ApplyResourceStatus.APPLIED, "adding permission profile succeeded") + + return record_deltas, True + + def rollback(self, context: SocaContext, record_deltas: List[MergedRecordDelta], logger: ApplySnapshotObservabilityHelper): + while record_deltas: + record_delta = record_deltas[0] + profile_id = record_delta.resolved_record.get(db_utils.PERMISSION_PROFILE_DB_HASH_KEY, "") + if record_delta.action_performed == MergedRecordActionType.CREATE: + # Currently we only add new permission profiles instead of updating existing ones when applying a snapshot. + # Add this checking here for handling updated records in the future. + try: + context.vdc_client.delete_permission_profile(profile_id) + except Exception as e: + logger.error(TABLE_NAME, profile_id, ApplyResourceStatus.FAILED_ROLLBACK, str(e)) + raise e + + logger.debug(TABLE_NAME, profile_id, ApplyResourceStatus.ROLLBACKED, "removing permission profile succeeded") + + record_deltas.pop(0) + + @staticmethod + def resolve_record(context: SocaContext, db_entry: dict, dedup_id: str, + permission_types: list[VirtualDesktopPermission], + logger: ApplySnapshotObservabilityHelper) -> (Dict, Optional[MergedRecordActionType]): + profile_id = db_entry[db_utils.PERMISSION_PROFILE_DB_HASH_KEY] + try: + existing_permission_profile = context.vdc_client.get_permission_profile(profile_id) + snapshot_permission_profile = db_utils.convert_db_dict_to_permission_profile_object(db_entry, permission_types) + if existing_permission_profile == snapshot_permission_profile: + return db_entry, None + + # Permission profile with the same profile ID exists. + # This merger will add a new permission profile by appending the dedup ID to its profile ID instead of overriding the existing ones, + # since the existing permission profiles might be used by active VDIs under the new environment. + profile_id = MergeTable.unique_resource_id_generator(profile_id, dedup_id) + db_entry[db_utils.PERMISSION_PROFILE_DB_HASH_KEY] = profile_id + + logger.debug(TABLE_NAME, profile_id, reason="Permission profile with the same ID exists under the current environment. " + "Created a new record with the dedup ID appended to the permission profile id") + except exceptions.SocaException as e: + if e.error_code != errorcodes.INVALID_PARAMS: + raise e + + return db_entry, MergedRecordActionType.CREATE diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/projects_table_merger.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/projects_table_merger.py new file mode 100644 index 0000000..bad335e --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/projects_table_merger.py @@ -0,0 +1,157 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from ideaclustermanager.app.projects.db.projects_dao import ProjectsDAO +from ideaclustermanager.app.snapshots.helpers.apply_snapshot_observability_helper import ApplySnapshotObservabilityHelper, ApplyResourceStatus +from ideaclustermanager.app.snapshots.helpers.merged_record_utils import MergedRecordDelta, MergedRecordActionType +from ideaclustermanager.app.snapshots.apply_snapshot_merge_table.merge_table import MergeTable + +from ideadatamodel.snapshots.snapshot_model import TableName +from ideadatamodel import ( + GetProjectRequest, + CreateProjectRequest, + DeleteProjectRequest, + UpdateProjectRequest, + Project, + errorcodes, + exceptions, +) + +from ideasdk.context import SocaContext + +import copy +from typing import Dict, List, Optional, Tuple + +PROJECTS_TABLE_PROJECT_NAME_KEY = "name" +TABLE_NAME = TableName.PROJECTS_TABLE_NAME + + +class ProjectsTableMerger(MergeTable): + """ + Helper class for merging the projects table + """ + + def merge(self, context: SocaContext, table_data_to_merge: List[Dict], + dedup_id: str, _merged_record_deltas: Dict[TableName, List[MergedRecordDelta]], + logger: ApplySnapshotObservabilityHelper) -> Tuple[List[MergedRecordDelta], bool]: + record_deltas: List[MergedRecordDelta] = [] + for project_db_record in table_data_to_merge: + if not project_db_record or not project_db_record.get(PROJECTS_TABLE_PROJECT_NAME_KEY): + logger.debug(TABLE_NAME, "", ApplyResourceStatus.SKIPPED, "project name is empty") + continue + + project_name = project_db_record[PROJECTS_TABLE_PROJECT_NAME_KEY] + try: + resolved_record, action_type = self.resolve_record( + context, copy.deepcopy(project_db_record), + dedup_id, logger) + if not action_type: + logger.debug(TABLE_NAME, project_name, ApplyResourceStatus.SKIPPED, "project is unchanged") + continue + + record_delta = MergedRecordDelta( + snapshot_record=project_db_record, + resolved_record=resolved_record, + action_performed=action_type + ) + project_name = resolved_record[PROJECTS_TABLE_PROJECT_NAME_KEY] + + if record_delta.action_performed == MergedRecordActionType.CREATE: + project = ProjectsDAO.convert_from_db(record_delta.resolved_record) + project.ldap_groups = [] + project.users = [] + project.enable_budgets = self._budget_is_enabled_and_exists(context, project, logger) + project = context.projects.create_project(CreateProjectRequest(project=project)).project + record_delta.resolved_record = ProjectsDAO.convert_to_db(project) + record_deltas.append(record_delta) + except Exception as e: + logger.error(TABLE_NAME, project_name, ApplyResourceStatus.FAILED_APPLY, str(e)) + return record_deltas, False + + self._add_groups_and_users_to_project(context, project_db_record, record_delta, logger) + + logger.debug(TABLE_NAME, project_name, ApplyResourceStatus.APPLIED, "adding project succeeded") + + return record_deltas, True + + def rollback(self, context: SocaContext, record_deltas: List[MergedRecordDelta], logger: ApplySnapshotObservabilityHelper) -> None: + while record_deltas: + record_delta = record_deltas[0] + project_name = record_delta.resolved_record.get(PROJECTS_TABLE_PROJECT_NAME_KEY, "") + if record_delta.action_performed == MergedRecordActionType.CREATE: + # Currently we only add new projects instead of updating existing ones when applying a snapshot. + # Add this checking here for handling updated records in the future. + try: + context.projects.delete_project(DeleteProjectRequest(project_name=project_name)) + except Exception as e: + logger.error(TABLE_NAME, project_name, ApplyResourceStatus.FAILED_ROLLBACK, str(e)) + raise e + + logger.debug(TABLE_NAME, project_name, ApplyResourceStatus.ROLLBACKED,"removing project succeeded") + + record_deltas.pop(0) + + @staticmethod + def _add_groups_and_users_to_project(context: SocaContext, project_db_record: Dict, + record_delta: MergedRecordDelta, logger: ApplySnapshotObservabilityHelper): + project = copy.deepcopy(record_delta.resolved_record) + project["ldap_groups"] = project_db_record.get("ldap_groups", []) + project["users"] = project_db_record.get("users", []) + project_name = project["name"] + try: + project = ProjectsDAO.convert_from_db(project) + project = context.projects.update_project( + UpdateProjectRequest(project=project) + ).project + record_delta.resolved_record = ProjectsDAO.convert_to_db(project) + + logger.info(TABLE_NAME, project_name, reason="adding ldap groups and users to project successfully") + except Exception as e: + logger.warning(TABLE_NAME, project_name, reason=f"failed to add ldap groups and users to project: {str(e)}") + + @staticmethod + def _budget_is_enabled_and_exists(context: SocaContext, project: Project, logger: ApplySnapshotObservabilityHelper) -> bool: + if not project.enable_budgets: + return False + + if not project.budget or not project.budget.budget_name: + logger.warning(TABLE_NAME, project.name, reason="budget name is empty and will be ignored") + return False + + try: + _budget = context.aws_util().budgets_get_budget(project.budget.budget_name) + except Exception as e: + logger.warning(TABLE_NAME, project.name, reason=f"cannot get the budget name and it will be ignored: {str(e)}") + return False + + return True + + @staticmethod + def resolve_record(context: SocaContext, db_entry: dict, dedup_id: str, + logger: ApplySnapshotObservabilityHelper) -> (Dict, Optional[MergedRecordActionType]): + project_name = db_entry[PROJECTS_TABLE_PROJECT_NAME_KEY] + try: + existing_project = context.projects.get_project(GetProjectRequest(project_name=project_name)).project + snapshot_project = ProjectsDAO.convert_from_db(db_entry) + if existing_project == snapshot_project: + return db_entry, None + + # If the project already exists, rename the project ID by appending the dedup ID. + # This merger will add new projects instead of overriding the existing ones to resolve conflicts. + project_name = MergeTable.unique_resource_id_generator(project_name, dedup_id) + db_entry[PROJECTS_TABLE_PROJECT_NAME_KEY] = project_name + logger.debug(TABLE_NAME, project_name, reason="Project with the same name exists under the current environment. " + "Created a new record with the dedup ID appended to the project name") + except exceptions.SocaException as e: + if e.error_code != errorcodes.PROJECT_NOT_FOUND: + raise e + + return db_entry, MergedRecordActionType.CREATE diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/software_stacks_table_merger.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/software_stacks_table_merger.py new file mode 100644 index 0000000..3429325 --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/software_stacks_table_merger.py @@ -0,0 +1,155 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from ideaclustermanager.app.snapshots.helpers.apply_snapshot_observability_helper import ( + ApplySnapshotObservabilityHelper, + ApplyResourceStatus, +) +from ideaclustermanager.app.snapshots.helpers.merged_record_utils import ( + MergedRecordDelta, + MergedRecordActionType, +) +from ideaclustermanager.app.snapshots.apply_snapshot_merge_table.merge_table import ( + MergeTable, +) +import ideaclustermanager.app.snapshots.helpers.db_utils as db_utils + +from ideadatamodel.snapshots.snapshot_model import TableName +from ideadatamodel import errorcodes, exceptions + +from ideasdk.context import SocaContext + +from typing import Dict, List, Optional, Tuple +import copy + +TABLE_NAME = TableName.SOFTWARE_STACKS_TABLE_NAME + + +class SoftwareStacksTableMerger(MergeTable): + """ + Helper class for merging the software stacks table + """ + + def merge(self, context: SocaContext, table_data_to_merge: List[Dict], + dedup_id: str, merged_record_deltas: Dict[TableName, List[MergedRecordDelta]], + logger: ApplySnapshotObservabilityHelper) -> Tuple[List[MergedRecordDelta], bool]: + project_table_deltas = merged_record_deltas.get(TableName.PROJECTS_TABLE_NAME, []) + project_id_mappings = { + project_delta.snapshot_record["project_id"]: project_delta.resolved_record["project_id"] + for project_delta in project_table_deltas + } + + record_deltas: List[MergedRecordDelta] = [] + for software_stack_db_record in table_data_to_merge: + if not software_stack_db_record or not software_stack_db_record.get(db_utils.SOFTWARE_STACK_DB_NAME_KEY): + logger.debug(TABLE_NAME, "", ApplyResourceStatus.SKIPPED, + f"{db_utils.SOFTWARE_STACK_DB_NAME_KEY} is empty") + continue + + stack_name = software_stack_db_record[db_utils.SOFTWARE_STACK_DB_NAME_KEY] + try: + resolved_record, action_type = self.resolve_record( + context, copy.deepcopy(software_stack_db_record), + dedup_id, project_id_mappings, logger) + if not action_type: + logger.debug(TABLE_NAME, stack_name, ApplyResourceStatus.SKIPPED, "software stack is unchanged") + continue + + record_delta = MergedRecordDelta( + snapshot_record=software_stack_db_record, + resolved_record=resolved_record, + action_performed=action_type, + ) + stack_name = resolved_record[db_utils.SOFTWARE_STACK_DB_NAME_KEY] + + if record_delta.action_performed == MergedRecordActionType.CREATE: + software_stack = db_utils.convert_db_dict_to_software_stack_object(record_delta.resolved_record) + software_stack = context.vdc_client.create_software_stack(software_stack) + record_delta.resolved_record = db_utils.convert_software_stack_object_to_db_dict(software_stack) + record_deltas.append(record_delta) + except exceptions.SocaException as e: + if e.error_code == errorcodes.INVALID_PARAMS and e.message.startswith("Invalid software_stack.ami_id"): + logger.debug(TABLE_NAME, stack_name, ApplyResourceStatus.SKIPPED, f"AMI ID of the software stack is not available in the current region: {str(e)}") + continue + raise e + except Exception as e: + logger.error(TABLE_NAME, stack_name, ApplyResourceStatus.FAILED_APPLY, str(e)) + return record_deltas, False + + logger.debug(TABLE_NAME, stack_name, ApplyResourceStatus.APPLIED, "adding software stack succeeded") + + return record_deltas, True + + def rollback( + self, + context: SocaContext, + record_deltas: List[MergedRecordDelta], + logger: ApplySnapshotObservabilityHelper, + ) -> None: + while record_deltas: + record_delta = record_deltas[0] + if record_delta.action_performed == MergedRecordActionType.CREATE: + # Currently we only add new software stacks instead of updating existing ones when applying a snapshot. + # Add this checking here for handling updated records in the future. + try: + context.vdc_client.delete_software_stack( + db_utils.convert_db_dict_to_software_stack_object(record_delta.resolved_record), + ) + except Exception as e: + logger.error(TABLE_NAME, record_delta.resolved_record[db_utils.SOFTWARE_STACK_DB_NAME_KEY], + ApplyResourceStatus.FAILED_ROLLBACK, str(e)) + raise e + + logger.debug(TABLE_NAME, record_delta.resolved_record[db_utils.SOFTWARE_STACK_DB_NAME_KEY], + ApplyResourceStatus.ROLLBACKED,"removing software stack succeeded") + + record_deltas.pop(0) + + @staticmethod + def resolve_record( + context: SocaContext, + db_entry: dict, + dedup_id: str, + project_id_mappings: Dict[str, str], + logger: ApplySnapshotObservabilityHelper, + ) -> (Dict, Optional[MergedRecordActionType]): + SoftwareStacksTableMerger._resolve_project_ids(db_entry, project_id_mappings) + + stack_name = db_entry[db_utils.SOFTWARE_STACK_DB_NAME_KEY] + software_stacks_by_name = context.vdc_client.get_software_stacks_by_name(stack_name) + if software_stacks_by_name: + snapshot_software_stack = db_utils.convert_db_dict_to_software_stack_object(db_entry) + if any(existing_software_stack == snapshot_software_stack for existing_software_stack in software_stacks_by_name): + return db_entry, None + + # If software stacks with the exact same name already exists, rename the stack name by appending the dedup ID. + # This merger will add new software stacks instead of overriding the existing ones to resolve conflicts. + stack_name = MergeTable.unique_resource_id_generator(stack_name, dedup_id) + db_entry[db_utils.SOFTWARE_STACK_DB_NAME_KEY] = stack_name + + logger.debug(TABLE_NAME, stack_name, reason="Software stack with the same name exists under the current environment. " + "Created a new record with the dedup ID appended to the software stack name") + + return db_entry, MergedRecordActionType.CREATE + + @staticmethod + def _resolve_project_ids( + db_entry: Dict, + project_id_mappings: Dict[str, str], + ): + projects = db_entry.get(db_utils.SOFTWARE_STACK_DB_PROJECTS_KEY, []) + for index in range(len(projects)): + project_id = projects[index] + if not project_id_mappings.get(project_id): + # project is unchanged, so it's not in the merged record delta. + continue + projects[index] = project_id_mappings[project_id] + db_entry[db_utils.SOFTWARE_STACK_DB_PROJECTS_KEY] = projects diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/users_table_merger.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/users_table_merger.py new file mode 100644 index 0000000..b1d2449 --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/apply_snapshot_merge_table/users_table_merger.py @@ -0,0 +1,88 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from ideasdk.context import SocaContext +from ideaclustermanager.app.accounts.db.user_dao import UserDAO +from ideaclustermanager.app.snapshots.helpers.apply_snapshot_observability_helper import ApplySnapshotObservabilityHelper, ApplyResourceStatus +from ideaclustermanager.app.snapshots.helpers.merged_record_utils import MergedRecordDelta, MergedRecordActionType +from ideaclustermanager.app.snapshots.apply_snapshot_merge_table.merge_table import MergeTable +from ideadatamodel.snapshots.snapshot_model import TableName +from ideadatamodel import errorcodes, exceptions, constants, User + +from typing import Dict, List, Tuple + +TABLE_NAME = TableName.USERS_TABLE_NAME + + +class UsersTableMerger(MergeTable): + """ + Helper class for merging the accounts.users table + """ + def merge(self, context: SocaContext, table_data_to_merge: List[Dict], + _dedup_id: str, _merged_record_deltas: Dict[TableName, List[MergedRecordDelta]], + logger: ApplySnapshotObservabilityHelper) -> Tuple[List[MergedRecordDelta], bool]: + record_deltas: List[MergedRecordDelta] = [] + for user_db_record in table_data_to_merge: + user_to_merge = UserDAO.convert_from_db(user_db_record) + try: + existing_user = context.accounts.get_user(user_to_merge.username) + except exceptions.SocaException as e: + if e.error_code == errorcodes.AUTH_USER_NOT_FOUND: + logger.debug(TABLE_NAME, user_to_merge.username, ApplyResourceStatus.SKIPPED, "the user doesn't exist in the current environment") + continue + raise e + except Exception as e: + logger.error(TABLE_NAME, user_to_merge.username, ApplyResourceStatus.FAILED_APPLY, str(e)) + return record_deltas, False + + if user_to_merge.role == existing_user.role: + logger.debug(TABLE_NAME, user_to_merge.username, ApplyResourceStatus.SKIPPED, "the user role is unchanged") + continue + + try: + self.apply(context, user_to_merge, logger) + record_deltas.append( + MergedRecordDelta( + original_record=UserDAO.convert_to_db(existing_user), + snapshot_record=user_db_record, + action_performed=MergedRecordActionType.UPDATE + ) + ) + except Exception as e: + logger.error(TABLE_NAME, user_to_merge.username, ApplyResourceStatus.FAILED_APPLY, str(e)) + return record_deltas, False + + return record_deltas, True + + def rollback(self, context: SocaContext, record_deltas: List[MergedRecordDelta], logger: ApplySnapshotObservabilityHelper) -> None: + while record_deltas: + record_delta = record_deltas[0] + user = UserDAO.convert_from_db(record_delta.original_record) + if record_delta.action_performed == MergedRecordActionType.UPDATE: + # Currently we only update admin permissions for existing users. + # Add this checking here for handling new users in the future. + try: + self.apply(context, user, logger, True) + except Exception as e: + logger.error(TABLE_NAME, user.username, ApplyResourceStatus.FAILED_ROLLBACK, str(e)) + raise e + + record_deltas.pop(0) + + @staticmethod + def apply(context: SocaContext, user: User, logger: ApplySnapshotObservabilityHelper, is_rolling_back: bool = False) -> None: + expected_resource_status = ApplyResourceStatus.ROLLBACKED if is_rolling_back else ApplyResourceStatus.APPLIED + if user.role == constants.ADMIN_ROLE: + context.accounts.add_admin_user(user.username) + logger.debug(TABLE_NAME, user.username, expected_resource_status, "setting admin privileges succeeded") + else: + context.accounts.remove_admin_user(user.username) + logger.debug(TABLE_NAME, user.username, expected_resource_status, "removing admin privileges succeeded") diff --git a/source/idea/idea-sdk/src/ideasdk/aws/opensearch/__init__.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/db/__init__.py similarity index 100% rename from source/idea/idea-sdk/src/ideasdk/aws/opensearch/__init__.py rename to source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/db/__init__.py diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/db/apply_snapshot_dao.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/db/apply_snapshot_dao.py new file mode 100644 index 0000000..f5ae334 --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/db/apply_snapshot_dao.py @@ -0,0 +1,126 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from ideasdk.context import SocaContext +from ideasdk.utils import Utils, scan_db_records +from ideadatamodel import ( + exceptions, + ApplySnapshot, + ApplySnapshotStatus, + ListApplySnapshotRecordsRequest, + ListApplySnapshotRecordsResult, + SocaListingPayload, + SocaPaginator, +) + +from typing import Dict, Optional +from boto3.dynamodb.conditions import Attr + + +class ApplySnapshotDAO: + def __init__(self, context: SocaContext, logger=None): + self.context = context + if logger is not None: + self.logger = logger + else: + self.logger = context.logger('apply-snapshot-dao') + self.table = None + + def get_table_name(self) -> str: + return f'{self.context.cluster_name()}.apply-snapshot' + + def initialize(self): + self.context.aws_util().dynamodb_create_table( + create_table_request={ + 'TableName': self.get_table_name(), + 'AttributeDefinitions': [{'AttributeName': 'apply_snapshot_identifier', 'AttributeType': 'S'}], + 'KeySchema': [{'AttributeName': 'apply_snapshot_identifier', 'KeyType': 'HASH'}], + 'BillingMode': 'PAY_PER_REQUEST', + }, + wait=True, + ) + self.table = self.context.aws().dynamodb_table().Table(self.get_table_name()) + + @staticmethod + def convert_from_db(apply_snapshot: Dict) -> ApplySnapshot: + keys = [ + 'apply_snapshot_identifier', + 's3_bucket_name', + 'snapshot_path', + 'status', + 'created_on', + 'failure_reason', + ] + return ApplySnapshot(**{k: apply_snapshot.get(k) for k in keys}) + + @staticmethod + def convert_to_db(apply_snapshot: ApplySnapshot) -> Dict: + keys = [ + 'apply_snapshot_identifier', + 's3_bucket_name', + 'snapshot_path', + 'status', + 'created_on', + 'failure_reason', + ] + return {k: getattr(apply_snapshot, k) for k in keys} + + def create(self, apply_snapshot: Dict, created_on: int) -> Dict: + s3_bucket_name = apply_snapshot.get('s3_bucket_name') + if not s3_bucket_name or len(s3_bucket_name.strip()) == 0: + raise exceptions.invalid_params('s3_bucket_name is required') + snapshot_path = apply_snapshot.get('snapshot_path') + if not snapshot_path or len(snapshot_path.strip()) == 0: + raise exceptions.invalid_params('snapshot_path is required') + + apply_snapshot_record = { + **apply_snapshot, + 'apply_snapshot_identifier': f'{s3_bucket_name}-{snapshot_path}-{created_on}', + 'created_on': created_on, + } + + self.table.put_item(Item=apply_snapshot_record) + + return apply_snapshot_record + + def update_status(self, apply_snapshot: ApplySnapshot, status: ApplySnapshotStatus, failure_message: Optional[str] = None) -> Dict: + update_expression_tokens = ['#status_key = :status_value'] + expression_attr_names = {"#status_key": "status"} + expression_attr_values = {':status_value': status} + + if failure_message and status in [ApplySnapshotStatus.FAILED, ApplySnapshotStatus.ROLLBACK_IN_PROGRESS, ApplySnapshotStatus.ROLLBACK_COMPLETE, ApplySnapshotStatus.ROLLBACE_FAILED]: + update_expression_tokens.append('#failure_reason_key = :failure_reason_value') + expression_attr_names['#failure_reason_key'] = "failure_reason" + expression_attr_values[':failure_reason_value'] = failure_message + + result = self.table.update_item( + Key={'apply_snapshot_identifier': apply_snapshot.apply_snapshot_identifier}, + ConditionExpression=Attr('status').ne(status), + UpdateExpression='SET ' + ', '.join(update_expression_tokens), + ExpressionAttributeNames=expression_attr_names, + ExpressionAttributeValues=expression_attr_values, + ReturnValues='ALL_NEW', + ) + + updated_apply_snapshot = result['Attributes'] + return updated_apply_snapshot + + def list(self, request: ListApplySnapshotRecordsRequest) -> ListApplySnapshotRecordsResult: + list_result = scan_db_records(request, self.table) + entries = list_result.get('Items', []) + result = [self.convert_from_db(entry) for entry in entries] + + exclusive_start_key = list_result.get("LastEvaluatedKey") + response_cursor = Utils.base64_encode(Utils.to_json(exclusive_start_key)) if exclusive_start_key else None + + return SocaListingPayload( + listing=result, paginator=SocaPaginator(page_size=request.page_size, cursor=response_cursor) + ) diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/snapshot_dao.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/db/snapshot_dao.py similarity index 99% rename from source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/snapshot_dao.py rename to source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/db/snapshot_dao.py index 94a2aaa..4ad55ed 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/snapshot_dao.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/db/snapshot_dao.py @@ -8,6 +8,7 @@ # or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions # and limitations under the License. + import logging from ideasdk.utils import Utils from ideadatamodel import exceptions, Snapshot, SnapshotStatus, ListSnapshotsRequest, ListSnapshotsResult, SocaPaginator diff --git a/source/idea/idea-sdk/src/ideasdk/analytics/__init__.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/__init__.py similarity index 99% rename from source/idea/idea-sdk/src/ideasdk/analytics/__init__.py rename to source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/__init__.py index 4a799b6..6d8d18a 100644 --- a/source/idea/idea-sdk/src/ideasdk/analytics/__init__.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/__init__.py @@ -8,4 +8,3 @@ # or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions # and limitations under the License. - diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/apply_snapshot_observability_helper.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/apply_snapshot_observability_helper.py new file mode 100644 index 0000000..4c62cc5 --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/apply_snapshot_observability_helper.py @@ -0,0 +1,94 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from enum import Enum +from logging import Logger +from typing import Optional + + +class ApplyResourceStatus(Enum): + APPLIED = "applied", + APPLIED_SOFT_DEP = "applied_soft_dep" + SKIPPED_SOFT_DEP = "skipped_soft_dep" + SKIPPED = "skipped", + FAILED_APPLY = "failed_apply", + FAILED_ROLLBACK = "failed_rollback", + ROLLBACKED = "rollbacked" + + +class ApplySnapshotObservabilityHelper: + """ + Helper class for formatting the ApplySnapshot logs. + TODO: Log messages in JSON format for easy querying. + """ + def __init__(self, logger: Logger): + self.logger = logger + + def info(self, table_name: str, resource_id: str, status: Optional[ApplyResourceStatus] = None, reason: Optional[str] = None) -> None: + """ + Log message at INFO level + :param table_name: Name of the table to merge + :param resource_id: ID for the resource to merge + :param status: Status of the merge operation + :param reason: Reason for the status + :return: None + """ + self.logger.info(self.message(table_name, resource_id, status, reason)) + + def warning(self, table_name: str, resource_id: str, status: Optional[ApplyResourceStatus] = None, reason: Optional[str] = None) -> None: + """ + Log message at WARN level + :param table_name: Name of the table to merge + :param resource_id: ID for the resource to merge + :param status: Status of the merge operation + :param reason: Reason for the status + :return: None + """ + self.logger.warning(self.message(table_name, resource_id, status, reason)) + + def error(self, table_name: str, resource_id: str, status: Optional[ApplyResourceStatus] = None, reason: Optional[str] = None) -> None: + """ + Log message at ERROR level + :param table_name: Name of the table to merge + :param resource_id: ID for the resource to merge + :param status: Status of the merge operation + :param reason: Reason for the failure + :return: None + """ + self.logger.error(self.message(table_name, resource_id, status, reason)) + + def debug(self, table_name: str, resource_id: str, status: Optional[ApplyResourceStatus] = None, reason: Optional[str] = None) -> None: + """ + Log message at DEBUG level + :param table_name: Name of the table to merge + :param resource_id: ID for the resource to merge + :param status: Status of the merge operation + :param reason: Reason for the status + :return: None + """ + self.logger.debug(self.message(table_name, resource_id, status, reason)) + + @staticmethod + def message(table_name: str, resource_id: str, status: Optional[ApplyResourceStatus] = None, reason: Optional[str] = None) -> str: + """ + Construct the log message + :param table_name: Name of the table to merge + :param resource_id: ID for the resource to merge + :param status: Status of the merge operation + :param reason: Reason for the status + :return: the constructed message + """ + message = f"{table_name}/{resource_id}" + if status: + message = f"{message} {status.name}" + if reason: + message = f"{message} because: {reason}" + return message diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/apply_snapshot_temp_tables_helper.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/apply_snapshot_temp_tables_helper.py new file mode 100644 index 0000000..7506ab3 --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/apply_snapshot_temp_tables_helper.py @@ -0,0 +1,72 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from ideasdk.context import SocaContext + +from typing import List, Optional +import botocore.exceptions + + +class ApplySnapshotTempTablesHelper: + def __init__(self, context: SocaContext): + self.context = context + self.logger = context.logger('apply-snapshot-temp-table-helper') + + def initiate_import_table(self, s3_bucket_name: str, s3_key_prefix: str, table_name: str, partition_key: str, sort_key: Optional[str]): + """Initiate dynamodb.import_table operation. + + Args: + s3_bucket_name (str): Bucket name where the snapshot is present + s3_key_prefix (str): Path at which data for the particular table is present. + table_name (str): The name of the table that will be created with the data from the snapshot + partition_key (str): partition_key that will be used for the DynamoDB table created from the snapshot + sort_key (Optional[str]): sort_key (if applicable) that will be used for the DynamoDb table created from the snapshot + + Returns: + Dict: Returns response received for the aws().dynamodb().import_table() operation + """ + import_table_request = { + 'S3BucketSource': { + # Update with valid S3 bucket name here. + 'S3Bucket': s3_bucket_name, + 'S3KeyPrefix': s3_key_prefix, + }, + 'InputFormat': 'DYNAMODB_JSON', + 'InputCompressionType': 'GZIP', + 'TableCreationParameters': { + 'TableName': table_name, + 'AttributeDefinitions': [{'AttributeName': partition_key, 'AttributeType': 'S'}], + 'KeySchema': [{'AttributeName': partition_key, 'KeyType': 'HASH'}], + 'BillingMode': 'PAY_PER_REQUEST', + }, + } + + if sort_key: + import_table_request['TableCreationParameters']['AttributeDefinitions'].append( + {'AttributeName': sort_key, 'AttributeType': 'S'} + ) + import_table_request['TableCreationParameters']['KeySchema'].append( + {'AttributeName': sort_key, 'KeyType': 'RANGE'} + ) + + return self.context.aws_util().dynamodb_import_table(import_table_request) + + def delete_imported_tables(self, table_names: List[str]): + """ Creates a thread that initates table deletion for all tables in the list + + Args: + table_names (List[str]): table_names of DynamoDB tables to be deleted + """ + for table_name in table_names: + try: + self.context.aws_util().dynamodb_delete_table(table_name) + except botocore.exceptions.ClientError as e: + self.logger.error(f"DynamoDB table {table_name} failed deletion with exception {e}") diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/apply_snapshot_version_control_helper.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/apply_snapshot_version_control_helper.py new file mode 100644 index 0000000..f3f222c --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/apply_snapshot_version_control_helper.py @@ -0,0 +1,50 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from ideadatamodel.snapshots.snapshot_model import RESVersion, TableName, TableKeys +from ideaclustermanager.app.snapshots.helpers.apply_snapshots_config import RES_VERSION_IN_TOPOLOGICAL_ORDER, TABLE_TO_TABLE_KEYS_BY_VERSION + +from typing import Dict, List + + +def get_table_keys_by_res_version(table_names: List[TableName], res_version: RESVersion) -> Dict[TableName, TableKeys]: + """Returns the partition_key and optional sort_key for all tables represented by the table_names list. + + It returns the partition_key and sort_key of a table based on the res version of the snapshot (as these keys can be updated for tables between versions). + + The TABLE_TO_TABLE_KEYS_BY_VERSION constant keeps track of these keys for each table by version. If it does not have an entry for a table + for a particular version, it uses a fallback mechanism to get the keys corresponding to the most recent res_version that was released before + the res_version of the snapshot + + Args: + table_names (List[TableName]): List of table_names for which the partition_key and optional sort_key should be returned + res_version (ResVersion): The res version of the Snapshot being applied + + Returns: + Dict[TableName, TableKeys]: returns a dict that includes the partition_key and optional sort_key for the requested tables + """ + snapshot_res_version_index = RES_VERSION_IN_TOPOLOGICAL_ORDER.index(res_version) + + response: Dict[TableName, TableKeys] = {} + + for table_name in table_names: + res_version_index = snapshot_res_version_index + table_key_details = TABLE_TO_TABLE_KEYS_BY_VERSION[table_name] + + while table_key_details.get(RES_VERSION_IN_TOPOLOGICAL_ORDER[res_version_index]) is None and res_version_index >= 0: + res_version_index -= 1 + + if res_version_index < 0: + error_message = f"Could not fetch partition_key and sort_key for {table_name} in RES_VERSION_IN_TOPOLOGICAL_ORDER dict" + raise RuntimeError(error_message) + response[table_name] = table_key_details.get(RES_VERSION_IN_TOPOLOGICAL_ORDER[res_version_index]) + + return response diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/apply_snapshots_config.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/apply_snapshots_config.py new file mode 100644 index 0000000..9632c5b --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/apply_snapshots_config.py @@ -0,0 +1,79 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from ideadatamodel.snapshots.snapshot_model import RESVersion, TableName, TableKeys +from ideaclustermanager.app.snapshots.apply_snapshot_data_transformation_from_version import data_transformation_from_2023_11 +from ideaclustermanager.app.snapshots.apply_snapshot_merge_table import ( + merge_table, + users_table_merger, + permission_profiles_table_merger, + projects_table_merger, + filesystems_cluster_settings_table_merger, + software_stacks_table_merger +) +from typing import Dict, Type + +# This array should be updated each release to include the new RES version number +RES_VERSION_IN_TOPOLOGICAL_ORDER = [RESVersion.v_2023_11, RESVersion.v_2024_01] + +TABLE_TO_TABLE_KEYS_BY_VERSION: Dict[TableName, Dict[RESVersion, TableKeys]] = { + TableName.CLUSTER_SETTINGS_TABLE_NAME: { + RESVersion.v_2023_11: TableKeys(partition_key='key') + }, + TableName.USERS_TABLE_NAME: { + RESVersion.v_2023_11: TableKeys(partition_key='username') + }, + TableName.PROJECTS_TABLE_NAME: { + RESVersion.v_2023_11: TableKeys(partition_key='project_id') + }, + TableName.PERMISSION_PROFILES_TABLE_NAME: { + RESVersion.v_2023_11: TableKeys(partition_key='profile_id') + }, + TableName.SOFTWARE_STACKS_TABLE_NAME: { + RESVersion.v_2023_11: TableKeys(partition_key='base_os', sort_key='stack_id') + } +} +""" +- This is a strictly additive list. +- New table addition -> Add an entry for the table, res_version the table is introduced, table's partition_key and sort_key. +- Existing table, keys change -> In the table dict, add an entry for the res_version in which the change is introduced, +table's updated partition_key and sort_key. +- Table deletion -> Do not remove the entry from this list. This is a strictly additive list. This is to maintain backward compatability. +""" + +RES_VERSION_TO_DATA_TRANSFORMATION_CLASS = { + RESVersion.v_2023_11: data_transformation_from_2023_11.TransformationFromVersion2023_11 +} +""" +- An entry must be added to this map when data transformation logic must be added for a version. +- Data transformation class naming converntion: TransformationFromVersion +- If a RES version does not have any schema changes, an entry does not need to be created in this map. +""" + +TABLES_IN_MERGE_DEPENDENCY_ORDER = [ + TableName.USERS_TABLE_NAME, + TableName.PERMISSION_PROFILES_TABLE_NAME, + TableName.PROJECTS_TABLE_NAME, + TableName.SOFTWARE_STACKS_TABLE_NAME, + TableName.CLUSTER_SETTINGS_TABLE_NAME, +] + +TABLE_TO_MERGE_LOGIC_CLASS: Dict[TableName, Type[merge_table.MergeTable]] = { + TableName.USERS_TABLE_NAME: users_table_merger.UsersTableMerger, + TableName.PERMISSION_PROFILES_TABLE_NAME: permission_profiles_table_merger.PermissionProfilesTableMerger, + TableName.PROJECTS_TABLE_NAME: projects_table_merger.ProjectsTableMerger, + TableName.CLUSTER_SETTINGS_TABLE_NAME: filesystems_cluster_settings_table_merger.FileSystemsClusterSettingTableMerger, + TableName.SOFTWARE_STACKS_TABLE_NAME: software_stacks_table_merger.SoftwareStacksTableMerger, +} +""" +- An entry must be added to this map for all tables that must be applied from snapshot. +- Merge logic class naming convention: TableMerger +""" diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/db_utils.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/db_utils.py new file mode 100644 index 0000000..ea64d66 --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/db_utils.py @@ -0,0 +1,180 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +# Note: The idea-virtual-desktop-controller module is not available in the cluster manager, so we have to +# duplicate the code for converting between Python dictionary and virtual desktop specific object here. + +from ideadatamodel import ( + VirtualDesktopSoftwareStack, + VirtualDesktopBaseOS, + VirtualDesktopArchitecture, + VirtualDesktopGPU, + VirtualDesktopPermission, + VirtualDesktopPermissionProfile, + Project, + SocaMemory, + SocaMemoryUnit, +) + +from ideasdk.utils.utils import Utils + +from typing import Dict, Optional + +SOFTWARE_STACK_DB_BASE_OS_KEY = "base_os" +SOFTWARE_STACK_DB_STACK_ID_KEY = "stack_id" +SOFTWARE_STACK_DB_NAME_KEY = "name" +SOFTWARE_STACK_DB_DESCRIPTION_KEY = "description" +SOFTWARE_STACK_DB_CREATED_ON_KEY = "created_on" +SOFTWARE_STACK_DB_UPDATED_ON_KEY = "updated_on" +SOFTWARE_STACK_DB_AMI_ID_KEY = "ami_id" +SOFTWARE_STACK_DB_ENABLED_KEY = "enabled" +SOFTWARE_STACK_DB_MIN_STORAGE_VALUE_KEY = "min_storage_value" +SOFTWARE_STACK_DB_MIN_STORAGE_UNIT_KEY = "min_storage_unit" +SOFTWARE_STACK_DB_MIN_RAM_VALUE_KEY = "min_ram_value" +SOFTWARE_STACK_DB_MIN_RAM_UNIT_KEY = "min_ram_unit" +SOFTWARE_STACK_DB_ARCHITECTURE_KEY = "architecture" +SOFTWARE_STACK_DB_GPU_KEY = "gpu" +SOFTWARE_STACK_DB_PROJECTS_KEY = "projects" +SOFTWARE_STACK_DB_PROJECT_ID_KEY = "project_id" +SOFTWARE_STACK_DB_PROJECT_NAME_KEY = "name" +SOFTWARE_STACK_DB_PROJECT_TITLE_KEY = "title" + +PERMISSION_PROFILE_DB_HASH_KEY = "profile_id" +PERMISSION_PROFILE_DB_TITLE_KEY = "title" +PERMISSION_PROFILE_DB_DESCRIPTION_KEY = "description" +PERMISSION_PROFILE_DB_CREATED_ON_KEY = "created_on" +PERMISSION_PROFILE_DB_UPDATED_ON_KEY = "updated_on" + + +def convert_db_dict_to_software_stack_object( + db_entry: dict, +) -> Optional[VirtualDesktopSoftwareStack]: + if Utils.is_empty(db_entry): + return None + + software_stack = VirtualDesktopSoftwareStack( + base_os=VirtualDesktopBaseOS(db_entry.get(SOFTWARE_STACK_DB_BASE_OS_KEY)), + stack_id=db_entry.get(SOFTWARE_STACK_DB_STACK_ID_KEY), + name=db_entry.get(SOFTWARE_STACK_DB_NAME_KEY), + description=db_entry.get(SOFTWARE_STACK_DB_DESCRIPTION_KEY), + created_on=Utils.to_datetime(db_entry.get(SOFTWARE_STACK_DB_CREATED_ON_KEY)), + updated_on=Utils.to_datetime(db_entry.get(SOFTWARE_STACK_DB_UPDATED_ON_KEY)), + ami_id=db_entry.get(SOFTWARE_STACK_DB_AMI_ID_KEY), + enabled=db_entry.get(SOFTWARE_STACK_DB_ENABLED_KEY), + min_storage=SocaMemory( + value=db_entry.get(SOFTWARE_STACK_DB_MIN_STORAGE_VALUE_KEY), + unit=SocaMemoryUnit(db_entry.get(SOFTWARE_STACK_DB_MIN_STORAGE_UNIT_KEY)), + ), + min_ram=SocaMemory( + value=db_entry.get(SOFTWARE_STACK_DB_MIN_RAM_VALUE_KEY), + unit=SocaMemoryUnit(db_entry.get(SOFTWARE_STACK_DB_MIN_RAM_UNIT_KEY)), + ), + architecture=VirtualDesktopArchitecture( + db_entry.get(SOFTWARE_STACK_DB_ARCHITECTURE_KEY) + ), + gpu=VirtualDesktopGPU(db_entry.get(SOFTWARE_STACK_DB_GPU_KEY)), + projects=[], + ) + + for project_id in db_entry.get(SOFTWARE_STACK_DB_PROJECTS_KEY, []): + software_stack.projects.append(Project(project_id=project_id)) + + return software_stack + + +def convert_software_stack_object_to_db_dict( + software_stack: VirtualDesktopSoftwareStack, +) -> Dict: + if Utils.is_empty(software_stack): + return {} + + db_dict = { + SOFTWARE_STACK_DB_BASE_OS_KEY: software_stack.base_os, + SOFTWARE_STACK_DB_STACK_ID_KEY: software_stack.stack_id, + SOFTWARE_STACK_DB_NAME_KEY: software_stack.name, + SOFTWARE_STACK_DB_DESCRIPTION_KEY: software_stack.description, + SOFTWARE_STACK_DB_CREATED_ON_KEY: Utils.to_milliseconds( + software_stack.created_on + ), + SOFTWARE_STACK_DB_UPDATED_ON_KEY: Utils.to_milliseconds( + software_stack.updated_on + ), + SOFTWARE_STACK_DB_AMI_ID_KEY: software_stack.ami_id, + SOFTWARE_STACK_DB_ENABLED_KEY: software_stack.enabled, + SOFTWARE_STACK_DB_MIN_STORAGE_VALUE_KEY: str(software_stack.min_storage.value), + SOFTWARE_STACK_DB_MIN_STORAGE_UNIT_KEY: software_stack.min_storage.unit, + SOFTWARE_STACK_DB_MIN_RAM_VALUE_KEY: str(software_stack.min_ram.value), + SOFTWARE_STACK_DB_MIN_RAM_UNIT_KEY: software_stack.min_ram.unit, + SOFTWARE_STACK_DB_ARCHITECTURE_KEY: software_stack.architecture, + SOFTWARE_STACK_DB_GPU_KEY: software_stack.gpu, + } + + project_ids = [] + if software_stack.projects: + for project in software_stack.projects: + project_ids.append(project.project_id) + + db_dict[SOFTWARE_STACK_DB_PROJECTS_KEY] = project_ids + return db_dict + + +def convert_db_dict_to_permission_profile_object( + db_dict: Dict, permission_types: list[VirtualDesktopPermission] +) -> Optional[VirtualDesktopPermissionProfile]: + permission_profile = VirtualDesktopPermissionProfile( + profile_id=db_dict.get(PERMISSION_PROFILE_DB_HASH_KEY, ""), + title=db_dict.get(PERMISSION_PROFILE_DB_TITLE_KEY, ""), + description=db_dict.get(PERMISSION_PROFILE_DB_DESCRIPTION_KEY, ""), + permissions=[], + created_on=Utils.to_datetime( + db_dict.get(PERMISSION_PROFILE_DB_CREATED_ON_KEY, 0) + ), + updated_on=Utils.to_datetime( + db_dict.get(PERMISSION_PROFILE_DB_UPDATED_ON_KEY, 0) + ), + ) + + for permission_type in permission_types: + permission_profile.permissions.append( + VirtualDesktopPermission( + key=permission_type.key, + name=permission_type.name, + description=permission_type.description, + enabled=db_dict.get(permission_type.key, False), + ) + ) + + return permission_profile + + +def convert_permission_profile_object_to_db_dict( + permission_profile: VirtualDesktopPermissionProfile, + permission_types: list[VirtualDesktopPermission], +) -> Dict: + db_dict = { + PERMISSION_PROFILE_DB_HASH_KEY: permission_profile.profile_id, + PERMISSION_PROFILE_DB_TITLE_KEY: permission_profile.title, + PERMISSION_PROFILE_DB_DESCRIPTION_KEY: permission_profile.description, + PERMISSION_PROFILE_DB_CREATED_ON_KEY: Utils.to_milliseconds( + permission_profile.created_on + ), + PERMISSION_PROFILE_DB_UPDATED_ON_KEY: Utils.to_milliseconds( + permission_profile.updated_on + ), + } + + for permission_type in permission_types: + db_dict[permission_type.key] = False + permission_entry = permission_profile.get_permission(permission_type.key) + if Utils.is_not_empty(permission_entry): + db_dict[permission_type.key] = permission_entry.enabled + + return db_dict diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/merged_record_utils.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/merged_record_utils.py new file mode 100644 index 0000000..4c3a3d6 --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/helpers/merged_record_utils.py @@ -0,0 +1,34 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from enum import Enum +from pydantic import BaseModel +from typing import Dict, Optional + + +class MergedRecordActionType(Enum): + CREATE = "create" + UPDATE = "update" + DELETE = "delete" + + +class MergedRecordDelta(BaseModel): + """ + Class for storing the delta after merging a table record: + original_record: Existing record under the RES environment before applying snapshot + snapshot_record: Record from the snapshot to apply + resolved_record: Record being merged after resolving the conflict between original_record and snapshot_record + action_performed: Action performed for applying snapshot + """ + original_record: Optional[Dict] = None + snapshot_record: Dict + resolved_record: Optional[Dict] = None + action_performed: Optional[MergedRecordActionType] = None diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/snapshot_constants.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/snapshot_constants.py index c574b8a..d5a0f3c 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/snapshot_constants.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/snapshot_constants.py @@ -11,3 +11,7 @@ SNAPSHOT_S3_BUCKET_NAME_REGEX = r'^[a-z0-9]+[\.\-\w]*[a-z0-9]+$' SNAPSHOT_PATH_REGEX = r'^([\w\.\-\!\*\'\(\)]+[\/]*)+$' + +METADATA_FILE_NAME_AND_EXTENSION = "metadata.json" +TABLE_EXPORT_DESCRIPTION_KEY = "table_export_descriptions" +VERSION_KEY = "version" diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/snapshots_service.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/snapshots_service.py index 5989221..04decee 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/snapshots_service.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/snapshots/snapshots_service.py @@ -11,13 +11,15 @@ from ideasdk.context import SocaContext from ideadatamodel.snapshots import ( - Snapshot, SnapshotStatus, ListSnapshotsRequest, ListSnapshotsResult + Snapshot, SnapshotStatus, ListSnapshotsRequest, ListSnapshotsResult, ListApplySnapshotRecordsRequest, ListApplySnapshotRecordsResult ) from ideadatamodel import exceptions from ideasdk.context import ArnBuilder from ideasdk.utils import Utils +from ideaclustermanager.app.snapshots.apply_snapshot import ApplySnapshot from ideaclustermanager.app.snapshots import snapshot_constants -from ideaclustermanager.app.snapshots.snapshot_dao import SnapshotDAO +from ideaclustermanager.app.snapshots.db.apply_snapshot_dao import ApplySnapshotDAO +from ideaclustermanager.app.snapshots.db.snapshot_dao import SnapshotDAO import botocore.exceptions import re @@ -70,6 +72,7 @@ class SnapshotsService: 1. Creating Snapshots 2. Listing Snapshots 3. Deleting Snapshots + 4. Applying Snapshots The service is primarily invoked via AuthAPI and SnapshotsAPI """ @@ -81,6 +84,8 @@ def __init__(self, context: SocaContext): self.logger = context.logger('snapshots-service') self.snapshot_dao = SnapshotDAO(context) self.snapshot_dao.initialize() + self.apply_snapshot_dao = ApplySnapshotDAO(context) + self.apply_snapshot_dao.initialize() def create_snapshot(self, snapshot: Snapshot): """ @@ -119,13 +124,13 @@ def create_snapshot(self, snapshot: Snapshot): table_export_descriptions[table_name] = export_response['ExportDescription'] metadata = { - 'table_export_descriptions': table_export_descriptions, + snapshot_constants.TABLE_EXPORT_DESCRIPTION_KEY: table_export_descriptions, 'version': self.context.module_version() } self.context.aws().s3().put_object( Bucket=snapshot.s3_bucket_name, - Key=f'{snapshot.snapshot_path}/metadata.json', + Key=f'{snapshot.snapshot_path}/{snapshot_constants.METADATA_FILE_NAME_AND_EXTENSION}', Body=json.dumps(metadata, default=str) ) self.snapshot_dao.create_snapshot({ @@ -146,10 +151,13 @@ def create_snapshot(self, snapshot: Snapshot): def __update_snapshot_status(self, snapshot: Snapshot): try: - metadata_s3_object = self.context.aws().s3().get_object(Bucket=snapshot.s3_bucket_name, Key=f'{snapshot.snapshot_path}/metadata.json') + metadata_s3_object = self.context.aws().s3().get_object( + Bucket=snapshot.s3_bucket_name, + Key=f'{snapshot.snapshot_path}/{snapshot_constants.METADATA_FILE_NAME_AND_EXTENSION}', + ) metadata_file_content = metadata_s3_object['Body'].read().decode('utf-8') metadata = json.loads(metadata_file_content) - table_export_descriptions = metadata['table_export_descriptions'] + table_export_descriptions = metadata[snapshot_constants.TABLE_EXPORT_DESCRIPTION_KEY] is_export_status_updated = False export_completed_tables_count = 0 is_export_failed = False @@ -157,7 +165,9 @@ def __update_snapshot_status(self, snapshot: Snapshot): for table_name in DYNAMODB_TABLES_TO_EXPORT: table_export_description = table_export_descriptions[table_name] if table_export_description['ExportStatus'] == SnapshotStatus.IN_PROGRESS: - describe_export_response = self.context.aws().dynamodb().describe_export(ExportArn=table_export_description['ExportArn']) + describe_export_response = ( + self.context.aws().dynamodb().describe_export(ExportArn=table_export_description['ExportArn']) + ) latest_table_export_description = describe_export_response['ExportDescription'] if latest_table_export_description['ExportStatus'] == SnapshotStatus.COMPLETED: export_completed_tables_count += 1 @@ -178,10 +188,10 @@ def __update_snapshot_status(self, snapshot: Snapshot): break if is_export_status_updated: - metadata['table_export_descriptions'] = table_export_descriptions + metadata[snapshot_constants.TABLE_EXPORT_DESCRIPTION_KEY] = table_export_descriptions self.context.aws().s3().put_object( Bucket=snapshot.s3_bucket_name, - Key=f'{snapshot.snapshot_path}/metadata.json', + Key=f'{snapshot.snapshot_path}/{snapshot_constants.METADATA_FILE_NAME_AND_EXTENSION}', Body=json.dumps(metadata, default=str) ) if is_export_failed: @@ -203,3 +213,14 @@ def list_snapshots(self, request: ListSnapshotsRequest) -> ListSnapshotsResult: self.__update_snapshot_status(snapshot) return list_snapshots_response + + def apply_snapshot(self, snapshot: Snapshot): + """ + apply a snapshot + """ + apply_snapshot = ApplySnapshot(snapshot, self.apply_snapshot_dao, self.context) + apply_snapshot.initialize() + + def list_applied_snapshots(self, request: ListApplySnapshotRecordsRequest) -> ListApplySnapshotRecordsResult: + return self.apply_snapshot_dao.list(request) + diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/web_portal.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/web_portal.py index 2e5e510..ceac795 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/app/web_portal.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/app/web_portal.py @@ -241,46 +241,58 @@ async def sso_oauth2_callback_route(self, http_request): claims = self.context.token_service.decode_token(auth_result.access_token) self.logger.info(f'sso token claims: {Utils.to_json(claims)}') - username = Utils.get_value_as_string('username', claims) - self.logger.debug(f'Cognito SSO claims - Username: {username}') + cognito_username = Utils.get_value_as_string('username', claims) + self.logger.debug(f'Cognito SSO claims - Username: {cognito_username}') - if Utils.is_empty(username): - self.logger.exception(f'Error - Unable to read IdP username from claims: {claims}') + if not cognito_username: + self.logger.exception(f'Error - Unable to read cognito username from claims: {claims}') return sanic.response.redirect(f'{self.web_resources_context_path}?sso_auth_status=FAIL', headers=DEFAULT_HTTP_HEADERS) - self.logger.debug(f'SSO auth: Looking up user: {username}') - existing_user = self.context.accounts.user_dao.get_user(username=username) - - if existing_user is None: - self.logger.warning(f'SSO auth: {username} is not previously known. User must be synced before first SSO attempt') - # TODO auto-enrollment would go here - self.logger.info(f'SSO auth: Unable to process {username}') - self.logger.info(f'disabling federated user in user pool: {username}') - self.context.user_pool.admin_disable_user(username=username) - self.logger.info(f'deleting federated user in user pool: {username}') - self.context.user_pool.admin_delete_user(username=username) + email = self.context.token_service.get_email_from_token_username(token_username=cognito_username) + if not email: + self.logger.exception(f'Error: No email defined for cognito user {cognito_username}') + + existing_user = self.context.accounts.get_user_by_email(email=email) + if not existing_user: + self.logger.info(f'SSO auth: Unable to process user with email {email}') + self.logger.info(f'Disabling federated user in user pool: {cognito_username}') + self.context.user_pool.admin_disable_user(username=cognito_username) + self.logger.info(f'Deleting federated user in user pool: {cognito_username}') + self.context.user_pool.admin_delete_user(username=cognito_username) try: self.logger.info(f'Deleting state {state}') self.context.accounts.sso_state_dao.delete_sso_state(state) except: self.logger.info(f'Could not delete state {state}') return sanic.response.redirect(f'{self.web_resources_context_path}?sso_auth_status=FAIL&error_msg=UserNotFound', - headers=DEFAULT_HTTP_HEADERS) - - self.logger.debug(f'Updating SSO State for user {username}: {state}') + headers=DEFAULT_HTTP_HEADERS) + + if not existing_user.enabled: + self.logger.error(f'User {existing_user.username} is disabled. Login Denied.') + cognito_domain_url = self.context.config().get_string('identity-provider.cognito.domain_url') + _, logout_urls = self.context.accounts.single_sign_on_helper.get_callback_logout_urls() + client_id = claims.get('client_id') + if logout_urls and client_id and cognito_domain_url: + logout_url = logout_urls[-1] + self.context.accounts.sign_out(auth_result.refresh_token, sso_enabled) + return sanic.response.redirect(f'{cognito_domain_url}/logout?client_id={client_id}&logout_uri={logout_url}') + return sanic.response.redirect(f'{self.web_resources_context_path}?sso_auth_status=FAIL&error_msg=UserNotFound', + headers=DEFAULT_HTTP_HEADERS) + + self.logger.debug(f'Updating SSO State for user {cognito_username}: {state}') self.context.accounts.sso_state_dao.update_sso_state({ 'state': state, 'access_token': auth_result.access_token, 'refresh_token': auth_result.refresh_token, 'expires_in': auth_result.expires_in, 'id_token': auth_result.id_token, - 'token_type': auth_result.token_type + 'token_type': auth_result.token_type, }) - if not existing_user["is_active"]: - user = self.context.accounts.user_dao.convert_from_db(existing_user) - self.context.accounts.activate_user(user) + if not existing_user.is_active: + user = self.context.accounts.get_user(existing_user.username) + self.context.accounts.activate_user(existing_user=user) return sanic.response.redirect(f'{self.web_resources_context_path}?sso_auth_status=SUCCESS&sso_auth_code={state}', headers=DEFAULT_HTTP_HEADERS) diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/cli/cli_main.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/cli/cli_main.py index 6b33662..eae9bd1 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager/cli/cli_main.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/cli/cli_main.py @@ -17,6 +17,7 @@ from ideaclustermanager.cli.groups import groups from ideaclustermanager.cli.ldap_commands import ldap_commands from ideaclustermanager.cli.module import app_module_clean_up +from ideaclustermanager.cli.snapshots import snapshots import sys import click @@ -34,6 +35,7 @@ def main(): main.add_command(logs) main.add_command(accounts) main.add_command(groups) +main.add_command(snapshots) main.add_command(ldap_commands) main.add_command(app_module_clean_up) diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager/cli/snapshots.py b/source/idea/idea-cluster-manager/src/ideaclustermanager/cli/snapshots.py new file mode 100644 index 0000000..97bed52 --- /dev/null +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager/cli/snapshots.py @@ -0,0 +1,32 @@ +from ideadatamodel import constants, ApplySnapshotRequest, ApplySnapshotResult + +from ideaclustermanager.cli import build_cli_context +import click + +@click.group() +def snapshots(): + """ + snapshot management options + """ + +@snapshots.command(context_settings=constants.CLICK_SETTINGS) +@click.option('--s3_bucket_name', required=True, help='S3 Bucket Name to retrieve the snapshot from') +@click.option('--snapshot_path', required=True, help='Path in the S3 bucket to retrieve the snapshot from') +def apply_snapshot(**kwargs): + """ + apply snapshot + """ + request = { + 'snapshot': { + 's3_bucket_name': kwargs.get('s3_bucket_name'), + 'snapshot_path': kwargs.get('snapshot_path') + } + } + + context = build_cli_context() + result = context.unix_socket_client.invoke_alt( + namespace='Snapshots.ApplySnapshot', + payload=ApplySnapshotRequest(**request), + result_as=ApplySnapshotResult + ) + print(result) \ No newline at end of file diff --git a/source/idea/idea-cluster-manager/src/ideaclustermanager_meta/__init__.py b/source/idea/idea-cluster-manager/src/ideaclustermanager_meta/__init__.py index 206ddb5..27bbf2a 100644 --- a/source/idea/idea-cluster-manager/src/ideaclustermanager_meta/__init__.py +++ b/source/idea/idea-cluster-manager/src/ideaclustermanager_meta/__init__.py @@ -10,4 +10,4 @@ # and limitations under the License. __name__ = 'idea-cluster-manager' -__version__ = '2023.11' +__version__ = '2024.01' diff --git a/source/idea/idea-cluster-manager/webapp/.env b/source/idea/idea-cluster-manager/webapp/.env index 81f4d24..024a96d 100644 --- a/source/idea/idea-cluster-manager/webapp/.env +++ b/source/idea/idea-cluster-manager/webapp/.env @@ -1,4 +1,4 @@ REACT_APP_IDEA_HTTP_ENDPOINT="http://localhost:8080" REACT_APP_IDEA_ALB_ENDPOINT="http://localhost:8080" REACT_APP_IDEA_HTTP_API_SUFFIX="/api/v1" -REACT_APP_IDEA_RELEASE_VERSION="2023.11" +REACT_APP_IDEA_RELEASE_VERSION="2024.01" diff --git a/source/idea/idea-cluster-manager/webapp/package.json b/source/idea/idea-cluster-manager/webapp/package.json index 76054ce..d8011a4 100644 --- a/source/idea/idea-cluster-manager/webapp/package.json +++ b/source/idea/idea-cluster-manager/webapp/package.json @@ -1,6 +1,6 @@ { "name": "web-portal", - "version": "2023.11", + "version": "2024.01", "private": true, "dependencies": { "@cloudscape-design/components": "^3.0.82", diff --git a/source/idea/idea-cluster-manager/webapp/src/App.tsx b/source/idea/idea-cluster-manager/webapp/src/App.tsx index 4150c60..998c76d 100644 --- a/source/idea/idea-cluster-manager/webapp/src/App.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/App.tsx @@ -50,7 +50,7 @@ import { IdeaAppNavigationProps, withRouter } from "./navigation/navigation-util import { Routes, Route, Navigate } from "react-router-dom"; import IdeaLogTail from "./pages/home/log-tail"; import Utils from "./common/utils"; -import Snapshots from "./pages/snapshots/snapshots"; +import SnapshotManagement from "./pages/snapshots/snapshot-management" export interface IdeaWebPortalAppProps extends IdeaAppNavigationProps {} @@ -562,7 +562,7 @@ class IdeaWebPortalApp extends Component */} - { - queryOpenSearch(req: OpenSearchQueryRequest): Promise { - return this.apiInvoker.invoke_alt("Analytics.OpenSearchQuery", req); - } -} - -export default AnalyticsClient; diff --git a/source/idea/idea-cluster-manager/webapp/src/client/clients.ts b/source/idea/idea-cluster-manager/webapp/src/client/clients.ts index faa5f78..22d39ee 100644 --- a/source/idea/idea-cluster-manager/webapp/src/client/clients.ts +++ b/source/idea/idea-cluster-manager/webapp/src/client/clients.ts @@ -20,7 +20,6 @@ import FileBrowserClient from "./file-browser-client"; import VirtualDesktopClient from "./virtual-desktop-client"; import VirtualDesktopAdminClient from "./virtual-desktop-admin-client"; import ClusterSettingsClient from "./cluster-settings-client"; -import AnalyticsClient from "./analytics-client"; import ProjectsClient from "./projects-client"; import EmailTemplatesClient from "./email-templates-client"; import Utils from "../common/utils"; @@ -50,7 +49,6 @@ class IdeaClients { private readonly virtualDesktopUtilsClient: VirtualDesktopUtilsClient; private readonly virtualDesktopDCVClient: VirtualDesktopDCVClient; private readonly clusterSettingsClient: ClusterSettingsClient; - private readonly analyticsClient: AnalyticsClient; private readonly projectsClient: ProjectsClient; private readonly filesystemClient: FileSystemClient; private readonly emailTemplatesClient: EmailTemplatesClient; @@ -159,15 +157,6 @@ class IdeaClients { }); this.clients.push(this.clusterSettingsClient); - this.analyticsClient = new AnalyticsClient({ - name: "analytics-client", - baseUrl: props.baseUrl, - authContext: props.authContext, - apiContextPath: Utils.getApiContextPath(Constants.MODULE_CLUSTER_MANAGER), - serviceWorkerRegistration: props.serviceWorkerRegistration, - }); - this.clients.push(this.analyticsClient); - this.projectsClient = new ProjectsClient({ name: "projects-client", baseUrl: props.baseUrl, @@ -244,10 +233,6 @@ class IdeaClients { return this.clusterSettingsClient; } - analytics(): AnalyticsClient { - return this.analyticsClient; - } - projects(): ProjectsClient { return this.projectsClient; } diff --git a/source/idea/idea-cluster-manager/webapp/src/client/data-model.ts b/source/idea/idea-cluster-manager/webapp/src/client/data-model.ts index 9023df2..def9ae7 100644 --- a/source/idea/idea-cluster-manager/webapp/src/client/data-model.ts +++ b/source/idea/idea-cluster-manager/webapp/src/client/data-model.ts @@ -40,7 +40,8 @@ export type SocaUserInputParamType = | "datepicker" | "radio-group" | "file-upload" - | "tiles"; + | "tiles" + | "container"; export type VirtualDesktopBaseOS = "amazonlinux2" | "centos7" | "rhel7" | "rhel8" | "rhel9" | "windows"; export type SocaMemoryUnit = "bytes" | "kib" | "mib" | "gib" | "tib" | "kb" | "mb" | "gb" | "tb"; export type VirtualDesktopArchitecture = "x86_64" | "arm64"; @@ -62,6 +63,7 @@ export type SocaComputeNodeSharing = "default-excl" | "default-exlchost" | "defa export type SocaJobPlacementArrangement = "free" | "pack" | "scatter" | "vscatter"; export type SocaJobPlacementSharing = "excl" | "shared" | "exclhost" | "vscatter"; export type SnapshotStatus = "IN_PROGRESS" | "COMPLETED" | "FAILED"; +export type ApplySnapshotStatus = "IN_PROGRESS" | "COMPLETED" | "FAILED" | "ROLLBACK_IN_PROGRESS" | "ROLLBACK_COMPLETE" | "ROLLBACK_FAILED"; export interface FileList { cwd?: string; @@ -115,6 +117,7 @@ export interface Snapshot { snapshot_path?: string; status?: SnapshotStatus; created_on?: string; + failure_reason?: string } export interface ListSnapshotsRequest { @@ -133,6 +136,30 @@ export interface ListSnapshotsResult { filters?: SocaFilter[]; } +export interface ApplySnapshot { + s3_bucket_name?: string; + snapshot_path?: string; + status?: ApplySnapshotStatus; + created_on?: string; + failure_reason?: string +} + +export interface ListApplySnapshotRecordsRequest { + paginator?: SocaPaginator; + sort_by?: SocaSortBy; + date_range?: SocaDateRange; + listing?: (SocaBaseModel | unknown)[]; + filters?: SocaFilter[]; +} + +export interface ListApplySnapshotRecordsResult { + paginator?: SocaPaginator; + sort_by?: SocaSortBy; + date_range?: SocaDateRange; + listing?: Snapshot[]; + filters?: SocaFilter[]; +} + export interface CreateFileResult {} export interface DisableGroupResult {} export interface DeleteHpcLicenseResourceRequest { @@ -199,6 +226,8 @@ export interface SocaUserInputParamMetadata { custom?: { [k: string]: unknown; }; + container_items?: SocaUserInputParamMetadata[]; + custom_error_message?: string; } export interface SocaUserInputValidate { eq?: unknown; @@ -338,6 +367,7 @@ export interface Project { description?: string; enabled?: boolean; ldap_groups?: string[]; + users?: string[]; enable_budgets?: boolean; budget?: AwsProjectBudget; tags?: SocaKeyValue[]; @@ -645,7 +675,7 @@ export interface ListJobsRequest { export interface InitiateAuthRequest { client_id?: string; auth_flow?: string; - username?: string; + cognito_username?: string; password?: string; refresh_token?: string; authorization_code?: string; @@ -873,6 +903,7 @@ export interface VirtualDesktopSession { is_launched_by_admin?: boolean; locked?: boolean; failure_reason?: string; + tags?: Record[] } export interface VirtualDesktopServer { server_id?: string; @@ -1012,6 +1043,8 @@ export interface InitiateAuthResult { [k: string]: unknown; }; auth?: AuthResult; + db_username?: string; + role?: string; } export interface AuthenticateUserResult { status?: boolean; @@ -1238,6 +1271,15 @@ export interface CreateSnapshotRequest { export interface CreateSnapshotResult { result?: string; } + +export interface ApplySnapshotRequest { + snapshot?: Snapshot +} + +export interface ApplySnapshotResult { + message?: string +} + export interface SendNotificationRequest { notification?: Notification; } @@ -1568,7 +1610,6 @@ export interface CreateQueueProfileResult { queue_profile?: HpcQueueProfile; validation_errors?: JobValidationResult; } -export interface ReIndexUserSessionsRequest {} export interface CheckHpcLicenseResourceAvailabilityResult { available_count?: number; } @@ -1579,11 +1620,6 @@ export interface ListSessionsRequest { listing?: (SocaBaseModel | unknown)[]; filters?: SocaFilter[]; } -export interface OpenSearchQueryResult { - data?: { - [k: string]: unknown; - }; -} export interface UpdatePermissionProfileRequest { profile?: VirtualDesktopPermissionProfile; } @@ -1724,7 +1760,6 @@ export interface SocaUserInputTag { export interface EnableGroupRequest { group_name?: string; } -export interface ReIndexUserSessionsResponse {} export interface GetParamDefaultRequest { module?: string; param?: string; @@ -1810,11 +1845,6 @@ export interface DeleteSessionResponse { failed?: VirtualDesktopSession[]; success?: VirtualDesktopSession[]; } -export interface OpenSearchQueryRequest { - data?: { - [k: string]: unknown; - }; -} export interface UpdateHpcLicenseResourceRequest { license_resource?: HpcLicenseResource; dry_run?: boolean; @@ -1835,7 +1865,6 @@ export interface ModifyUserRequest { email_verified?: boolean; } export interface EnableGroupResult {} -export interface ReIndexSoftwareStacksRequest {} export interface GetParamDefaultResult { default?: unknown; } @@ -2006,7 +2035,6 @@ export interface TailFileRequest { line_count?: number; next_token?: string; } -export interface ReIndexSoftwareStacksResponse {} export interface BatchCreateSessionResponse { failed?: VirtualDesktopSession[]; success?: VirtualDesktopSession[]; @@ -2111,6 +2139,7 @@ export interface ConfigureSSORequest { export interface ConfigureSSOResponse {} export interface GetSoftwareStackInfoRequest { stack_id?: string; + base_os?: string } export interface AddAdminUserResult { user?: User; diff --git a/source/idea/idea-cluster-manager/webapp/src/client/snapshots-client.ts b/source/idea/idea-cluster-manager/webapp/src/client/snapshots-client.ts index 94ea542..e7d496d 100644 --- a/source/idea/idea-cluster-manager/webapp/src/client/snapshots-client.ts +++ b/source/idea/idea-cluster-manager/webapp/src/client/snapshots-client.ts @@ -12,12 +12,16 @@ */ import { + ApplySnapshotRequest, + ApplySnapshotResult, GetModuleInfoRequest, GetModuleInfoResult, CreateSnapshotRequest, CreateSnapshotResult, ListSnapshotsRequest, - ListSnapshotsResult + ListSnapshotsResult, + ListApplySnapshotRecordsRequest, + ListApplySnapshotRecordsResult } from "./data-model"; import IdeaBaseClient, { IdeaBaseClientProps } from "./base-client"; @@ -34,6 +38,12 @@ class SnapshotsClient extends IdeaBaseClient { listSnapshots(req?: ListSnapshotsRequest): Promise { return this.apiInvoker.invoke_alt("Snapshots.ListSnapshots", req); } + applySnapshot(req: ApplySnapshotRequest): Promise { + return this.apiInvoker.invoke_alt("Snapshots.ApplySnapshot", req) + } + listAppliedSnapshots(req?: ListApplySnapshotRecordsRequest): Promise{ + return this.apiInvoker.invoke_alt("Snapshots.ListAppliedSnapshots", req) + } } export default SnapshotsClient; diff --git a/source/idea/idea-cluster-manager/webapp/src/common/authentication-context.ts b/source/idea/idea-cluster-manager/webapp/src/common/authentication-context.ts index 291c4e8..9c2b901 100644 --- a/source/idea/idea-cluster-manager/webapp/src/common/authentication-context.ts +++ b/source/idea/idea-cluster-manager/webapp/src/common/authentication-context.ts @@ -32,6 +32,8 @@ const KEY_REFRESH_TOKEN = "refresh-token"; const KEY_ACCESS_TOKEN = "access-token"; const KEY_ID_TOKEN = "id-token"; const KEY_SSO_AUTH = "sso-auth"; +const KEY_DB_USERNAME = "db-username"; +const KEY_ROLE = "role"; const HEADER_CONTENT_TYPE_JSON = "application/json;charset=UTF-8"; const NETWORK_TIMEOUT = 30000; @@ -72,7 +74,8 @@ export class IdeaAuthenticationContext { private accessToken: string | null; private idToken: string | null; private claimsProvider: JwtTokenClaimsProvider | null; - + private dbUsername: string | null; + private role: string | null; private logger: AppLogger; private authContextInitialized: boolean; @@ -96,7 +99,8 @@ export class IdeaAuthenticationContext { this.accessToken = null; this.idToken = null; this.claimsProvider = null; - + this.dbUsername = null; + this.role = null; this.props = props; this.isServiceworkerInitialized = false; @@ -122,13 +126,15 @@ export class IdeaAuthenticationContext { this.accessToken = this.localStorage.getItem(KEY_ACCESS_TOKEN); this.idToken = this.localStorage.getItem(KEY_ID_TOKEN); this.refreshToken = this.localStorage.getItem(KEY_REFRESH_TOKEN); + this.dbUsername = this.localStorage.getItem(KEY_DB_USERNAME); + this.role = this.localStorage.getItem(KEY_ROLE); let ssoAuth = this.localStorage.getItem(KEY_SSO_AUTH); if (ssoAuth != null) { this.ssoAuth = Utils.asBoolean(ssoAuth); } - if (this.accessToken != null && this.idToken != null) { - this.claimsProvider = new JwtTokenClaimsProvider(this.accessToken, this.idToken); + if (this.accessToken != null && this.idToken != null && this.dbUsername != null && this.role != null) { + this.claimsProvider = new JwtTokenClaimsProvider(this.accessToken, this.idToken, this.dbUsername, this.role); } } @@ -174,22 +180,29 @@ export class IdeaAuthenticationContext { * @param ssoAuth * @private */ - private saveAuthResult(authResult: any, ssoAuth: boolean) { - if (authResult.refresh_token) { - this.refreshToken = authResult.refresh_token; + private saveAuthResult(initiateAuthResult: any, ssoAuth: boolean) { + if (initiateAuthResult.auth.refresh_token) { + this.refreshToken = initiateAuthResult.auth.refresh_token; } - this.accessToken = authResult.access_token; - this.idToken = authResult.id_token; - this.claimsProvider = new JwtTokenClaimsProvider(this.accessToken!, this.idToken!); + this.accessToken = initiateAuthResult.auth.access_token; + this.idToken = initiateAuthResult.auth.id_token; + this.dbUsername = initiateAuthResult.db_username; + this.role = initiateAuthResult.role; + this.claimsProvider = new JwtTokenClaimsProvider(this.accessToken!, + this.idToken!, + this.dbUsername!, + this.role!); this.ssoAuth = ssoAuth; if (this.localStorage != null) { - if (authResult.refresh_token) { - this.localStorage.setItem(KEY_REFRESH_TOKEN, authResult.refresh_token!); + if (initiateAuthResult.auth.refresh_token) { + this.localStorage.setItem(KEY_REFRESH_TOKEN, initiateAuthResult.auth.refresh_token!); } this.localStorage.setItem(KEY_SSO_AUTH, ssoAuth ? "true" : "false"); - this.localStorage.setItem(KEY_ACCESS_TOKEN, authResult.access_token!); - this.localStorage.setItem(KEY_ID_TOKEN, authResult.id_token!); + this.localStorage.setItem(KEY_ACCESS_TOKEN, initiateAuthResult.auth.access_token!); + this.localStorage.setItem(KEY_ID_TOKEN, initiateAuthResult.auth.id_token!); + this.localStorage.setItem(KEY_DB_USERNAME, initiateAuthResult.db_username!); + this.localStorage.setItem(KEY_ROLE, initiateAuthResult.role!); } } @@ -213,8 +226,9 @@ export class IdeaAuthenticationContext { if (this.localStorage != null) { // this is primarily to allow force token renewal in local storage mode for testing, by deleting the access token from local storage return this.renewAccessToken().then(() => { - if (this.accessToken != null && this.idToken != null) { - this.claimsProvider = new JwtTokenClaimsProvider(this.accessToken, this.idToken); + if (this.accessToken != null && this.idToken != null && this.dbUsername != null && this.role != null) { + this.claimsProvider = new JwtTokenClaimsProvider(this.accessToken, + this.idToken, this.dbUsername, this.role); return true; } else { return false; @@ -256,11 +270,15 @@ export class IdeaAuthenticationContext { this.accessToken = null; this.idToken = null; this.claimsProvider = null; + this.dbUsername = null; + this.role = null; if (this.localStorage != null) { this.localStorage.removeItem(KEY_ACCESS_TOKEN); this.localStorage.removeItem(KEY_REFRESH_TOKEN); this.localStorage.removeItem(KEY_SSO_AUTH); this.localStorage.removeItem(KEY_ID_TOKEN); + this.localStorage.removeItem(KEY_DB_USERNAME); + this.localStorage.removeItem(KEY_ROLE); } return true; }); @@ -368,20 +386,20 @@ export class IdeaAuthenticationContext { this.logger.info("renewing access token ..."); - let username; + let cognito_username; if (this.claimsProvider == null) { if (this.accessToken != null) { let claims = JwtTokenUtils.parseJwtToken(this.accessToken); - username = claims.username; + cognito_username = claims.username; } else if (this.idToken != null) { let claims = JwtTokenUtils.parseJwtToken(this.idToken); - username = claims["cognito:username"]; + cognito_username = claims["cognito:username"]; } else { console.info("✗ failed to renew token."); return Promise.resolve(false); } } else { - username = this.claimsProvider.getUsername(); + cognito_username = this.claimsProvider.getCognitoUsername(); } if (this.renewalInProgress != null) { @@ -400,7 +418,7 @@ export class IdeaAuthenticationContext { }, payload: { auth_flow: authFlow, - username: username, + cognito_username: cognito_username, refresh_token: this.refreshToken, }, }; @@ -416,7 +434,7 @@ export class IdeaAuthenticationContext { .then((result) => { if (result.success && result.payload.auth) { this.logger.info("✓ access token renewed successfully"); - this.saveAuthResult(result.payload.auth, this.ssoAuth); + this.saveAuthResult(result.payload, this.ssoAuth); return true; } else { if (result.error_code === NETWORK_TIMEOUT || result.error_code === NETWORK_ERROR || result.error_code === SERVER_ERROR) { @@ -486,8 +504,7 @@ export class IdeaAuthenticationContext { // all subsequent API invocations will be attached with the Authorization header. this.logger.debug("✓ initiate auth successful"); const isSsoAuth = request.payload.auth_flow === "SSO_AUTH"; - this.saveAuthResult(result.payload.auth, isSsoAuth); - + this.saveAuthResult(result.payload, isSsoAuth); return { success: true, payload: {}, diff --git a/source/idea/idea-cluster-manager/webapp/src/common/constants.ts b/source/idea/idea-cluster-manager/webapp/src/common/constants.ts index 776008c..405238f 100644 --- a/source/idea/idea-cluster-manager/webapp/src/common/constants.ts +++ b/source/idea/idea-cluster-manager/webapp/src/common/constants.ts @@ -12,6 +12,8 @@ */ export const Constants = { + ADMIN_ROLE: 'admin', + USER_ROLE: 'user', MODULE_VIRTUAL_DESKTOP_CONTROLLER: "virtual-desktop-controller", MODULE_SCHEDULER: "scheduler", MODULE_DIRECTORY_SERVICE: "directoryservice", @@ -19,7 +21,6 @@ export const Constants = { MODULE_SHARED_STORAGE: "shared-storage", MODULE_METRICS: "metrics", MODULE_BASTION_HOST: "bastion-host", - MODULE_ANALYTICS: "analytics", MODULE_CLUSTER: "cluster", MODULE_CLUSTER_MANAGER: "cluster-manager", MODULE_GLOBAL_SETTINGS: "global-settings", diff --git a/source/idea/idea-cluster-manager/webapp/src/common/token-utils.ts b/source/idea/idea-cluster-manager/webapp/src/common/token-utils.ts index 8cf7ab2..f2e9f04 100644 --- a/source/idea/idea-cluster-manager/webapp/src/common/token-utils.ts +++ b/source/idea/idea-cluster-manager/webapp/src/common/token-utils.ts @@ -30,14 +30,22 @@ export class JwtTokenUtils { export class JwtTokenClaimsProvider { private readonly accessToken: any; private readonly idToken: any; + private readonly dbUsername: any; + private readonly role: any; - constructor(accessToken: string, idToken: string) { + constructor(accessToken: string, idToken: string, dbUsername: string, role: string) { this.accessToken = JwtTokenUtils.parseJwtToken(accessToken); this.idToken = JwtTokenUtils.parseJwtToken(idToken); + this.dbUsername = dbUsername + this.role = role } - getUsername(): string { - return this.accessToken.username; + getRole(): string { + return this.role; + } + + getDbUsername(): string { + return this.dbUsername; } getClientId(): string { @@ -45,15 +53,7 @@ export class JwtTokenClaimsProvider { } getCognitoUsername(): string { - return this.idToken["cognito:username"]; - } - - getGroups(): string[] { - let groups = this.accessToken["cognito:groups"]; - if (groups == null) { - return []; - } - return groups; + return this.accessToken.username; } getIssuedAt(): number { @@ -102,8 +102,9 @@ export class JwtTokenClaimsProvider { getClaims(): JwtTokenClaims { return { - username: this.getUsername(), - groups: this.getGroups(), + cognito_username: this.getCognitoUsername(), + db_username: this.getDbUsername(), + role: this.getRole(), issued_at: this.getIssuedAt(), expires_at: this.getExpiresAt(), auth_time: this.getAuthTime(), @@ -117,8 +118,9 @@ export class JwtTokenClaimsProvider { } export interface JwtTokenClaims { - username: string; - groups: string[]; + cognito_username: string; + db_username: string; + role: string; issued_at: number; expires_at: number; auth_time: number; diff --git a/source/idea/idea-cluster-manager/webapp/src/common/utils.ts b/source/idea/idea-cluster-manager/webapp/src/common/utils.ts index 1cdbdd4..891e432 100644 --- a/source/idea/idea-cluster-manager/webapp/src/common/utils.ts +++ b/source/idea/idea-cluster-manager/webapp/src/common/utils.ts @@ -138,6 +138,25 @@ class Utils { return def; } + static isListOfRecordStrings(value?: any): boolean { + let result = true; + if (value == null) { + result = false + } else { + for (let item of value) { + if (item == null || typeof item != "object") { + result = false; + } + Object.keys(item).map((key) => { + if (typeof item[key] != "string") { + result = false; + } + }); + } + } + return result; + } + static isEmpty(value?: any): boolean { if (value == null) { return true; @@ -836,13 +855,6 @@ class Utils { static getDefaultModuleSettings() { return [ - { - deployment_priority: 3, - module_id: "analytics", - name: "analytics", - title: "Analytics", - type: "stack", - }, { deployment_priority: 7, module_id: "bastion-host", diff --git a/source/idea/idea-cluster-manager/webapp/src/components/form-field/form-field.tsx b/source/idea/idea-cluster-manager/webapp/src/components/form-field/form-field.tsx index 0e3b626..5869858 100644 --- a/source/idea/idea-cluster-manager/webapp/src/components/form-field/form-field.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/components/form-field/form-field.tsx @@ -97,6 +97,8 @@ export interface IdeaFormFieldState { memoryVal(): string fileVal(): File[]; + + listOfRecordsVal(): Record[] } export interface IdeaFormFieldStateChangeEvent { @@ -261,6 +263,12 @@ class IdeaFormField extends Component { } return []; }, + listOfRecordsVal(): Record[] { + if (this.value != null) { + return this.value + } + return [] + } }; } @@ -315,6 +323,13 @@ class IdeaFormField extends Component { return this.state.value; } + getListOfStringRecords(): Record[] { + if (Utils.isListOfRecordStrings(this.state.listOfRecordsVal())) { + return this.state.listOfRecordsVal() + } + return [] + } + getTypedValue(): any { if (this.isMultiple()) { const dataType = this.getDataType(); @@ -324,6 +339,8 @@ class IdeaFormField extends Component { return this.state.numberArrayVal(); case "bool": return this.state.booleanArrayVal(); + case "record": + return this.state.listOfRecordsVal() default: return this.state.stringArrayVal(); } @@ -372,7 +389,6 @@ class IdeaFormField extends Component { reader.readAsDataURL(file); }) } - getErrorCode(): string | null { if (Utils.isEmpty(this.state.errorCode)) { @@ -716,6 +732,18 @@ class IdeaFormField extends Component { this.setState({ disabled: should_disable }, this.setStateCallback); } + validate_empty_record(record: Record, container_items: SocaUserInputParamMetadata[]): boolean { + if (Utils.isEmpty(record)) { + return true; + } + for (let column of container_items) { + if (column.name && Utils.isEmpty(record[column.name])) { + return true; + } + } + return false; + } + validate(): string { const validate = this.props.param.validate; if (validate == null) { @@ -752,6 +780,18 @@ class IdeaFormField extends Component { } if (this.isMultiple()) { + if (this.props.param.param_type === "container") { + if (this.getListOfStringRecords().length === 0) { + return "OK"; + } else { + for (let record of this.getListOfStringRecords()) { + if (this.props.param.container_items && this.validate_empty_record(record, this.props.param.container_items)) { + return "CUSTOM_FAILED" + } + } + return "OK"; + } + } if (this.state.stringArrayVal().length === 0) { return "REQUIRED"; } else { @@ -813,9 +853,9 @@ class IdeaFormField extends Component { case "REGEX": errorMessage = this.props.param.validate?.message ?? `${displayTitle} must satisfy regex: ${this.props.param.validate?.regex}`; break; - /*case 'CUSTOM_FAILED': - errorMessage = this.props.param.validate?.custom?.error_message - break*/ + case 'CUSTOM_FAILED': + errorMessage = this.props.param.custom_error_message + break default: errorMessage = `${displayTitle} validation failed.`; } @@ -1744,6 +1784,114 @@ class IdeaFormField extends Component { ) } + onContainerArrayStateChange(event: IdeaFormFieldStateChangeEvent, index: number) { + const values: Record[] = this.getListOfStringRecords(); + if (!event.param.name || !values[index]) { + this.setState({ + errorMessage: "Unable to map container state change." + }) + } else { + values[index][event.param.name] = event.value + this.setState( + { + value: values, + }, + () => { + if (this.triggerValidate()) { + this.setState({}, this.setStateCallback); + } + } + ); + } + } + + buildContainerArray(props: FormFieldProps): React.ReactNode { + return this.buildFormField( + + {this.getListOfStringRecords().length === 0 && ( + + )} + { + this.getListOfStringRecords().length > 0 && this.getListOfStringRecords().map((value, index) => { + const numOfColumns = this.props.param.container_items?.length ? this.props.param.container_items.length : 0 + return ( + + { + (() => { + let container: JSX.Element[] = []; + this.props.param.container_items?.map((form) => ( + container.push( { + this.onContainerArrayStateChange(event, index) + }} + onFetchOptions={this.props.onFetchOptions} + stretch={props.stretch} + />) + )) + return container; + })() + } + + + ); + }) + } + { + this.getListOfStringRecords().length > 0 && ( + + + + + ) + } + , + props + ) + } + getRenderType(): string { const type = this.getNativeType(); const param_type = this.props.param.param_type; @@ -1760,6 +1908,8 @@ class IdeaFormField extends Component { return "multi-select"; } else if (param_type === "select_or_text") { return "auto-suggest"; + } else if (param_type === "container"){ + return "parent_parameter_array" } else { if (multiline) { return "textarea-array"; @@ -1851,6 +2001,8 @@ class IdeaFormField extends Component { formFields.push(this.buildFileUpload({ stretch: stretch })); } else if (type === "tiles") { formFields.push(this.buildTilesGroup({ stretch: stretch })); + } else if (type === "parent_parameter_array") { + formFields.push(this.buildContainerArray({ stretch: stretch })); } else { formFields.push(this.buildInput({ stretch: stretch })); } diff --git a/source/idea/idea-cluster-manager/webapp/src/docs/apply-snapshots.md b/source/idea/idea-cluster-manager/webapp/src/docs/apply-snapshots.md new file mode 100644 index 0000000..57a271e --- /dev/null +++ b/source/idea/idea-cluster-manager/webapp/src/docs/apply-snapshots.md @@ -0,0 +1 @@ +Apply Snapshots \ No newline at end of file diff --git a/source/idea/idea-cluster-manager/webapp/src/docs/snapshot-management.md b/source/idea/idea-cluster-manager/webapp/src/docs/snapshot-management.md new file mode 100644 index 0000000..208ebfe --- /dev/null +++ b/source/idea/idea-cluster-manager/webapp/src/docs/snapshot-management.md @@ -0,0 +1,17 @@ +## Snapshot Management + +### Snapshots + +Create snapshots of the current environment and view their status. + +#### Create a snapshot + +Click **Create snapshot** button to create a snapshot of the current environment. + +### Applied Snapshots + +Apply a snapshot frrom a previous environment to the current environment and view its status. + +#### Apply Snapshot + +Click **Apply Snapshot** button to apply a Snapshot stored in an S3 bucket \ No newline at end of file diff --git a/source/idea/idea-cluster-manager/webapp/src/navigation/side-nav-items.tsx b/source/idea/idea-cluster-manager/webapp/src/navigation/side-nav-items.tsx index b44ceab..c552f5a 100644 --- a/source/idea/idea-cluster-manager/webapp/src/navigation/side-nav-items.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/navigation/side-nav-items.tsx @@ -34,7 +34,7 @@ export const IdeaSideNavItems = (context: AppContext): SideNavigationProps.Item[ }; result.push(userNav); - if (context.getClusterSettingsService().isVirtualDesktopDeployed() && context.auth().hasModuleAccess(Constants.MODULE_VIRTUAL_DESKTOP_CONTROLLER)) { + if (context.getClusterSettingsService().isVirtualDesktopDeployed()) { userNav.items.push({ type: "link", text: "My Virtual Desktops", @@ -45,30 +45,29 @@ export const IdeaSideNavItems = (context: AppContext): SideNavigationProps.Item[ text: "Shared Desktops", href: "#/home/shared-desktops", }); - } - - if (context.auth().hasModuleAccess(Constants.MODULE_CLUSTER_MANAGER)) { userNav.items.push({ type: "link", text: "File Browser", href: "#/home/file-browser", }); - if (context.getClusterSettingsService().isBastionHostDeployed()) { - userNav.items.push({ - type: "link", - text: "SSH Access Instructions", - href: "#/home/ssh-access", - }); - } } + if (context.getClusterSettingsService().isBastionHostDeployed()) { + userNav.items.push({ + type: "link", + text: "SSH Access Instructions", + href: "#/home/ssh-access", + }); + } + + // start admin section adminNavItems.push({ type: "divider", }); - if (context.getClusterSettingsService().isVirtualDesktopDeployed() && context.auth().isModuleAdmin(Constants.MODULE_VIRTUAL_DESKTOP_CONTROLLER)) { + if (context.getClusterSettingsService().isVirtualDesktopDeployed() && context.auth().isAdmin()) { adminNavItems.push({ type: "section", text: "Session Management", @@ -108,7 +107,7 @@ export const IdeaSideNavItems = (context: AppContext): SideNavigationProps.Item[ }); } - if (context.auth().isModuleAdmin(Constants.MODULE_CLUSTER_MANAGER)) { + if (context.auth().isAdmin()) { adminNavItems.push({ type: "section", text: "Environment Management", @@ -141,8 +140,8 @@ export const IdeaSideNavItems = (context: AppContext): SideNavigationProps.Item[ }, { type: "link", - text: "Environment Snapshots", - href: "#/cluster/snapshots", + text: "Snapshot Management", + href: "#/cluster/snapshot-management", }, { type: "link", diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/auth/auth-route.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/auth/auth-route.tsx index b12d4e1..637d307 100644 --- a/source/idea/idea-cluster-manager/webapp/src/pages/auth/auth-route.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/pages/auth/auth-route.tsx @@ -36,9 +36,9 @@ class IdeaAuthenticatedRoute extends Component { if (this.props.isLoggedIn) { if (isAuthRoute) { return ; - } else if (isVirtualDesktopAdminRoute && !(context.getClusterSettingsService().isVirtualDesktopDeployed() && context.auth().isModuleAdmin(Constants.MODULE_VIRTUAL_DESKTOP_CONTROLLER))) { + } else if (isVirtualDesktopAdminRoute && (!context.getClusterSettingsService().isVirtualDesktopDeployed() || !context.auth().isAdmin())) { return ; - } else if (isClusterAdminRoute && !context.auth().isModuleAdmin(Constants.MODULE_CLUSTER_MANAGER)) { + } else if (isClusterAdminRoute && !context.auth().isAdmin()) { return ; } else { return this.props.children; diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/cluster-admin/cluster-settings.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/cluster-admin/cluster-settings.tsx index 1968c4c..a675286 100644 --- a/source/idea/idea-cluster-manager/webapp/src/pages/cluster-admin/cluster-settings.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/pages/cluster-admin/cluster-settings.tsx @@ -33,7 +33,6 @@ export interface ClusterSettingsState { identityProvider: any; directoryservice: any; clusterManager: any; - analytics: any; metrics: any; activeTabId: string; } @@ -52,7 +51,6 @@ class ClusterSettings extends Component { @@ -502,11 +497,6 @@ class ClusterSettings extends Component { - let externalAlbUrl = ConfigUtils.getExternalAlbUrl(this.state.cluster); - return `${externalAlbUrl}/_dashboards`; - }; - const isMetricsEnabled = () => { return AppContext.get().getClusterSettingsService().isMetricsEnabled(); }; @@ -731,8 +721,6 @@ class ClusterSettings extends Component - - @@ -807,31 +795,6 @@ class ClusterSettings extends Component ), }, - { - label: "Analytics", - id: "analytics", - content: ( - - OpenSearch Settings}> - - - - - - - - - Kinesis Settings}> - - - - - - - - - ), - }, { label: "Metrics", id: "metrics", diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/cluster-admin/filesystem.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/cluster-admin/filesystem.tsx index 3b9ce4c..a28ca8a 100644 --- a/source/idea/idea-cluster-manager/webapp/src/pages/cluster-admin/filesystem.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/pages/cluster-admin/filesystem.tsx @@ -387,13 +387,13 @@ class FileSystems extends Component { name: `${Constants.SHARED_STORAGE_PROVIDER_EFS}.mount_directory`, title: "Mount Directory", description: "Enter directory to mount the file system", - help_text: "Mount directory cannot contain white spaces or special characters. Only use lowercase alphabets, numbers, and hyphens (-). Must be between 3 and 18 characters long. Eg. /efs-01", + help_text: "Mount directory cannot contain white spaces or special characters. Only use lowercase alphabets, numbers, and hyphens (-). Must be between 3 and 18 characters long starting with '/'. Eg. /efs-01", data_type: "str", param_type: "text", validate: { required: true, regex: "^/([a-z0-9-]+){3,18}$", - message: "Only use lowercase alphabets, numbers, and hyphens (-). Must be between 3 and 18 characters long." + message: "Only use lowercase alphabets, numbers, and hyphens (-). Must be between 3 and 18 characters long starting with '/'." } } ]; @@ -443,12 +443,12 @@ class FileSystems extends Component { name: `${Constants.SHARED_STORAGE_PROVIDER_FSX_NETAPP_ONTAP}.mount_directory`, title: "Mount Directory", description: "Enter directory to mount the file system", - help_text: "Mount directory cannot contain white spaces or special characters. Only use lowercase alphabets, numbers, and hyphens (-). Must be between 3 and 18 characters long. Eg. /efs-01", + help_text: "Mount directory cannot contain white spaces or special characters. Only use lowercase alphabets, numbers, and hyphens (-). Must be between 3 and 18 characters long starting with '/'. Eg. /efs-01", data_type: "str", param_type: "text", validate: { regex: "(^.{0}$)|(^/([a-z0-9-]+){3,18}$)", - message: "Only use lowercase alphabets, numbers, and hyphens (-). Must be between 3 and 18 characters long." + message: "Only use lowercase alphabets, numbers, and hyphens (-). Must be between 3 and 18 characters long starting with '/'." }, when: { param: "onboard_filesystem", @@ -730,7 +730,7 @@ class FileSystems extends Component { items: [ { type: "error", - content: `EFS File System ${values.filesystem_name} create failed ${error.message}.`, + content: `EFS File System ${values.filesystem_name} create failed. Error: ${error.message}.`, dismissible: true, }, ], @@ -756,7 +756,7 @@ class FileSystems extends Component { items: [ { type: "error", - content: `FSx for NetApp ONTAP File System ${values.filesystem_name} create failed - ${error.message}`, + content: `FSx for NetApp ONTAP File System ${values.filesystem_name} create failed. Error: ${error.message}`, dismissible: true, }, ], @@ -945,7 +945,7 @@ class FileSystems extends Component { title: "Storage Capacity", description: "Enter storage capacity for your file system", help_text: "SSD storage capacity in GiB", - data_type: "str", + data_type: "int", param_type: "text", validate: { required: true, @@ -991,12 +991,12 @@ class FileSystems extends Component { name: "mount_directory", title: "Mount Directory", description: "Enter directory to mount the file system", - help_text: "Mount target cannot contain white spaces or special characters. Only use lowercase alphabets, numbers, and hyphens (-). Must be between 3 and 18 characters long. Eg. /efs-01", + help_text: "Mount target cannot contain white spaces or special characters. Only use lowercase alphabets, numbers, and hyphens (-). Must be between 3 and 18 characters long starting with '/'. Eg. /efs-01", data_type: "str", param_type: "text", validate: { regex: "(^.{0}$)|(^/([a-z0-9-]+){3,18}$)", - message: "Only use lowercase alphabets, numbers, and hyphens (-). Must be between 3 and 18 characters long.", + message: "Only use lowercase alphabets, numbers, and hyphens (-). Must be between 3 and 18 characters long starting with '/'.", }, }, { @@ -1008,7 +1008,7 @@ class FileSystems extends Component { param_type: "text", validate: { regex: "(^.{0}$)|(^[ABD-Z]$)", - message: "Mount drive should be in uppercase", + message: "Mount drive should be an uppercase alphabet except 'C'", }, when: { param: "filesystem_provider", eq: Constants.SHARED_STORAGE_PROVIDER_FSX_NETAPP_ONTAP diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/cluster-admin/projects.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/cluster-admin/projects.tsx index 2ad1892..fcefb7c 100644 --- a/source/idea/idea-cluster-manager/webapp/src/pages/cluster-admin/projects.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/pages/cluster-admin/projects.tsx @@ -100,6 +100,23 @@ const PROJECT_TABLE_COLUMN_DEFINITIONS: TableProps.ColumnDefinition[] = } }, }, + { + id: "user", + header: "Users", + cell: (project) => { + if (project.users) { + return ( +
+ {project.users.map((user, index) => { + return
  • {user}
  • ; + })} +
    + ); + } else { + return "-"; + } + }, + }, { id: "updated_on", header: "Updated On", @@ -227,6 +244,24 @@ class Projects extends Component { }); return params; } + + buildUserParam(): SocaUserInputParamMetadata[] { + const params: SocaUserInputParamMetadata[] = []; + params.push({ + name: "users", + title: "Users", + description: "Select applicable users for the Project", + param_type: "select", + multiple: true, + data_type: "str", + dynamic_choices: true, + validate: { + required: false, + }, + }); + return params; + } + buildCreateProjectForm() { let values = undefined; const isUpdate = this.state.createProjectModalType === "update"; @@ -315,6 +350,30 @@ class Projects extends Component { }; } }); + } else if (request.param === "users") { + return this.accounts() + .listUsers() + .then((result) => { + const listing = result.listing!; + if (listing.length === 0) { + return { + listing: [], + }; + } else { + const choices: SocaUserInputChoice[] = []; + listing.forEach((value) => { + if (value.username != "clusteradmin") { + choices.push({ + title: `${value.username} (${value.uid})`, + value: value.username, + }); + } + }); + return { + listing: choices, + }; + } + }); } else if (request.param === "add_filesystems") { let promises: Promise[] = []; promises.push(this.clusterSettings().getModuleSettings({ module_id: Constants.MODULE_SHARED_STORAGE })); @@ -392,6 +451,7 @@ class Projects extends Component { }, dynamic_choices: true, }, + ...this.buildUserParam(), ...this.buildAddFileSystemParam(isUpdate), { name: "enable_budgets", diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/dashboard/job-submissions-widget.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/dashboard/job-submissions-widget.tsx index 26cc3f1..bd66994 100644 --- a/source/idea/idea-cluster-manager/webapp/src/pages/dashboard/job-submissions-widget.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/pages/dashboard/job-submissions-widget.tsx @@ -16,31 +16,6 @@ import { AppContext } from "../../common"; import { useEffect } from "react"; export function JobSubmissionsWidget() { - useEffect(() => { - AppContext.get() - .client() - .analytics() - .queryOpenSearch({ - data: { - index: "idea-test1_jobs", - body: { - size: 0, - aggs: { - jobs: { - date_histogram: { - field: "queue_time", - calendar_interval: "hour", - }, - }, - }, - }, - }, - }) - .then((result) => { - console.log(result); - }); - }); - const start = new Date(); const end = new Date(); start.setDate(start.getDate() - 7); diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/snapshots/apply-snapshot.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/snapshots/apply-snapshot.tsx new file mode 100644 index 0000000..956a485 --- /dev/null +++ b/source/idea/idea-cluster-manager/webapp/src/pages/snapshots/apply-snapshot.tsx @@ -0,0 +1,248 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +import React, { Component, RefObject } from "react"; + +import { AppContext } from "../../common"; +import IdeaForm from "../../components/form"; +import IdeaListView from "../../components/list-view"; +import { ApplySnapshot } from "../../client/data-model"; +import SnapshotsClient from "../../client/snapshots-client"; +import { IdeaSideNavigationProps } from "../../components/side-navigation"; +import { IdeaAppLayoutProps } from "../../components/app-layout"; +import { withRouter } from "../../navigation/navigation-utils"; +import { StatusIndicator } from "@cloudscape-design/components"; +import { TableProps } from "@cloudscape-design/components/table/interfaces"; + +export interface ApplySnapshotsProps extends IdeaAppLayoutProps, IdeaSideNavigationProps {} + +export const APPLY_SNAPSHOT_TABLE_COLUMN_DEFINITIONS: TableProps.ColumnDefinition[] = [ + { + id: "s3_bucket_name", + header: "S3 Bucket Name", + cell: (e) => e.s3_bucket_name, + }, + { + id: "snapshot_path", + header: "Snapshot Path", + cell: (e) => e.snapshot_path, + }, + { + id: "status", + header: "Status", + cell: (e) => { + switch (e.status) { + case "COMPLETED": + return ( + + {e.status} + + ); + case "FAILED": + case "ROLLBACK_IN_PROGRESS": + case "ROLLBACK_COMPLETE": + case "ROLLBACK_FAILED": + return ( + + {e.status} + + ); + default: + return ( + + {e.status} + + ); + } + }, + }, + { + id: "created_on", + header: "Created On", + cell: (e) => new Date(e.created_on!).toLocaleString(), + } +]; + + +class ApplySnapshots extends Component { + applySnapshotForm: RefObject; + listing: RefObject; + + constructor(props: ApplySnapshotsProps) { + super(props); + this.applySnapshotForm = React.createRef(); + this.listing = React.createRef(); + } + + snapshotsClient(): SnapshotsClient { + return AppContext.get().client().snapshots(); + } + + getApplySnapshotForm(): IdeaForm { + return this.applySnapshotForm.current!; + } + + getListing(): IdeaListView { + return this.listing.current!; + } + + buildApplySnapshotForm() { + return ( + { + if (!this.getApplySnapshotForm().validate()) { + return; + } + const values = this.getApplySnapshotForm().getValues(); + + this.snapshotsClient() + .applySnapshot({ + snapshot: { + s3_bucket_name: values.s3_bucket_name, + snapshot_path: values.snapshot_path, + }, + }) + .then((_) => { + this.props.onFlashbarChange({ + items: [ + { + type: "success", + content: "Apply Snapshot initiated. It takes about 5 minutes for the process to complete. Please refresh this page after some time to check the status.", + dismissible: true, + }, + ], + }); + this.getListing().fetchRecords(); + this.getApplySnapshotForm().hideModal(); + }) + .catch((error) => { + this.props.onFlashbarChange({ + items: [ + { + type: "error", + content: `Failed to apply Snapshot: ${error.message}`, + dismissible: true, + }, + ], + }); + this.getListing().fetchRecords(); + this.getApplySnapshotForm().hideModal(); + }); + }} + onCancel={() => { + this.getApplySnapshotForm().hideModal(); + }} + params={[ + { + name: "s3_bucket_name", + title: "S3 Bucket Name", + description: "Enter the name of the S3 bucket where the snapshot to be applied is stored.", + help_text: "S3 bucket name can only contain lowercase alphabets, numbers, dots (.), and hyphens (-).", + data_type: "str", + param_type: "text", + validate: { + required: true, + regex: "^[a-z0-9]+[\\.\\-\\w]*[a-z0-9]+$", + message: "S3 bucket name can only contain lowercase alphabets, numbers, dots (.), and hyphens (-).", + }, + }, + { + name: "snapshot_path", + title: "Snapshot Path", + description: "Enter the path at which the snapshot to be applied is stored in the provided S3 bucket.", + help_text: "Snapshot path can only contain forward slashes, dots (.), exclamations (!), asterisks (*), single quotes ('), parentheses (), and hyphens (-).", + data_type: "str", + param_type: "text", + validate: { + required: true, + regex: "^([\\w\\.\\-\\!\\*\\'\\(\\)]+[\\/]*)+$", + message: "Snapshot path can only contain forward slashes, dots (.), exclamations (!), asterisks (*), single quotes ('), parentheses (), and hyphens (-).", + }, + }, + ]} + /> + ); + } + + buildListing() { + return ( + { + this.getApplySnapshotForm().showModal(); + }, + }} + showPaginator={true} + showFilters={true} + filters={[ + { + key: "s3_bucket_name", + }, + ]} + onFilter={(filters) => { + const s3BucketToken = String(filters[0].value).toString().trim().toLowerCase(); + if (s3BucketToken == null) { + return []; + } else { + return [ + { + key: "s3_bucket_name", + like: s3BucketToken, + }, + ]; + } + }} + onRefresh={() => { + this.setState( + {}, + () => { + this.getListing().fetchRecords(); + } + ); + }} + onFetchRecords={async () => { + let result = await this.snapshotsClient().listAppliedSnapshots({ + filters: this.getListing().getFilters(), + paginator: this.getListing().getPaginator(), + }); + result.listing?.sort((a,b) => new Date(b.created_on!).getTime() - new Date(a.created_on!).getTime()) + return result + }} + columnDefinitions={APPLY_SNAPSHOT_TABLE_COLUMN_DEFINITIONS} + /> + ); + } + + render() { + return ( + + {this.buildApplySnapshotForm()} + {this.buildListing()} + + ); + } +} + +export default withRouter(ApplySnapshots); \ No newline at end of file diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/snapshots/snapshot-management.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/snapshots/snapshot-management.tsx new file mode 100644 index 0000000..ad16430 --- /dev/null +++ b/source/idea/idea-cluster-manager/webapp/src/pages/snapshots/snapshot-management.tsx @@ -0,0 +1,86 @@ +import React, { Component, RefObject } from "react"; + +import { IdeaSideNavigationProps } from "../../components/side-navigation"; +import IdeaAppLayout, { IdeaAppLayoutProps } from "../../components/app-layout"; +import { Button, Container, Header, SpaceBetween } from "@cloudscape-design/components"; +import { withRouter } from "../../navigation/navigation-utils"; + +import ApplySnapshots from "./apply-snapshot" +import Snapshots from "./snapshots" + +export interface SnapshotManagementProps extends IdeaAppLayoutProps, IdeaSideNavigationProps {} + +class SnapshotManagement extends Component { + + render() { + return ( + + Snapshot Management + + } + contentType={"default"} + content={ + + + + + + + } + /> + ) + } +} + + +export default withRouter(SnapshotManagement); \ No newline at end of file diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/snapshots/snapshots.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/snapshots/snapshots.tsx index 3f435ad..d667400 100644 --- a/source/idea/idea-cluster-manager/webapp/src/pages/snapshots/snapshots.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/pages/snapshots/snapshots.tsx @@ -19,7 +19,7 @@ import IdeaListView from "../../components/list-view"; import { Snapshot } from "../../client/data-model"; import Utils from "../../common/utils"; import { IdeaSideNavigationProps } from "../../components/side-navigation"; -import IdeaAppLayout, { IdeaAppLayoutProps } from "../../components/app-layout"; +import { IdeaAppLayoutProps } from "../../components/app-layout"; import { withRouter } from "../../navigation/navigation-utils"; import SnapshotsClient from "../../client/snapshots-client"; import { StatusIndicator } from "@cloudscape-design/components"; @@ -50,19 +50,19 @@ export const SNAPSHOT_TABLE_COLUMN_DEFINITIONS: TableProps.ColumnDefinition - Completed + {e.status} ); case "FAILED": return ( - Failed + {e.status} ); default: return ( - In Progress + {e.status} ); } @@ -154,7 +154,7 @@ class Snapshots extends Component { { name: "s3_bucket_name", title: "S3 Bucket Name", - description: "Enter the name of the S3 bucket where the snapshot will be stored in.", + description: "Enter the name of an existing S3 bucket where the snapshot should be stored.", help_text: "S3 bucket name can only contain lowercase alphabets, numbers, dots (.), and hyphens (-).", data_type: "str", param_type: "text", @@ -167,7 +167,7 @@ class Snapshots extends Component { { name: "snapshot_path", title: "Snapshot Path", - description: "Enter the path at which the snapshot should be stored in the provided S3 bucket.", + description: "Enter a path at which the snapshot should be stored in the provided S3 bucket.", help_text: "Snapshot path can only contain forward slashes, dots (.), exclamations (!), asterisks (*), single quotes ('), parentheses (), and hyphens (-).", data_type: "str", param_type: "text", @@ -192,8 +192,9 @@ class Snapshots extends Component { ref={this.listing} preferencesKey={"snapshots"} showPreferences={false} - title="Snapshots" - description="Environment snapshot management" + title="Created Snapshots" + variant="container" + description="Snapshots created from the environment" primaryAction={{ id: "create-snapshot", text: "Create Snapshot", @@ -236,11 +237,13 @@ class Snapshots extends Component { snapshotSelected: true, }); }} - onFetchRecords={() => { - return this.snapshotsClient().listSnapshots({ + onFetchRecords={async () => { + let result = await this.snapshotsClient().listSnapshots({ filters: this.getListing().getFilters(), paginator: this.getListing().getPaginator(), }); + result.listing?.sort((a,b) => new Date(b.created_on!).getTime() - new Date(a.created_on!).getTime()) + return result }} columnDefinitions={SNAPSHOT_TABLE_COLUMN_DEFINITIONS} /> @@ -249,38 +252,10 @@ class Snapshots extends Component { render() { return ( - - {this.buildCreateSnapshotForm()} - {this.buildListing()} - - } - /> + + {this.buildCreateSnapshotForm()} + {this.buildListing()} + ); } } diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/charts/virtual-desktop-az-distribution.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/charts/virtual-desktop-az-distribution.tsx index 9c7fa16..9f4f71a 100644 --- a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/charts/virtual-desktop-az-distribution.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/charts/virtual-desktop-az-distribution.tsx @@ -12,7 +12,6 @@ */ import { Box } from "@cloudscape-design/components"; -import { AppContext } from "../../../common"; import VirtualDesktopBaseChart from "./virtual-desktop-base-chart"; import PieOrDonutChart from "../../../components/charts/pie-or-donut-chart"; @@ -36,68 +35,6 @@ class VirtualDesktopAZDistributionChart extends VirtualDesktopBaseChart { - AppContext.get() - .client() - .analytics() - .queryOpenSearch({ - data: { - index: this.props.indexName, - body: { - size: 0, - aggs: { - availability_zone: { - terms: { - field: "server.availability_zone.raw", - }, - }, - }, - }, - }, - }) - .then((result) => { - let chartData: any = []; - if (result.data?.aggregations) { - const aggregations: any = result.data.aggregations; - let availability_zone = aggregations.availability_zone; - let buckets: any[] = availability_zone.buckets; - buckets.forEach((bucket) => { - chartData.push({ - title: bucket.key, - value: bucket.doc_count, - }); - }); - } - let hits: any = result.data?.hits; - this.setState({ - chartData: chartData, - total: `${hits.total.value}`, - statusType: "finished", - }); - }) - .catch((error) => { - console.error(error); - this.setState({ - statusType: "error", - }); - }); - } - ); - } - render() { return ( { - constructor(props: VirtualDesktopBaseOSChartProps) { - super(props); - this.state = { - chartData: [], - total: "-", - statusType: "loading", - }; - } - - componentDidMount() { - this.loadChartData(); - } - - reload() { - this.loadChartData(); - } +class VirtualDesktopBaseOSChart extends VirtualDesktopBaseChart { + render() { + const states = this.props.sessions.reduce((eax: {[key: string]: number}, item: any) => { + eax[item.base_os] = (eax[item.base_os] || 0) + 1; + return eax; + }, {}) + + let chartData: {title: string, value: number}[] = Object.entries(states).map(([key, value]) => {return {title: key, value: value}}); - loadChartData() { - this.setState( - { - statusType: "loading", - }, - () => { - AppContext.get() - .client() - .analytics() - .queryOpenSearch({ - data: { - index: this.props.indexName, - body: { - size: 0, - aggs: { - base_os: { - terms: { - field: "base_os.raw", - }, - }, - }, - }, - }, - }) - .then((result) => { - let chartData: any = []; - if (result.data?.aggregations) { - const aggregations: any = result.data.aggregations; - let base_os = aggregations.base_os; - let buckets: any[] = base_os.buckets; - buckets.forEach((bucket) => { - chartData.push({ - title: Utils.getOsTitle(bucket.key), - value: bucket.doc_count, - }); - }); - } - let hits: any = result.data?.hits; - this.setState({ - chartData: chartData, - total: `${hits.total.value}`, - statusType: "finished", - }); - }) - .catch((error) => { - console.error(error); - this.setState({ - statusType: "error", - }); - }); - } - ); - } + const statusType = this.props.loading ? 'loading' : 'finished' - render() { return ( No sessions available diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/charts/virtual-desktop-instance-types-chart.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/charts/virtual-desktop-instance-types-chart.tsx index 2e63aa1..e9246fe 100644 --- a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/charts/virtual-desktop-instance-types-chart.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/charts/virtual-desktop-instance-types-chart.tsx @@ -12,102 +12,34 @@ */ import { Box } from "@cloudscape-design/components"; -import { AppContext } from "../../../common"; import VirtualDesktopBaseChart from "./virtual-desktop-base-chart"; import PieOrDonutChart from "../../../components/charts/pie-or-donut-chart"; +import { VirtualDesktopSession } from '../../../client/data-model' export interface VirtualDesktopInstanceTypesChartProps { - indexName: string; + loading: boolean + sessions: VirtualDesktopSession[]; } -interface VirtualDesktopInstanceTypesChartState { - chartData: any; - total: string; - statusType: "loading" | "finished" | "error"; -} - -class VirtualDesktopInstanceTypesChart extends VirtualDesktopBaseChart { - constructor(props: VirtualDesktopInstanceTypesChartProps) { - super(props); - this.state = { - chartData: [], - total: "-", - statusType: "loading", - }; - } - - componentDidMount() { - this.loadChartData(); - } - - reload() { - this.loadChartData(); - } +class VirtualDesktopInstanceTypesChart extends VirtualDesktopBaseChart { + render() { + const states = this.props.sessions.reduce((eax: {[key: string]: number}, item: any) => { + eax[item.server.instance_type] = (eax[item.server.instance_type] || 0) + 1; + return eax; + }, {}) + + let chartData: {title: string, value: number}[] = Object.entries(states).map(([key, value]) => {return {title: key, value: value}}); - loadChartData() { - this.setState( - { - statusType: "loading", - }, - () => { - AppContext.get() - .client() - .analytics() - .queryOpenSearch({ - data: { - index: this.props.indexName, - body: { - size: 0, - aggs: { - instance_type: { - terms: { - field: "server.instance_type.raw", - }, - }, - }, - }, - }, - }) - .then((result) => { - let chartData: any = []; - let total: number = 0; - if (result.data?.aggregations) { - const aggregations: any = result.data.aggregations; - let instance_type = aggregations.instance_type; - let buckets: any[] = instance_type.buckets; - buckets.forEach((bucket) => { - total += bucket.doc_count; - chartData.push({ - title: bucket.key, - value: bucket.doc_count, - }); - }); - } - this.setState({ - chartData: chartData, - total: `${total}`, - statusType: "finished", - }); - }) - .catch((error) => { - console.error(error); - this.setState({ - statusType: "error", - }); - }); - } - ); - } + const statusType = this.props.loading ? 'loading' : 'finished' - render() { return ( No sessions available diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/charts/virtual-desktop-project-chart.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/charts/virtual-desktop-project-chart.tsx index d4a9d88..7f1436c 100644 --- a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/charts/virtual-desktop-project-chart.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/charts/virtual-desktop-project-chart.tsx @@ -12,101 +12,34 @@ */ import { Box } from "@cloudscape-design/components"; -import { AppContext } from "../../../common"; import VirtualDesktopBaseChart from "./virtual-desktop-base-chart"; import PieOrDonutChart from "../../../components/charts/pie-or-donut-chart"; +import { VirtualDesktopSession } from '../../../client/data-model' export interface VirtualDesktopProjectChartProps { - indexName: string; + loading: boolean + sessions: VirtualDesktopSession[]; } -interface VirtualDesktopProjectChartState { - chartData: any; - total: string; - statusType: "loading" | "finished" | "error"; -} - -class VirtualDesktopProjectChart extends VirtualDesktopBaseChart { - constructor(props: VirtualDesktopProjectChartProps) { - super(props); - this.state = { - chartData: [], - total: "-", - statusType: "loading", - }; - } - - componentDidMount() { - this.loadChartData(); - } - - reload() { - this.loadChartData(); - } +class VirtualDesktopProjectChart extends VirtualDesktopBaseChart { + render() { + const states = this.props.sessions.reduce((eax: {[key: string]: number}, item: any) => { + eax[item.project.name] = (eax[item.project.name] || 0) + 1; + return eax; + }, {}) + + let chartData: {title: string, value: number}[] = Object.entries(states).map(([key, value]) => {return {title: key, value: value}}); - loadChartData() { - this.setState( - { - statusType: "loading", - }, - () => { - AppContext.get() - .client() - .analytics() - .queryOpenSearch({ - data: { - index: this.props.indexName, - body: { - size: 0, - aggs: { - project: { - terms: { - field: "project.name.raw", - }, - }, - }, - }, - }, - }) - .then((result) => { - let chartData: any = []; - if (result.data?.aggregations) { - const aggregations: any = result.data.aggregations; - let base_os = aggregations.project; - let buckets: any[] = base_os.buckets; - buckets.forEach((bucket) => { - chartData.push({ - title: bucket.key, - value: bucket.doc_count, - }); - }); - } - let hits: any = result.data?.hits; - this.setState({ - chartData: chartData, - total: `${hits.total.value}`, - statusType: "finished", - }); - }) - .catch((error) => { - console.error(error); - this.setState({ - statusType: "error", - }); - }); - } - ); - } + const statusType = this.props.loading ? 'loading' : 'finished' - render() { return ( No sessions available diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/charts/virtual-desktop-software-stack-chart.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/charts/virtual-desktop-software-stack-chart.tsx index 233448c..55954a8 100644 --- a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/charts/virtual-desktop-software-stack-chart.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/charts/virtual-desktop-software-stack-chart.tsx @@ -11,88 +11,25 @@ * and limitations under the License. */ import VirtualDesktopBaseChart from "./virtual-desktop-base-chart"; -import { AppContext } from "../../../common"; import { BarChart, Box, Container, Header } from "@cloudscape-design/components"; +import { VirtualDesktopSession } from '../../../client/data-model' export interface VirtualDesktopSoftwareStackChartProps { - indexName: string; + loading: boolean + sessions: VirtualDesktopSession[]; } -interface VirtualDesktopSoftwareStackChartState { - series: any; - statusType: "loading" | "finished" | "error"; -} - -class VirtualDesktopSoftwareStackChart extends VirtualDesktopBaseChart { - constructor(props: VirtualDesktopSoftwareStackChartProps) { - super(props); - this.state = { - series: [], - statusType: "loading", - }; - } - - componentDidMount() { - this.loadChartData(); - } - - reload() { - this.loadChartData(); - } +class VirtualDesktopSoftwareStackChart extends VirtualDesktopBaseChart { + render() { + const states = this.props.sessions.reduce((eax: {[key: string]: number}, item: any) => { + eax[item.software_stack.name] = (eax[item.software_stack.name] || 0) + 1; + return eax; + }, {}) + + let chartData: {x: string, y: number}[] = Object.entries(states).map(([key, value]) => {return {x: key, y: value}}); - loadChartData() { - this.setState( - { - statusType: "loading", - }, - () => { - AppContext.get() - .client() - .analytics() - .queryOpenSearch({ - data: { - index: this.props.indexName, - body: { - size: 0, - aggs: { - software_stack: { - terms: { - field: "software_stack.name.raw", - }, - }, - }, - }, - }, - }) - .then((result) => { - let series: any[] = []; - if (result.data?.aggregations) { - const aggregations: any = result.data.aggregations; - let software_stacks = aggregations.software_stack; - let buckets: any[] = software_stacks.buckets; - buckets.forEach((bucket) => { - series.push({ - x: bucket.key, - y: bucket.doc_count, - }); - }); - } - this.setState({ - series: series, - statusType: "finished", - }); - }) - .catch((error) => { - console.error(error); - this.setState({ - statusType: "error", - }); - }); - } - ); - } + const statusType = this.props.loading ? 'loading' : 'finished' - render() { return ( { - constructor(props: VirtualDesktopStateChartProps) { - super(props); - this.state = { - chartData: [], - total: "-", - statusType: "loading", - }; - } - - componentDidMount() { - this.loadChartData(); - } - - reload() { - this.loadChartData(); - } +class VirtualDesktopStateChart extends VirtualDesktopBaseChart { + render() { + const states = this.props.sessions.reduce((eax: {[key: string]: number}, item: any) => { + eax[item.state] = (eax[item.state] || 0) + 1; + return eax; + }, {}) + + let chartData: {title: string, value: number}[] = Object.entries(states).map(([key, value]) => {return {title: key, value: value}}); - loadChartData() { - this.setState( - { - statusType: "loading", - }, - () => { - AppContext.get() - .client() - .analytics() - .queryOpenSearch({ - data: { - index: this.props.indexName, - body: { - size: 0, - aggs: { - state: { - terms: { - field: "state.raw", - }, - }, - }, - }, - }, - }) - .then((result) => { - let chartData: any = []; - if (result.data?.aggregations) { - const aggregations: any = result.data.aggregations; - let state = aggregations.state; - let buckets: any[] = state.buckets; - buckets.forEach((bucket) => { - chartData.push({ - title: bucket.key, - value: bucket.doc_count, - }); - }); - } - let hits: any = result.data?.hits; - this.setState({ - chartData: chartData, - total: `${hits.total.value}`, - statusType: "finished", - }); - }) - .catch((error) => { - console.error(error); - this.setState({ - statusType: "error", - }); - }); - } - ); - } + const statusType = this.props.loading ? 'loading' : 'finished' - render() { return ( No sessions available diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/components/dcv-client-help-modal.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/components/dcv-client-help-modal.tsx index fa106a7..544f224 100644 --- a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/components/dcv-client-help-modal.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/components/dcv-client-help-modal.tsx @@ -34,8 +34,6 @@ function downloadDcvClient(os: string) { window.open(client_settings.linux.rhel_centos_rocky8.url); } else if (os === "linux-suse15") { window.open(client_settings.linux.suse15.url); - } else if (os === "ubuntu-ubuntu1804") { - window.open(client_settings.ubuntu.ubuntu1804.url); } else if (os === "ubuntu-ubuntu2004") { window.open(client_settings.ubuntu.ubuntu2004.url); } else if (os === "ubuntu-ubuntu2204") { @@ -59,8 +57,6 @@ function getDCVClientLabelForOSFlavor(os: string): string { return client_settings.linux.rhel_centos_rocky8.label; } else if (os === "linux-suse15") { return client_settings.linux.suse15.label; - } else if (os === "ubuntu-ubuntu1804") { - return client_settings.ubuntu.ubuntu1804.label; } else if (os === "ubuntu-ubuntu2004") { return client_settings.ubuntu.ubuntu2004.label; } else if (os === "ubuntu-ubuntu2204") { @@ -198,9 +194,6 @@ export function DcvClientHelpModal(props: DcvClientHelpModalProps) { Step 1) Download DCV Ubuntu Client.

    - diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/forms/virtual-desktop-create-session-form.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/forms/virtual-desktop-create-session-form.tsx index 31b6e94..ad4acb8 100644 --- a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/forms/virtual-desktop-create-session-form.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/forms/virtual-desktop-create-session-form.tsx @@ -25,7 +25,7 @@ export interface VirtualDesktopCreateSessionFormProps { defaultName?: string; maxRootVolumeMemory: number; isAdminView?: boolean; - onSubmit: (session_name: string, username: string, project_id: string, base_os: VirtualDesktopBaseOS, software_stack_id: string, session_type: VirtualDesktopSessionType, instance_type: string, storage_size: number, hibernation_enabled: boolean, vpc_subnet_id: string) => Promise; + onSubmit: (session_name: string, username: string, project_id: string, base_os: VirtualDesktopBaseOS, software_stack_id: string, session_type: VirtualDesktopSessionType, instance_type: string, storage_size: number, hibernation_enabled: boolean, vpc_subnet_id: string, session_tags: Record[]) => Promise; onDismiss: () => void; } @@ -575,6 +575,36 @@ class VirtualDesktopCreateSessionForm extends Component { + onSubmit={(session_name, username, project_id, base_os, software_stack_id, session_type, instance_type, storage_size, hibernation_enabled, vpc_subnet_id, tags) => { return this.getVirtualDesktopClient() .createSession({ session: { @@ -945,6 +945,7 @@ class MyVirtualDesktopSessions extends Component { diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/virtual-desktop-dashboard.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/virtual-desktop-dashboard.tsx index 328f619..1ae3ec6 100644 --- a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/virtual-desktop-dashboard.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/virtual-desktop-dashboard.tsx @@ -20,19 +20,16 @@ import { AppContext } from "../../common"; import VirtualDesktopBaseChart from "./charts/virtual-desktop-base-chart"; import VirtualDesktopStateChart from "./charts/virtual-desktop-state-chart"; import VirtualDesktopBaseOSChart from "./charts/virtual-desktop-baseos-chart"; -import VirtualDesktopAZDistributionChart from "./charts/virtual-desktop-az-distribution"; import VirtualDesktopSoftwareStackChart from "./charts/virtual-desktop-software-stack-chart"; -import { Constants } from "../../common/constants"; -import dot from "dot-object"; import VirtualDesktopProjectChart from "./charts/virtual-desktop-project-chart"; import { withRouter } from "../../navigation/navigation-utils"; +import { VirtualDesktopSession } from '../../client/data-model' export interface VirtualDesktopDashboardProps extends IdeaAppLayoutProps, IdeaSideNavigationProps {} export interface VirtualDesktopDashboardState { - moduleInfo: any; - settings: any; - settingsLoaded: boolean; + sessions: VirtualDesktopSession[]; + loading: boolean } class VirtualDesktopDashboard extends Component { @@ -41,7 +38,6 @@ class VirtualDesktopDashboard extends Component; baseOsChart: RefObject; projectChart: RefObject; - azDistributionChart: RefObject; softwareStackChart: RefObject; constructor(props: VirtualDesktopDashboardProps) { @@ -49,35 +45,44 @@ class VirtualDesktopDashboard extends Component { - let moduleInfo = AppContext.get().getClusterSettingsService().getModuleInfo(Constants.MODULE_VIRTUAL_DESKTOP_CONTROLLER); - this.setState({ - moduleInfo: moduleInfo, - settings: settings, - settingsLoaded: true, - }); - }); + this.loadSessionsData() } reloadAllCharts() { - this.allCharts.forEach((chartRef) => { - chartRef.current!.reload(); - }); + this.loadSessionsData() + } + + loadSessionsData() { + this.setState({loading: true}) + AppContext.get().client().virtualDesktopAdmin() + .listSessions({ + paginator: { page_size: 100 } + }) + .then((sessions) => { + this.setState({sessions: sessions.listing ?? [], loading: false}) + }) + .catch((error) => { + this.props.onFlashbarChange({ + items: [ + { + content: error.message, + type: "error", + dismissible: true, + }, + ], + }); + throw error; + }); } render() { @@ -136,12 +141,11 @@ class VirtualDesktopDashboard extends Component - {this.state.settingsLoaded && } - {this.state.settingsLoaded && } - {this.state.settingsLoaded && } - {this.state.settingsLoaded && } - {this.state.settingsLoaded && } - {this.state.settingsLoaded && } + + + + + } /> diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/virtual-desktop-sessions.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/virtual-desktop-sessions.tsx index 4d5e6b9..713e6b3 100644 --- a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/virtual-desktop-sessions.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/virtual-desktop-sessions.tsx @@ -601,7 +601,7 @@ class VirtualDesktopSessions extends Component { + onSubmit={(session_name, username, project_id, base_os, software_stack_id, session_type, instance_type, storage_size, hibernation_enabled, vpc_subnet_id, tags) => { return this.getVirtualDesktopAdminClient() .createSession({ session: { @@ -624,6 +624,7 @@ class VirtualDesktopSessions extends Component { diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/virtual-desktop-software-stack-detail.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/virtual-desktop-software-stack-detail.tsx index a66b180..e6dc559 100644 --- a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/virtual-desktop-software-stack-detail.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/virtual-desktop-software-stack-detail.tsx @@ -48,6 +48,10 @@ class VirtualDesktopSoftwareStackDetail extends Component { this.setState({ diff --git a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/virtual-desktop-software-stacks.tsx b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/virtual-desktop-software-stacks.tsx index a9b35c4..c31c3d5 100644 --- a/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/virtual-desktop-software-stacks.tsx +++ b/source/idea/idea-cluster-manager/webapp/src/pages/virtual-desktops/virtual-desktop-software-stacks.tsx @@ -42,7 +42,7 @@ const VIRTUAL_DESKTOP_SOFTWARE_STACKS_TABLE_COLUMN_DEFINITIONS: TableProps.Colum { id: "name", header: "Name", - cell: (e) => {e.name}, + cell: (e) => {e.name}, }, { id: "description", diff --git a/source/idea/idea-cluster-manager/webapp/src/service/auth-service.ts b/source/idea/idea-cluster-manager/webapp/src/service/auth-service.ts index 89c425a..4bda987 100644 --- a/source/idea/idea-cluster-manager/webapp/src/service/auth-service.ts +++ b/source/idea/idea-cluster-manager/webapp/src/service/auth-service.ts @@ -17,6 +17,7 @@ import IdeaException from "../common/exceptions"; import { AUTH_LOGIN_CHALLENGE, AUTH_PASSWORD_RESET_REQUIRED, UNAUTHORIZED_ACCESS } from "../common/error-codes"; import Utils from "../common/utils"; import { JwtTokenClaims } from "../common/token-utils"; +import { Constants } from "../common/constants"; import { IdeaClients } from "../client"; export interface AuthServiceProps { @@ -58,12 +59,12 @@ class AuthService { * @param username * @param password */ - login(username: string, password: string): Promise { + login(cognito_username: string, password: string): Promise { return this.props.clients .auth() .initiateAuth({ auth_flow: "USER_PASSWORD_AUTH", - username: username, + cognito_username: cognito_username, password: password, }) .then((result) => { @@ -88,7 +89,7 @@ class AuthService { }) .catch((error) => { if (error.errorCode === AUTH_PASSWORD_RESET_REQUIRED) { - this.props.localStorage.setItem(KEY_FORGOT_PASSWORD_USERNAME, username); + this.props.localStorage.setItem(KEY_FORGOT_PASSWORD_USERNAME, cognito_username); } throw error; }); @@ -134,6 +135,13 @@ class AuthService { }); } + getCognitoUsername(): string { + if (this.claims == null) { + return ""; + } + return this.claims.cognito_username; + } + /** * get user name from the JWT token. * the JWT token is not validated or verified. @@ -147,7 +155,7 @@ class AuthService { if (this.claims == null) { return ""; } - return this.claims.username; + return this.claims.db_username; } getEmail(): string { @@ -195,16 +203,6 @@ class AuthService { return expiresIn; } - getGroups(): string[] { - if (this.claims == null) { - return []; - } - if (this.claims.groups == null) { - return []; - } - return this.claims.groups; - } - getAccessToken(): Promise { return this.props.clients.auth().getAccessToken(); } @@ -214,26 +212,9 @@ class AuthService { } isAdmin(): boolean { - const groups = this.getGroups(); - return groups.includes("administrators-cluster-group") || groups.includes("managers-cluster-group"); - } - - hasModuleAccess(moduleName: string): boolean { - if (this.isAdmin()) { - return true; - } - const moduleId = Utils.getModuleId(moduleName); - const groups = this.getGroups(); - return groups.includes(Utils.getUserGroupName(moduleId)) || groups.includes(Utils.getAdministratorGroup(moduleId)); - } - - isModuleAdmin(moduleName: string): boolean { - if (this.isAdmin()) { - return true; - } - const moduleId = Utils.getModuleId(moduleName); - const groups = this.getGroups(); - return groups.includes(`${moduleId}-administrators-module-group`); + if (this.claims == null) + return false + return this.claims.role == Constants.ADMIN_ROLE } /** diff --git a/source/idea/idea-cluster-manager/webapp/src/service/cluster-settings-service.ts b/source/idea/idea-cluster-manager/webapp/src/service/cluster-settings-service.ts index 5db266e..8778721 100644 --- a/source/idea/idea-cluster-manager/webapp/src/service/cluster-settings-service.ts +++ b/source/idea/idea-cluster-manager/webapp/src/service/cluster-settings-service.ts @@ -13,7 +13,9 @@ import { ClusterSettingsClient } from "../client"; import { Constants, ErrorCodes } from "../common/constants"; +import { UNAUTHORIZED_ACCESS } from "../common/error-codes"; import IdeaException from "../common/exceptions"; +import Utils from "../common/utils"; export interface ClusterSettingsServiceProps { clusterSettings: ClusterSettingsClient; @@ -66,6 +68,14 @@ class ClusterSettingsService { }) .catch((error) => { console.error(error); + //This is the first API call which happens from the client side + //If the user is disable and gets unauthorized access, + //Redirect it to /sso which will invoke sign out for user + if (error.errorCode == UNAUTHORIZED_ACCESS) { + if (Utils.isSsoEnabled()) { + window.location.href = "/sso" + } + } return false; }); } @@ -96,12 +106,15 @@ class ClusterSettingsService { } getModuleSet(): any { - return this.globalSettings.module_sets[this.getModuleSetId()]; + if (this.globalSettings != null){ + return this.globalSettings.module_sets[this.getModuleSetId()]; + } + return null; } getModuleId(name: string): string | null { const moduleSet = this.getModuleSet(); - if (name in moduleSet) { + if (moduleSet!= null && name in moduleSet) { return moduleSet[name].module_id; } return null; @@ -223,14 +236,6 @@ class ClusterSettingsService { return this.isModuleDeployed(Constants.MODULE_BASTION_HOST); } - isAnalyticsEnabled(): boolean { - return this.isModuleEnabled(Constants.MODULE_ANALYTICS); - } - - isAnalyticsDeployed(): boolean { - return this.isModuleDeployed(Constants.MODULE_ANALYTICS); - } - isMetricsEnabled(): boolean { return this.isModuleEnabled(Constants.MODULE_METRICS); } @@ -264,10 +269,6 @@ class ClusterSettingsService { return this.getModuleSettings(Constants.MODULE_SHARED_STORAGE); } - getAnalyticsSettings(): Promise { - return this.getModuleSettings(Constants.MODULE_ANALYTICS); - } - getMetricsSettings(): Promise { return this.getModuleSettings(Constants.MODULE_METRICS); } diff --git a/source/idea/idea-data-model/src/ideadatamodel/__init__.py b/source/idea/idea-data-model/src/ideadatamodel/__init__.py index b2d2d2c..1619463 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/__init__.py +++ b/source/idea/idea-data-model/src/ideadatamodel/__init__.py @@ -22,7 +22,6 @@ from .scheduler import * from .virtual_desktop import * from .cluster_settings import * -from .analytics import * from .email_templates import * from .notifications import * from .shared_filesystem import * diff --git a/source/idea/idea-data-model/src/ideadatamodel/api/api_model.py b/source/idea/idea-data-model/src/ideadatamodel/api/api_model.py index 3f315b7..bdc66c0 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/api/api_model.py +++ b/source/idea/idea-data-model/src/ideadatamodel/api/api_model.py @@ -11,6 +11,7 @@ __all__ = ( 'SocaPayload', + 'SocaPayloadType', 'get_payload_as', 'SocaListingPayload', 'SocaHeader', @@ -18,7 +19,9 @@ 'SocaEnvelope', 'SocaAnyPayload', 'SocaBatchResponsePayload', - 'IdeaOpenAPISpecEntry' + 'IdeaOpenAPISpecEntry', + 'ApiAuthorizationType', + 'ApiAuthorization' ) from ideadatamodel import (SocaBaseModel, SocaDateRange, SocaSortBy, SocaPaginator, SocaFilter) @@ -27,6 +30,7 @@ from typing import Optional, Union, TypeVar, Type, List, Any from types import SimpleNamespace from pydantic import Field +from enum import Enum SocaBaseModelType = TypeVar('SocaBaseModelType', bound='SocaBaseModel') @@ -134,3 +138,15 @@ class IdeaOpenAPISpecEntry(SocaBaseModel): result: Type[SocaBaseModel] is_listing: bool is_public: bool + +class ApiAuthorizationType(str, Enum): + ADMINISTRATOR = 'admin' + USER = 'user' + APP = 'app' + +class ApiAuthorization(SocaBaseModel): + type: ApiAuthorizationType + username: Optional[str] # will not exist for APP authorizations + client_id: Optional[str] + scopes: Optional[List[str]] # list of allowed oauth scopes + invocation_source: Optional[str] diff --git a/source/idea/idea-data-model/src/ideadatamodel/auth/auth_api.py b/source/idea/idea-data-model/src/ideadatamodel/auth/auth_api.py index 8b26087..a9a2de3 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/auth/auth_api.py +++ b/source/idea/idea-data-model/src/ideadatamodel/auth/auth_api.py @@ -13,6 +13,8 @@ 'CreateUserResult', 'GetUserRequest', 'GetUserResult', + 'GetUserByEmailRequest', + 'GetUserByEmailResult', 'ModifyUserRequest', 'ModifyUserResult', 'DeleteUserRequest', @@ -99,6 +101,15 @@ class GetUserResult(SocaPayload): user: Optional[User] +#GetUserByEmail + +class GetUserByEmailRequest(SocaPayload): + email: Optional[str] + +class GetUserByEmailResult(SocaPayload): + user: Optional[User] + + # ModifyUser class ModifyUserRequest(SocaPayload): @@ -155,7 +166,7 @@ class ListUsersResult(SocaListingPayload): class InitiateAuthRequest(SocaPayload): client_id: Optional[str] auth_flow: Optional[str] - username: Optional[str] + cognito_username: Optional[str] password: Optional[str] refresh_token: Optional[str] authorization_code: Optional[str] @@ -166,6 +177,8 @@ class InitiateAuthResult(SocaPayload): session: Optional[str] challenge_params: Optional[Dict] auth: Optional[AuthResult] + db_username: Optional[str] + role: Optional[str] # RespondToAuthChallenge @@ -426,6 +439,13 @@ class ConfigureSSOResult(SocaPayload): is_listing=False, is_public=False ), + IdeaOpenAPISpecEntry( + namespace='Accounts.GetUserByEmail', + request=GetUserByEmailRequest, + result=GetUserByEmailResult, + is_listing=False, + is_public=False, + ), IdeaOpenAPISpecEntry( namespace='Accounts.ModifyUser', request=ModifyUserRequest, diff --git a/source/idea/idea-data-model/src/ideadatamodel/cluster_resources/cluster_resources_model.py b/source/idea/idea-data-model/src/ideadatamodel/cluster_resources/cluster_resources_model.py index 6de420f..37bd522 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/cluster_resources/cluster_resources_model.py +++ b/source/idea/idea-data-model/src/ideadatamodel/cluster_resources/cluster_resources_model.py @@ -13,7 +13,6 @@ 'SocaClusterResource', 'SocaVPC', 'SocaCloudFormationStack', - 'SocaOpenSearchDomain', 'SocaDirectory', 'SocaSubnet', 'SocaFileSystem', @@ -101,9 +100,6 @@ def is_cluster(self) -> bool: def is_app(self) -> bool: return self.stack_type == constants.STACK_TYPE_APP - def is_analytics(self) -> bool: - return self.stack_type == constants.STACK_TYPE_ANALYTICS - def is_alb(self) -> bool: return self.stack_type == constants.STACK_TYPE_ALB @@ -111,30 +107,6 @@ def is_job(self) -> bool: return self.stack_type == constants.STACK_TYPE_JOB -class SocaOpenSearchDomain(SocaClusterResource): - - @property - def _domain_status(self) -> Optional[Dict]: - return ModelUtils.get_value_as_dict('DomainStatus', self.ref) - - @property - def vpc_id(self) -> Optional[str]: - vpc_options = ModelUtils.get_value_as_dict('VPCOptions', self._domain_status) - return ModelUtils.get_value_as_string('VPCId', vpc_options) - - @property - def endpoint(self) -> str: - return ModelUtils.get_value_as_string('Endpoint', self.ref) - - @property - def vpc_endpoint(self) -> Optional[str]: - endpoints = ModelUtils.get_value_as_dict('Endpoints', self._domain_status) - if endpoints is None: - return None - for key, value in endpoints.items(): - return value - - class SocaDirectory(SocaClusterResource): @property diff --git a/source/idea/idea-data-model/src/ideadatamodel/constants.py b/source/idea/idea-data-model/src/ideadatamodel/constants.py index 2d272e7..107eccf 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/constants.py +++ b/source/idea/idea-data-model/src/ideadatamodel/constants.py @@ -105,6 +105,10 @@ IDEA_TAG_PROJECT = IDEA_TAG_PREFIX + 'Project' IDEA_TAG_AMI_BUILDER = IDEA_TAG_PREFIX + 'AmiBuilder' +BI_TAG_PREFIX = 'bi:' + +BI_TAG_DEPLOYMENT = BI_TAG_PREFIX + 'Deployment' + IDEA_TAG_NAME = 'Name' IDEA_TAG_JOB_ID = IDEA_TAG_PREFIX + 'JobId' @@ -135,7 +139,6 @@ STACK_TYPE_BOOTSTRAP = 'bootstrap' STACK_TYPE_CLUSTER = 'cluster' -STACK_TYPE_ANALYTICS = 'analytics' STACK_TYPE_APP = 'app' STACK_TYPE_ALB = 'alb' STACK_TYPE_JOB = 'job' @@ -262,7 +265,6 @@ MODULE_IDENTITY_PROVIDER = 'identity-provider' MODULE_DIRECTORYSERVICE = 'directoryservice' MODULE_SHARED_STORAGE = 'shared-storage' -MODULE_ANALYTICS = 'analytics' MODULE_SCHEDULER = 'scheduler' MODULE_CLUSTER_MANAGER = 'cluster-manager' MODULE_VIRTUAL_DESKTOP_CONTROLLER = 'virtual-desktop-controller' @@ -275,7 +277,6 @@ MODULE_IDENTITY_PROVIDER, MODULE_DIRECTORYSERVICE, MODULE_SHARED_STORAGE, - MODULE_ANALYTICS, MODULE_SCHEDULER, MODULE_CLUSTER_MANAGER, MODULE_VIRTUAL_DESKTOP_CONTROLLER, @@ -302,14 +303,6 @@ GROUP_TYPE_EXTERNAL = "external" GROUP_TYPE_INTERNAL = "internal" -RES_ADMIN_GROUPS = [ - "cluster-manager-administrators-module-group", - "scheduler-administrators-module-group", - "vdc-administrators-module-group", - "managers-cluster-group", -] - -RES_USER_GROUPS = ["cluster-manager-users-module-group", "vdc-users-module-group", "scheduler-users-module-group"] ADMIN_ROLE = 'admin' USER_ROLE = 'user' @@ -328,7 +321,6 @@ SERVICE_ID_LEADER_ELECTION = 'leader-election' SERVICE_ID_DISTRIBUTED_LOCK = 'distributed-lock' SERVICE_ID_METRICS = 'metrics-service' -SERVICE_ID_ANALYTICS = 'analytics-service' # idea service account IDEA_SERVICE_ACCOUNT = 'ideaserviceaccount' @@ -414,3 +406,30 @@ # SSO SSO_IDP_PROVIDER_OIDC = 'OIDC' SSO_IDP_PROVIDER_SAML = 'SAML' +SSO_SOURCE_PROVIDER_NAME_REGEX = "^(?!^Cognito$)[\\w._:/-]{1,128}$" +SSO_SOURCE_PROVIDER_NAME_ERROR_MESSAGE = "Only use word character or a single character in the list [\".\", \"_\", \":\", \"/\", \"-\"] for SSO source provider name. " +\ + "Must be between 1 and 128 characters long." +\ + "SourceProviderName may not be Cognito" + +# API Validation Regex and ErrorMessages +FILE_SYSTEM_NAME_REGEX = "^[a-z0-9_]{3,18}$" +FILE_SYSTEM_NAME_ERROR_MESSAGE = "Only use lowercase alphabets, numbers and underscore (_) for file system name. " +\ + "Must be between 3 and 18 characters long." + +MOUNT_DIRECTORY_REGEX = "^/[a-z0-9-]{3,18}$" +MOUNT_DIRECTORY_ERROR_MESSAGE = "Only use lowercase alphabets, numbers, " +\ + "and hyphens (-) for mount directory. Must be between 3 and 18 characters long." + +MOUNT_DRIVE_REGEX = "^[ABD-Z]{1}$" +MOUNT_DRIVE_ERROR_MESSAGE = "Only use an uppercase alphabet for mount drive" +ONTAP_STORAGE_CAPACITY_RANGE = (1024, 196608) + +PROJECT_ID_REGEX = "^[a-z0-9-]{3,18}$" +PROJECT_ID_ERROR_MESSAGE = "Only use lowercase alphabets, numbers, and hyphens (-) for project id. " +\ + "Must be between 3 and 18 characters long." + +SOFTWARE_STACK_NAME_REGEX = SESSION_NAME_REGEX = "^.{3,50}$" +SOFTWARE_STACK_NAME_ERROR_MESSAGE = SESSION_NAME_ERROR_MESSAGE = "Use any characters " +\ + "and form a name of length between 3 and 50 characters, inclusive for software stack name." + +INVALID_RANGE_ERROR_MESSAGE = "Input out of permitted range" diff --git a/source/idea/idea-data-model/src/ideadatamodel/errorcodes.py b/source/idea/idea-data-model/src/ideadatamodel/errorcodes.py index 2eff77c..7574c8d 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/errorcodes.py +++ b/source/idea/idea-data-model/src/ideadatamodel/errorcodes.py @@ -92,9 +92,13 @@ JOB_SUBMISSION_FAILED = 'JOB_SUBMISSION_FAILED' JOB_NOT_FOUND = 'JOB_NOT_FOUND' CONNECTION_ERROR = 'CONNECTION_ERROR' +GID_NOT_FOUND = 'GID_NOT_FOUND' +UID_AND_GID_NOT_FOUND = 'UID_AND_GID_NOT_FOUND' +USER_NOT_AVAILABLE = 'USER_NOT_AVAILABLE' SERVER_IS_SHUTTING_DOWN = 'SERVER_IS_SHUTTING_DOWN' +AUTH_MULTIPLE_USERS_FOUND = 'MULTIPLE_USERS_FOUND' AUTH_USER_IS_DISABLED = 'USER_IS_DISABLED' AUTH_GROUP_IS_DISABLED = 'GROUP_IS_DISABLED' AUTH_USER_NOT_FOUND = 'AUTH_USER_NOT_FOUND' @@ -140,7 +144,15 @@ # ------- FILE SYSTEM ERROR CODES START ------- NO_SHARED_FILESYSTEM_FOUND = 'NO_SHARED_FILESYSTEM_FOUND' FILESYSTEM_NOT_FOUND = 'FILESYSTEM_NOT_FOUND' +FILESYSTEM_ALREADY_ONBOARDED = "FILESYSTEM_ALREADY_ONBOARDED" +FILESYSTEM_NOT_IN_VPC = "FILESYSTEM_NOT_IN_VPC" # ------- FILE SYSTEM ERROR CODES END ------- # Integration Tests INTEGRATION_TEST_FAILED = 'INTEGRATION_TEST_FAILED' + +# ------- APPLY SNAPSHOT ERROR CODES START ------- +TABLE_IMPORT_FAILED = "TABLE_IMPORT_FAILED" +TABLE_MERGE_FAILED = "TABLE_MERGE_FAILED" +TABLE_ROLLBACK_FAILED = "TABLE_ROLLBACK_FAILED" +# ------- APPLY SNAPSHOT ERROR CODES END ------- diff --git a/source/idea/idea-data-model/src/ideadatamodel/exceptions/exception_utils.py b/source/idea/idea-data-model/src/ideadatamodel/exceptions/exception_utils.py index 24f8786..0fcb767 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/exceptions/exception_utils.py +++ b/source/idea/idea-data-model/src/ideadatamodel/exceptions/exception_utils.py @@ -20,7 +20,10 @@ 'general_exception', 'app_not_found', 'invalid_session', - 'cluster_config_error' + 'cluster_config_error', + 'table_import_failed', + 'table_merge_failed', + 'table_rollback_failed' ) from ideadatamodel.exceptions import SocaException @@ -133,3 +136,21 @@ def app_not_found(message: str) -> SocaException: error_code=errorcodes.APP_NOT_FOUND, message=message ) + +def table_import_failed(message: str) -> SocaException: + return SocaException( + error_code=errorcodes.TABLE_IMPORT_FAILED, + message=message + ) + +def table_merge_failed(message: str) -> SocaException: + return SocaException( + error_code=errorcodes.TABLE_MERGE_FAILED, + message=message + ) + +def table_rollback_failed(message: str) -> SocaException: + return SocaException( + error_code=errorcodes.TABLE_ROLLBACK_FAILED, + message=message + ) diff --git a/source/idea/idea-data-model/src/ideadatamodel/projects/projects_api.py b/source/idea/idea-data-model/src/ideadatamodel/projects/projects_api.py index 6e9916b..e4ea32c 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/projects/projects_api.py +++ b/source/idea/idea-data-model/src/ideadatamodel/projects/projects_api.py @@ -181,4 +181,11 @@ class ListFileSystemsForProjectResult(SocaListingPayload): is_listing=True, is_public=False ), + IdeaOpenAPISpecEntry( + namespace='Projects.DeleteProject', + request=DeleteProjectRequest, + result=DeleteProjectResult, + is_listing=False, + is_public=False + ) ] diff --git a/source/idea/idea-data-model/src/ideadatamodel/projects/projects_model.py b/source/idea/idea-data-model/src/ideadatamodel/projects/projects_model.py index 493ffec..50ddafa 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/projects/projects_model.py +++ b/source/idea/idea-data-model/src/ideadatamodel/projects/projects_model.py @@ -26,13 +26,39 @@ class Project(SocaBaseModel): title: Optional[str] description: Optional[str] enabled: Optional[bool] - ldap_groups: Optional[List[str]] + ldap_groups: Optional[List[str]] = [] + users: Optional[List[str]] enable_budgets: Optional[bool] budget: Optional[AwsProjectBudget] tags: Optional[List[SocaKeyValue]] created_on: Optional[datetime] updated_on: Optional[datetime] + def __eq__(self, other): + eq = True + eq = eq and self.name == other.name + eq = eq and self.title == other.title + eq = eq and self.description == other.description + eq = eq and self.enable_budgets == other.enable_budgets + eq = eq and self.budget == other.budget + + self_ldap_groups = self.ldap_groups if self.ldap_groups else [] + other_ldap_groups = other.ldap_groups if other.ldap_groups else [] + eq = eq and len(self_ldap_groups) == len(other_ldap_groups) + eq = eq and all(ldap_group in self_ldap_groups for ldap_group in other_ldap_groups) + + self_users = self.users if self.users else [] + other_users = other.users if other.users else [] + eq = eq and len(self_users) == len(other_users) + eq = eq and all(user in self_users for user in other_users) + + self_tags = self.tags if self.tags else [] + other_tags = other.tags if other.tags else [] + eq = eq and len(self_tags) == len(other_tags) + eq = eq and all(tag in self_tags for tag in other_tags) + + return eq + def is_enabled(self) -> bool: return ModelUtils.get_as_bool(self.enabled, False) diff --git a/source/idea/idea-data-model/src/ideadatamodel/shared_filesystem/shared_filesystem_api.py b/source/idea/idea-data-model/src/ideadatamodel/shared_filesystem/shared_filesystem_api.py index 1c3e4ab..f5d2939 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/shared_filesystem/shared_filesystem_api.py +++ b/source/idea/idea-data-model/src/ideadatamodel/shared_filesystem/shared_filesystem_api.py @@ -16,6 +16,7 @@ 'OnboardEFSFileSystemRequest', 'OnboardONTAPFileSystemRequest', 'OnboardFileSystemResult', + 'OffboardFileSystemRequest', "OPEN_API_SPEC_ENTRIES_FILESYSTEM", ) @@ -120,6 +121,11 @@ class OnboardONTAPFileSystemRequest(CommonOnboardFileSystemRequest): class OnboardFileSystemResult(SocaPayload): pass + +class OffboardFileSystemRequest(SocaPayload): + filesystem_name: str + + OPEN_API_SPEC_ENTRIES_FILESYSTEM = [ IdeaOpenAPISpecEntry( namespace="FileSystem.AddFileSystemToProject", diff --git a/source/idea/idea-data-model/src/ideadatamodel/snapshots/__init__.py b/source/idea/idea-data-model/src/ideadatamodel/snapshots/__init__.py index c216f53..d2f2320 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/snapshots/__init__.py +++ b/source/idea/idea-data-model/src/ideadatamodel/snapshots/__init__.py @@ -9,5 +9,5 @@ # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions # and limitations under the License. -from .snapshot_model import Snapshot, SnapshotStatus +from .snapshot_model import Snapshot, SnapshotStatus, ApplySnapshot, ApplySnapshotStatus, TableKeys, TableName, RESVersion from .snapshots_api import * diff --git a/source/idea/idea-data-model/src/ideadatamodel/snapshots/snapshot_model.py b/source/idea/idea-data-model/src/ideadatamodel/snapshots/snapshot_model.py index 6995fcc..849a9de 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/snapshots/snapshot_model.py +++ b/source/idea/idea-data-model/src/ideadatamodel/snapshots/snapshot_model.py @@ -10,8 +10,13 @@ # and limitations under the License. __all__ = ( + 'ApplySnapshot', + 'ApplySnapshotStatus', 'Snapshot', - 'SnapshotStatus' + 'SnapshotStatus', + 'TableKeys', + 'TableName', + 'RESVersion' ) from ideadatamodel import SocaBaseModel @@ -25,8 +30,46 @@ class SnapshotStatus(str, Enum): COMPLETED = 'COMPLETED' FAILED = 'FAILED' + class Snapshot(SocaBaseModel): s3_bucket_name: Optional[str] snapshot_path: Optional[str] status: Optional[SnapshotStatus] created_on: Optional[datetime] + failure_reason: Optional[str] + + +class ApplySnapshotStatus(str, Enum): + IN_PROGRESS = 'IN_PROGRESS' + COMPLETED = 'COMPLETED' + FAILED = 'FAILED' + ROLLBACK_IN_PROGRESS = "ROLLBACK_IN_PROGRESS" + ROLLBACK_COMPLETE = "ROLLBACK_COMPLETE" + ROLLBACE_FAILED = "ROLLBACK_FAILED" + + +class ApplySnapshot(SocaBaseModel): + apply_snapshot_identifier: Optional[str] + s3_bucket_name: Optional[str] + snapshot_path: Optional[str] + status: Optional[ApplySnapshotStatus] + created_on: Optional[datetime] + failure_reason: Optional[str] + + +class TableKeys(SocaBaseModel): + partition_key: str + sort_key: Optional[str] + + +class TableName(str, Enum): + CLUSTER_SETTINGS_TABLE_NAME = "cluster-settings" + USERS_TABLE_NAME = "accounts.users" + PROJECTS_TABLE_NAME = "projects" + PERMISSION_PROFILES_TABLE_NAME = "vdc.controller.permission-profiles" + SOFTWARE_STACKS_TABLE_NAME = "vdc.controller.software-stacks" + + +class RESVersion(str, Enum): + v_2023_11 = "2023.11" + v_2024_01 = "2024.01" diff --git a/source/idea/idea-data-model/src/ideadatamodel/snapshots/snapshots_api.py b/source/idea/idea-data-model/src/ideadatamodel/snapshots/snapshots_api.py index 73c5eb3..6f5054f 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/snapshots/snapshots_api.py +++ b/source/idea/idea-data-model/src/ideadatamodel/snapshots/snapshots_api.py @@ -13,11 +13,15 @@ 'CreateSnapshotResult', 'ListSnapshotsRequest', 'ListSnapshotsResult', + 'ApplySnapshotRequest', + 'ApplySnapshotResult', + 'ListApplySnapshotRecordsRequest', + 'ListApplySnapshotRecordsResult', 'OPEN_API_SPEC_ENTRIES_SNAPSHOTS' ) from ideadatamodel.api import SocaPayload, IdeaOpenAPISpecEntry, SocaListingPayload -from ideadatamodel.snapshots.snapshot_model import Snapshot +from ideadatamodel.snapshots.snapshot_model import Snapshot, ApplySnapshot from typing import List, Optional @@ -28,6 +32,7 @@ class CreateSnapshotRequest(SocaPayload): class CreateSnapshotResult(SocaPayload): snapshot: Optional[Snapshot] + message: Optional[str] class ListSnapshotsRequest(SocaListingPayload): pass @@ -35,6 +40,23 @@ class ListSnapshotsRequest(SocaListingPayload): class ListSnapshotsResult(SocaListingPayload): listing: Optional[List[Snapshot]] + + +class ApplySnapshotRequest(SocaPayload): + snapshot: Optional[Snapshot] + + +class ApplySnapshotResult(SocaPayload): + snapshot: Optional[Snapshot] + message: Optional[str] + + +class ListApplySnapshotRecordsRequest(SocaListingPayload): + pass + + +class ListApplySnapshotRecordsResult(SocaListingPayload): + listing: Optional[List[ApplySnapshot]] OPEN_API_SPEC_ENTRIES_SNAPSHOTS = [ @@ -52,4 +74,11 @@ class ListSnapshotsResult(SocaListingPayload): is_listing=True, is_public=False ), + IdeaOpenAPISpecEntry( + namespace='Snapshots.ApplySnapshot', + request=ApplySnapshotRequest, + result=ApplySnapshotResult, + is_listing=False, + is_public=False + ), ] diff --git a/source/idea/idea-data-model/src/ideadatamodel/virtual_desktop/virtual_desktop_api.py b/source/idea/idea-data-model/src/ideadatamodel/virtual_desktop/virtual_desktop_api.py index 0af0713..59b94d2 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/virtual_desktop/virtual_desktop_api.py +++ b/source/idea/idea-data-model/src/ideadatamodel/virtual_desktop/virtual_desktop_api.py @@ -34,6 +34,8 @@ 'ListSessionsRequest', 'CreateSoftwareStackRequest', 'CreateSoftwareStackResponse', + 'DeleteSoftwareStackRequest', + 'DeleteSoftwareStackResponse', 'UpdateSoftwareStackRequest', 'UpdateSoftwareStackResponse', 'GetSoftwareStackInfoRequest', @@ -58,10 +60,6 @@ 'ListAllowedInstanceTypesResponse', 'ListAllowedInstanceTypesForSessionRequest', 'ListAllowedInstanceTypesForSessionResponse', - 'ReIndexUserSessionsRequest', - 'ReIndexUserSessionsResponse', - 'ReIndexSoftwareStacksRequest', - 'ReIndexSoftwareStacksResponse', 'ListPermissionProfilesRequest', 'ListPermissionProfilesResponse', 'GetPermissionProfileRequest', @@ -70,6 +68,8 @@ 'CreatePermissionProfileRequest', 'UpdatePermissionProfileRequest', 'UpdatePermissionProfileResponse', + 'DeletePermissionProfileRequest', + 'DeletePermissionProfileResponse', 'GetBasePermissionsRequest', 'GetBasePermissionsResponse', 'UpdateSessionPermissionRequest', @@ -158,6 +158,7 @@ class GetSessionInfoResponse(SocaPayload): # VirtualDesktopAdmin.GetSoftwareStackInfo - Request class GetSoftwareStackInfoRequest(SocaPayload): stack_id: Optional[str] + base_os: Optional[str] # VirtualDesktopAdmin.GetSoftwareStackInfo - Response @@ -250,6 +251,16 @@ class CreateSoftwareStackResponse(SocaPayload): software_stack: Optional[VirtualDesktopSoftwareStack] +# VirtualDesktopAdmin.DeleteSoftwareStack - Request +class DeleteSoftwareStackRequest(SocaPayload): + software_stack: Optional[VirtualDesktopSoftwareStack] + + +# VirtualDesktopAdmin.DeleteSoftwareStack - Response +class DeleteSoftwareStackResponse(SocaPayload): + pass + + # VirtualDesktopAdmin.UpdateSoftwareStack - Request class UpdateSoftwareStackRequest(SocaPayload): software_stack: Optional[VirtualDesktopSoftwareStack] @@ -350,25 +361,6 @@ class ListAllowedInstanceTypesForSessionResponse(SocaListingPayload): # VirtualDesktopAdmin.CreateSharedSession # VirtualDesktopAdmin.DeleteSharedSession -# VirtualDesktopAdmin.ReIndexUserSessions - Request -class ReIndexUserSessionsRequest(SocaPayload): - pass - - -# VirtualDesktopAdmin.ReIndexUserSessions - Response -class ReIndexUserSessionsResponse(SocaPayload): - pass - - -# VirtualDesktopAdmin.ReIndexSoftwareStacks - Request -class ReIndexSoftwareStacksRequest(SocaPayload): - pass - - -# VirtualDesktopAdmin.ReIndexSoftwareStacks - Response -class ReIndexSoftwareStacksResponse(SocaPayload): - pass - # VirtualDesktopUtils.ListPermissionProfiles - Request class ListPermissionProfilesRequest(SocaListingPayload): @@ -410,6 +402,16 @@ class CreatePermissionProfileResponse(SocaPayload): profile: Optional[VirtualDesktopPermissionProfile] +# VirtualDesktopAdmin.DeletePermissionProfile - Request +class DeletePermissionProfileRequest(SocaPayload): + profile_id: Optional[str] + + +# VirtualDesktopAdmin.DeletePermissionProfile - Response +class DeletePermissionProfileResponse(SocaPayload): + pass + + # VirtualDesktopUtils.GetBasePermissions - Request class GetBasePermissionsRequest(SocaPayload): pass @@ -636,6 +638,13 @@ class ListPermissionsResponse(SocaListingPayload): is_listing=False, is_public=False ), + IdeaOpenAPISpecEntry( + namespace='VirtualDesktopAdmin.DeleteSoftwareStack', + request=DeleteSoftwareStackRequest, + result=DeleteSoftwareStackResponse, + is_listing=False, + is_public=False + ), IdeaOpenAPISpecEntry( namespace='VirtualDesktopAdmin.UpdateSoftwareStack', request=UpdateSoftwareStackRequest, @@ -678,6 +687,13 @@ class ListPermissionsResponse(SocaListingPayload): is_listing=False, is_public=False ), + IdeaOpenAPISpecEntry( + namespace='VirtualDesktopAdmin.DeletePermissionProfile', + request=DeletePermissionProfileRequest, + result=DeletePermissionProfileResponse, + is_listing=False, + is_public=False + ), IdeaOpenAPISpecEntry( namespace='VirtualDesktopAdmin.ListSessionPermissions', request=ListPermissionsRequest, @@ -699,20 +715,6 @@ class ListPermissionsResponse(SocaListingPayload): is_listing=False, is_public=False ), - IdeaOpenAPISpecEntry( - namespace='VirtualDesktopAdmin.ReIndexUserSessions', - request=ReIndexUserSessionsRequest, - result=ReIndexUserSessionsResponse, - is_listing=False, - is_public=False - ), - IdeaOpenAPISpecEntry( - namespace='VirtualDesktopAdmin.ReIndexSoftwareStacks', - request=ReIndexSoftwareStacksRequest, - result=ReIndexSoftwareStacksResponse, - is_listing=False, - is_public=False - ), # VirtualDesktopAdmin.* ENDS # # VirtualDesktopUtils.* STARTS # IdeaOpenAPISpecEntry( diff --git a/source/idea/idea-data-model/src/ideadatamodel/virtual_desktop/virtual_desktop_model.py b/source/idea/idea-data-model/src/ideadatamodel/virtual_desktop/virtual_desktop_model.py index 48f5ae9..7c3cc1d 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/virtual_desktop/virtual_desktop_model.py +++ b/source/idea/idea-data-model/src/ideadatamodel/virtual_desktop/virtual_desktop_model.py @@ -135,7 +135,24 @@ class VirtualDesktopSoftwareStack(SocaBaseModel): min_ram: Optional[SocaMemory] architecture: Optional[VirtualDesktopArchitecture] gpu: Optional[VirtualDesktopGPU] - projects: Optional[List[Project]] + projects: Optional[List[Project]] = [] + + def __eq__(self, other): + eq = True + eq = eq and self.base_os == other.base_os + eq = eq and self.name == other.name + eq = eq and self.description == other.description + eq = eq and self.ami_id == other.ami_id + eq = eq and self.min_storage == other.min_storage + eq = eq and self.min_ram == other.min_ram + eq = eq and self.gpu == other.gpu + + self_project_ids = [project.project_id for project in self.projects] if self.projects else [] + other_project_ids = [project.project_id for project in other.projects] if other.projects else [] + eq = eq and len(self_project_ids) == len(other_project_ids) + eq = eq and all(project_id in self_project_ids for project_id in other_project_ids) + + return eq class VirtualDesktopServer(SocaBaseModel): @@ -201,6 +218,19 @@ class VirtualDesktopPermissionProfile(SocaBaseModel): created_on: Optional[datetime] updated_on: Optional[datetime] + def __eq__(self, other): + eq = True + eq = eq and self.profile_id == other.profile_id + eq = eq and self.title == other.title + eq = eq and self.description == other.description + + self_permissions = self.permissions if self.permissions else [] + other_permissions = other.permissions if other.permissions else [] + eq = eq and len(self_permissions) == len(other_permissions) + eq = eq and all(permission in self_permissions for permission in other_permissions) + + return eq + def get_permission(self, permission_key: str) -> Optional[VirtualDesktopPermission]: if self.permissions is None: return None @@ -251,6 +281,7 @@ class VirtualDesktopSession(SocaBaseModel): hibernation_enabled: Optional[bool] is_launched_by_admin: Optional[bool] locked: Optional[bool] + tags: Optional[list[dict]] # Transient field, to be used for API responses only. failure_reason: Optional[str] diff --git a/source/idea/idea-data-model/src/ideadatamodel_meta/__init__.py b/source/idea/idea-data-model/src/ideadatamodel_meta/__init__.py index 604e9a8..c6d7c43 100644 --- a/source/idea/idea-data-model/src/ideadatamodel_meta/__init__.py +++ b/source/idea/idea-data-model/src/ideadatamodel_meta/__init__.py @@ -10,4 +10,4 @@ # and limitations under the License. __name__ = 'idea-data-model' -__version__ = '2023.11' +__version__ = '2024.01' diff --git a/source/idea/idea-sdk/src/ideasdk/analytics/analytics_service.py b/source/idea/idea-sdk/src/ideasdk/analytics/analytics_service.py deleted file mode 100644 index 5701732..0000000 --- a/source/idea/idea-sdk/src/ideasdk/analytics/analytics_service.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance -# with the License. A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES -# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions -# and limitations under the License. -import threading -from enum import Enum -from typing import Optional, Dict, List - -from ideadatamodel import SocaBaseModel -from ideadatamodel.constants import SERVICE_ID_ANALYTICS -from ideasdk.aws.opensearch.aws_opensearch_client import AwsOpenSearchClient -from ideasdk.protocols import SocaContextProtocol, AnalyticsServiceProtocol -from ideasdk.service import SocaService -from ideasdk.utils import Utils - - -class EntryAction(str, Enum): - UPDATE_ENTRY = 'UPDATE_ENTRY' - DELETE_ENTRY = 'DELETE_ENTRY' - CREATE_ENTRY = 'CREATE_ENTRY' - - -class EntryContent(SocaBaseModel): - index_id: Optional[str] - entry_record: Optional[dict] - - -class AnalyticsEntry(SocaBaseModel): - entry_id: str - entry_action: EntryAction - entry_content: EntryContent - - -class AnalyticsService(SocaService, AnalyticsServiceProtocol): - - def __init__(self, context: SocaContextProtocol): - super().__init__(context=context) - self.MAX_BUFFER_SIZE = 20 - self.MAX_WAIT_TIME_MS = 1000 - self.context = context - self._logger = context.logger('analytics-service') - self._buffer_lock = threading.RLock() - self._buffer_size_limit_reached_condition = threading.Condition() - self._buffer: Optional[List[AnalyticsEntry]] = [] - self._buffer_processing_thread: Optional[threading.Thread] = None - self._exit = threading.Event() - self.os_client = AwsOpenSearchClient(self.context) - self._initialize() - - def service_id(self) -> str: - return SERVICE_ID_ANALYTICS - - def start(self): - pass - - def _initialize(self): - self._logger.info('Starting Analytics Service ...') - # until I get positive status from Open Search -> wait here. - self._buffer_processing_thread = threading.Thread( - name='buffer-processing-thread', - target=self._process_buffer - ) - self._buffer_processing_thread.start() - - def _process_buffer(self): - while not self._exit.is_set(): - self._buffer_size_limit_reached_condition.acquire() - self._buffer_size_limit_reached_condition.wait(timeout=self.MAX_WAIT_TIME_MS / 1000) - self._buffer_size_limit_reached_condition.release() - self._post_entries_to_kinesis() - - def _post_entries_to_kinesis(self): - with self._buffer_lock: - records = [] - if Utils.is_empty(self._buffer): - return - - for entry in self._buffer: - records.append({ - 'Data': Utils.to_bytes(Utils.to_json({ - 'index_id': entry.entry_content.index_id, - 'entry': entry.entry_content.entry_record, - 'document_id': entry.entry_id, - 'action': entry.entry_action, - 'timestamp': Utils.current_time_ms() - })), - 'PartitionKey': entry.entry_id - }) - - stream_name = self.context.config().get_string('analytics.kinesis.stream_name', required=True) - self._logger.info(f'posting {len(records)} record(s) to analytics stream...') - response = self.context.aws().kinesis().put_records( - Records=records, - StreamName=stream_name - ) - # TODO: handle failure/success - self._buffer = [] - - def _enforce_buffer_processing(self): - try: - self._buffer_size_limit_reached_condition.acquire() - self._buffer_size_limit_reached_condition.notify_all() - finally: - self._buffer_size_limit_reached_condition.release() - - def post_entry(self, document: AnalyticsEntry): - with self._buffer_lock: - self._buffer.append(document) - self._logger.debug(f'Added entry to buffer ... Buffer Size: {len(self._buffer)}') - - if len(self._buffer) >= self.MAX_BUFFER_SIZE: - self._enforce_buffer_processing() - - def initialize_template(self, template_name: str, template_body: Dict) -> int: - new_version = Utils.get_value_as_int('version', template_body, default=1) - self._logger.info(f'new template version for {template_name} is {new_version}') - current_template = self.os_client.get_template(template_name) - if Utils.is_not_empty(current_template): - self._logger.info(f'current template is not empty') - self._logger.debug(current_template) - current_version = Utils.get_value_as_int('version', Utils.get_value_as_dict(template_name, current_template, {}), default=-1) - self._logger.info(f'current version = {current_version}') - if int(new_version) <= int(current_version): - self._logger.info('Trying to add the same or older version of the template. Ignoring request.') - return new_version - - response = self.os_client.put_template(template_name, template_body) - if not Utils.get_value_as_bool('acknowledged', response, default=False): - self._logger.info(f'There is some error. Need to check later. response: {response}') - self._logger.error(f'new version being returned is {new_version}') - return new_version - - def stop(self): - self._logger.error('Stopping Analytics Service ...') - self._exit.set() - self._enforce_buffer_processing() - if self._buffer_processing_thread is not None and self._buffer_processing_thread.is_alive(): - self._buffer_processing_thread.join() diff --git a/source/idea/idea-sdk/src/ideasdk/api/api_invocation_context.py b/source/idea/idea-sdk/src/ideasdk/api/api_invocation_context.py index 8eab93c..2767898 100644 --- a/source/idea/idea-sdk/src/ideasdk/api/api_invocation_context.py +++ b/source/idea/idea-sdk/src/ideasdk/api/api_invocation_context.py @@ -10,11 +10,12 @@ # and limitations under the License. from ideasdk.protocols import SocaContextProtocol, ApiInvocationContextProtocol, SocaContextProtocolType -from ideasdk.auth import TokenService, ApiAuthorization, ApiAuthorizationType +from ideasdk.auth import TokenService, ApiAuthorizationServiceBase +from ideadatamodel.api.api_model import ApiAuthorization, ApiAuthorizationType from ideasdk.utils import Utils, GroupNameHelper from ideadatamodel import constants, get_payload_as, SocaEnvelope, SocaPayload, exceptions, errorcodes -from typing import Optional, Dict, Type, TypeVar, Union, List +from typing import Optional, Dict, Mapping, Type, TypeVar, Union, List import logging import re @@ -36,18 +37,22 @@ class ApiInvocationContext(ApiInvocationContextProtocol): def __init__(self, context: SocaContextProtocol, request: Union[Dict, SocaEnvelope], + http_headers: Optional[Mapping], invocation_source: str, group_name_helper: GroupNameHelper, logger: logging.Logger, token: Optional[Dict] = None, - token_service: Optional[TokenService] = None): + token_service: Optional[TokenService] = None, + api_authorization_service: Optional[ApiAuthorizationServiceBase] = None): self._context = context self._request = request + self._http_headers = http_headers self._invocation_source = invocation_source self._group_name_helper = group_name_helper self._logger = logger self._token = token self._token_service = token_service + self._api_authorization_service = api_authorization_service self._start_time = Utils.current_time_ms() self._total_time: Optional[int] = None @@ -97,9 +102,10 @@ def is_scope_authorized(self, scope: str) -> bool: access_token = self.access_token if Utils.is_empty(access_token): return False - if self._token_service is None: + if not self._api_authorization_service or not self._token_service: return False - return self._token_service.is_scope_authorized(access_token=access_token, scope=scope) + decoded_token = self._token_service.decode_token(token=access_token) + return self._api_authorization_service.is_scope_authorized(decoded_token=decoded_token, scope=scope) def is_unix_domain_socket_invocation(self) -> bool: """ @@ -110,49 +116,33 @@ def is_unix_domain_socket_invocation(self) -> bool: """ return self._invocation_source == 'unix-socket' - def is_administrator(self) -> bool: + def is_administrator(self, authorization: ApiAuthorization = None) -> bool: """ - check if user has "cluster administrator" access - cluster administrator == sudo user. - - a "module administrator" does not imply sudo access, and is equivalent to having a "cluster manager" access scoped to the current module. + check if user has "administrator" access """ - authorization = self.get_authorization() + if not authorization: + authorization = self.get_authorization() return authorization.type == ApiAuthorizationType.ADMINISTRATOR - def is_manager(self) -> bool: - """ - check if user has "cluster manager" access or "module administrator" access - """ - authorization = self.get_authorization() - - if authorization.type == ApiAuthorizationType.MANAGER: - return True - - if authorization.type == ApiAuthorizationType.USER: - module_administrators_group = self._group_name_helper.get_module_administrators_group(module_id=self._context.module_id()) - return module_administrators_group in authorization.groups - - return False - - def is_authenticated_user(self) -> bool: + def is_authenticated_user(self, authorization: ApiAuthorization = None) -> bool: """ allow any request as long as the token is issued to a valid user verify the token, but don't check for any scope access """ - authorization = self.get_authorization() + if not authorization: + authorization = self.get_authorization() return authorization.type in ( ApiAuthorizationType.USER, - ApiAuthorizationType.ADMINISTRATOR, - ApiAuthorizationType.MANAGER + ApiAuthorizationType.ADMINISTRATOR ) - def is_authenticated_app(self) -> bool: + def is_authenticated_app(self, authorization: ApiAuthorization = None) -> bool: """ allow any request as long as the token is issued to a valid app verify the token, but don't check for any scope access """ - authorization = self.get_authorization() + if not authorization: + authorization = self.get_authorization() return authorization.type == ApiAuthorizationType.APP def is_authenticated(self) -> bool: @@ -160,32 +150,26 @@ def is_authenticated(self) -> bool: allow any request as long as the token is valid verify the token, but don't check for any groups or scope access """ - self.get_authorization() + authorization = self.get_authorization() return True - def is_authorized_user(self) -> bool: + def is_authorized_user(self, authorization: ApiAuthorization = None) -> bool: """ - lock down API access such that only if a user is added to the -users-module-group, then grant access - administrators and managers get access implicitly + check if a user is authorized + admins are implicitly authorized """ - # administrator with sudo access - if self.is_administrator(): - return True - # cluster manager or module administrator - if self.is_manager(): + if not authorization: + authorization = self.get_authorization() + + # administrator with sudo access + if self.is_administrator(authorization=authorization): return True + return authorization.type == ApiAuthorizationType.USER - # regular user with module access - authorization = self.get_authorization() - if authorization.type == ApiAuthorizationType.USER: - module_users_group = self._group_name_helper.get_module_users_group(module_id=self._context.module_id()) - return module_users_group in authorization.groups - - return False - - def is_authorized_app(self) -> bool: - authorization = self.get_authorization() + def is_authorized_app(self, authorization: ApiAuthorization = None) -> bool: + if not authorization: + authorization = self.get_authorization() if authorization.type != ApiAuthorizationType.APP: return False if authorization.scopes is None or len(authorization.scopes) == 0: @@ -215,14 +199,15 @@ def is_authorized(self, elevated_access: bool, scopes: Optional[List[str]] = Non :param List[str] scopes: the applicable scopes to be checked :return: """ + authorization = self.get_authorization() if elevated_access: - authorized_user = self.is_administrator() or self.is_manager() + authorized_user = self.is_administrator(authorization=authorization) else: - authorized_user = self.is_authorized_user() + authorized_user = self.is_authorized_user(authorization=authorization) authorized_scopes = False if Utils.is_not_empty(scopes): - authorized_app = self.is_authorized_app() + authorized_app = self.is_authorized_app(authorization=authorization) if authorized_app: all_scopes_authorized = True for scope in scopes: @@ -238,7 +223,7 @@ def get_username(self) -> Optional[str]: get username from the JWT access token should be used for all authenticated APIs to ensure username cannot be spoofed via any payload parameters """ - if self._token_service is None: + if not self._api_authorization_service : return None authorization = self.get_authorization() return authorization.username @@ -360,12 +345,26 @@ def get_authorization(self) -> ApiAuthorization: type=ApiAuthorizationType.ADMINISTRATOR, invocation_source=constants.API_INVOCATION_SOURCE_UNIX_SOCKET ) - else: - if not self.has_access_token(): - raise exceptions.unauthorized_access() + elif self.has_access_token(): decoded_token = self.get_decoded_token() - result = self._token_service.get_authorization(decoded_token) + result = self._api_authorization_service.get_authorization(decoded_token) result.invocation_source = constants.API_INVOCATION_SOURCE_HTTP + elif Utils.is_test_mode(): + username = Utils.get_value_as_string('X_RES_TEST_USERNAME', self._http_headers) + self._logger.warning(f'User {username} is accessing API without a valid token; ' + 'This is only possible because the environment variable RES_TEST_MODE is set to True. ' + 'If this is not intended, remove the environment variable RES_TEST_MODE and restart the server.') + + if not username: + raise exceptions.unauthorized_access() + user = self._api_authorization_service.get_user_from_token_username(token_username=username) + result = ApiAuthorization( + username=username, + type=self._api_authorization_service.get_authorization_type(role=user.role), + invocation_source=constants.API_INVOCATION_SOURCE_HTTP + ) + else: + raise exceptions.unauthorized_access() self._authorization = result return self._authorization diff --git a/source/idea/idea-sdk/src/ideasdk/app/soca_app.py b/source/idea/idea-sdk/src/ideasdk/app/soca_app.py index 80823f6..d4129d3 100644 --- a/source/idea/idea-sdk/src/ideasdk/app/soca_app.py +++ b/source/idea/idea-sdk/src/ideasdk/app/soca_app.py @@ -208,10 +208,6 @@ def stop(self): self.app_stop() - analytics_service = self.context.service_registry().get_service(constants.SERVICE_ID_ANALYTICS) - if analytics_service is not None: - analytics_service.stop() - leader_election_service = self.context.service_registry().get_service(constants.SERVICE_ID_LEADER_ELECTION) if leader_election_service is not None: leader_election_service.stop() diff --git a/source/idea/idea-sdk/src/ideasdk/auth/__init__.py b/source/idea/idea-sdk/src/ideasdk/auth/__init__.py index acaa767..6607779 100644 --- a/source/idea/idea-sdk/src/ideasdk/auth/__init__.py +++ b/source/idea/idea-sdk/src/ideasdk/auth/__init__.py @@ -10,3 +10,4 @@ # and limitations under the License. from ideasdk.auth.token_service import * +from ideasdk.auth.api_authorization_service_base import * diff --git a/source/idea/idea-sdk/src/ideasdk/auth/api_authorization_service_base.py b/source/idea/idea-sdk/src/ideasdk/auth/api_authorization_service_base.py new file mode 100644 index 0000000..d1f9b0e --- /dev/null +++ b/source/idea/idea-sdk/src/ideasdk/auth/api_authorization_service_base.py @@ -0,0 +1,74 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from ideasdk.protocols import ApiAuthorizationServiceProtocol +from ideadatamodel.api.api_model import ApiAuthorization, ApiAuthorizationType +from ideadatamodel.auth import User +from ideadatamodel import exceptions, errorcodes +from typing import Optional, Dict +from abc import abstractmethod + +class ApiAuthorizationServiceBase(ApiAuthorizationServiceProtocol): + + @abstractmethod + def get_user_from_token_username(self, token_username: str) -> Optional[User]: + ... + + def get_authorization_type(self, role: Optional[str]) -> ApiAuthorizationType: + authorization_type = None + if role: + if role == ApiAuthorizationType.ADMINISTRATOR: + authorization_type = ApiAuthorizationType.ADMINISTRATOR + else: + authorization_type = ApiAuthorizationType.USER + if not authorization_type: + authorization_type = ApiAuthorizationType.USER + return authorization_type + + def get_authorization(self, decoded_token: Optional[Dict]) -> ApiAuthorization: + username = decoded_token.get('username') + token_scope = decoded_token.get('scope') + client_id = decoded_token.get('client_id') + authorization_type, role, db_username, scopes = None, None, None, None + if not username: + authorization_type = ApiAuthorizationType.APP + if token_scope: + scopes = token_scope.split(' ') + else: + user = self.get_user_from_token_username(username) + username = user.username + if not user.enabled: + raise exceptions.unauthorized_access(errorcodes.AUTH_USER_IS_DISABLED) + authorization_type = self.get_authorization_type(user.role) + + return ApiAuthorization( + type=authorization_type, + username=username, + scopes=scopes, + client_id=client_id + ) + + def is_scope_authorized(self, decoded_token: str, scope: str) -> bool: + if not decoded_token: + return False + if not scope: + return False + authorization = self.get_authorization(decoded_token) + if authorization.type != ApiAuthorizationType.APP: + return False + return authorization.scopes and scope in authorization.scopes + + def get_username(self, decoded_token: str) -> Optional[str]: + if not decoded_token: + return None + token_username = decoded_token.get('username') + user = self.get_user_from_token_username(token_username) + return user.username diff --git a/source/idea/idea-sdk/src/ideasdk/auth/token_service.py b/source/idea/idea-sdk/src/ideasdk/auth/token_service.py index 5aed250..31fec86 100644 --- a/source/idea/idea-sdk/src/ideasdk/auth/token_service.py +++ b/source/idea/idea-sdk/src/ideasdk/auth/token_service.py @@ -11,14 +11,14 @@ __all__ = ( 'TokenServiceOptions', - 'ApiAuthorization', - 'ApiAuthorizationType', 'TokenService' ) from ideasdk.protocols import SocaContextProtocol, TokenServiceProtocol -from ideadatamodel import SocaBaseModel, AuthResult, exceptions, errorcodes +from ideadatamodel import SocaBaseModel, AuthResult, exceptions, errorcodes, constants from ideasdk.utils import Utils +from ideadatamodel.auth import User +from ideadatamodel.auth import GetUserByEmailRequest, GetUserRequest from typing import Optional, Dict, List import jwt @@ -55,22 +55,6 @@ class TokenServiceOptions(SocaBaseModel): managers_group_name: Optional[str] -class ApiAuthorizationType(str, Enum): - ADMINISTRATOR = 'admin' - MANAGER = 'manager' - USER = 'user' - APP = 'app' - - -class ApiAuthorization(SocaBaseModel): - type: ApiAuthorizationType - username: Optional[str] # will not exist for APP authorizations - client_id: Optional[str] - groups: Optional[List[str]] # list of all groups user is part of - scopes: Optional[List[str]] # list of allowed oauth scopes - invocation_source: Optional[str] - - class TokenService(TokenServiceProtocol): def __init__(self, context: SocaContextProtocol, options: TokenServiceOptions): @@ -313,6 +297,29 @@ def decode_token(self, token: str, verify_exp: Optional[bool] = True) -> Dict: self._logger.error(f'Invalid Token: {e}') raise exceptions.unauthorized_access(f'Invalid Token - {e}') + + + def get_email_from_token_username(self, token_username: str) -> Optional[str]: + """ + For a user with + 1. email = a@email.com + 2. SSO enabled with identity-provider-name = idp + Cognito creates a user as idp_a@gmail.com and that name is passed as username in access token. + This method gets the identity-provider-name prefix from database and removes that from the username + to get the user email back. + + :param token_username + :return email + """ + + if self._context.config().get_bool('identity-provider.cognito.sso_enabled', required=True): + identity_provider_name = self._context.config().get_string('identity-provider.cognito.sso_idp_provider_name', required=True) + identity_provider_prefix = identity_provider_name + "_" + email = None + if token_username.startswith(identity_provider_prefix): + email = token_username.replace(identity_provider_prefix, "", 1) + return email + def is_token_expired(self, token: str) -> bool: """ check if the token is expired @@ -325,77 +332,6 @@ def is_token_expired(self, token: str) -> bool: except jwt.ExpiredSignatureError: return True - def get_authorization(self, decoded_token: Optional[Dict]) -> ApiAuthorization: - username = Utils.get_value_as_string('username', decoded_token) - groups = Utils.get_value_as_list('cognito:groups', decoded_token, []) - token_scope = Utils.get_value_as_string('scope', decoded_token) - client_id = Utils.get_value_as_string('client_id', decoded_token) - scopes = None - if Utils.is_not_empty(token_scope): - scopes = token_scope.split(' ') - - authorization_type = None - if Utils.is_not_empty(self.options.administrators_group_name) and self.options.administrators_group_name in groups: - authorization_type = ApiAuthorizationType.ADMINISTRATOR - elif Utils.is_not_empty(self.options.administrators_group_name) and self.options.managers_group_name in groups: - authorization_type = ApiAuthorizationType.MANAGER - - if Utils.is_empty(username): - authorization_type = ApiAuthorizationType.APP - - if authorization_type is None: - authorization_type = ApiAuthorizationType.USER - - return ApiAuthorization( - type=authorization_type, - username=username, - scopes=scopes, - groups=groups, - client_id=client_id - ) - - def is_scope_authorized(self, access_token: str, scope: str, verify_exp=True) -> bool: - access_token = access_token - if Utils.is_empty(access_token): - return False - if Utils.is_empty(scope): - return False - - decoded_token = self.decode_token(access_token, verify_exp=verify_exp) - authorization = self.get_authorization(decoded_token) - if authorization.type != ApiAuthorizationType.APP: - return False - return Utils.is_not_empty(authorization.scopes) and scope in authorization.scopes - - def is_administrator(self, access_token: str, verify_exp=True) -> bool: - access_token = access_token - if Utils.is_empty(access_token): - return False - administrators_group_name = self.options.administrators_group_name - if Utils.is_empty(administrators_group_name): - return False - decoded_token = self.decode_token(access_token, verify_exp=verify_exp) - authorization = self.get_authorization(decoded_token) - return authorization.type == ApiAuthorizationType.ADMINISTRATOR - - def is_manager(self, access_token: str, verify_exp=True) -> bool: - access_token = access_token - if Utils.is_empty(access_token): - return False - managers_group_name = self.options.managers_group_name - if Utils.is_empty(managers_group_name): - return False - - decoded_token = self.decode_token(access_token, verify_exp=verify_exp) - authorization = self.get_authorization(decoded_token) - return authorization.type == ApiAuthorizationType.MANAGER - - def get_username(self, access_token: str, verify_exp=True) -> Optional[str]: - if Utils.is_empty(access_token): - return None - decoded_token = self.decode_token(access_token, verify_exp=verify_exp) - return Utils.get_value_as_string('username', decoded_token) - def get_access_token(self, force_renewal=True) -> Optional[str]: auth_result = None diff --git a/source/idea/idea-sdk/src/ideasdk/aws/aws_client_provider.py b/source/idea/idea-sdk/src/ideasdk/aws/aws_client_provider.py index d25f6b0..dc26612 100644 --- a/source/idea/idea-sdk/src/ideasdk/aws/aws_client_provider.py +++ b/source/idea/idea-sdk/src/ideasdk/aws/aws_client_provider.py @@ -36,7 +36,6 @@ AWS_CLIENT_CLOUDWATCH = 'cloudwatch' AWS_CLIENT_CLOUDWATCHLOGS = 'logs' AWS_CLIENT_ES = 'es' -AWS_CLIENT_OPENSEARCH = 'opensearch' AWS_CLIENT_DS = 'ds' AWS_CLIENT_FSX = 'fsx' AWS_CLIENT_EFS = 'efs' @@ -72,7 +71,6 @@ AWS_CLIENT_CLOUDWATCH, AWS_CLIENT_CLOUDWATCHLOGS, AWS_CLIENT_ES, - AWS_CLIENT_OPENSEARCH, AWS_CLIENT_EVENTS, AWS_CLIENT_DS, AWS_CLIENT_FSX, @@ -323,9 +321,6 @@ def ds(self): def es(self): return self.get_client(service_name=AWS_CLIENT_ES) - def opensearch(self): - return self.get_client(service_name=AWS_CLIENT_OPENSEARCH) - def sts(self): return self.get_client(service_name=AWS_CLIENT_STS) diff --git a/source/idea/idea-sdk/src/ideasdk/aws/aws_resources.py b/source/idea/idea-sdk/src/ideasdk/aws/aws_resources.py index be44f89..8bec77d 100644 --- a/source/idea/idea-sdk/src/ideasdk/aws/aws_resources.py +++ b/source/idea/idea-sdk/src/ideasdk/aws/aws_resources.py @@ -14,7 +14,6 @@ from ideadatamodel.cluster_resources import ( SocaVPC, SocaCloudFormationStack, - SocaOpenSearchDomain, SocaDirectory, SocaSubnet, SocaFileSystem, @@ -61,12 +60,6 @@ def get_subnets(self) -> Optional[List[SocaSubnet]]: def set_subnets(self, subnets: List[SocaSubnet]): return self._db.set('aws.ec2.subnets', subnets) - def get_opensearch_domains(self) -> Optional[List[SocaOpenSearchDomain]]: - return self._db.get('aws.opensearch.domains') - - def set_opensearch_domains(self, domains: List[SocaOpenSearchDomain]): - return self._db.set('aws.opensearch.domains', domains) - def get_directories(self) -> Optional[List[SocaDirectory]]: return self._db.get('aws.directories') @@ -217,39 +210,6 @@ def result_cb(result) -> List[SocaVPC]: except Exception as e: self.aws_util.handle_aws_exception(e) - def get_opensearch_clusters(self, vpc_id: str, refresh: bool = False) -> List[SocaOpenSearchDomain]: - try: - if not refresh: - domains = self._db.get_opensearch_domains() - if domains is not None: - return domains - - domains = [] - list_domain_names_result = self.aws.es().list_domain_names() - domain_names = Utils.get_value_as_list('DomainNames', list_domain_names_result, []) - for entry in domain_names: - domain_name = Utils.get_value_as_string('DomainName', entry) - describe_domain_result = self.aws.es().describe_elasticsearch_domain(DomainName=domain_name) - domain_status = describe_domain_result['DomainStatus'] - vpc_options = Utils.get_value_as_dict('VPCOptions', domain_status) - if vpc_options is None: - continue - domain_vpc_id = vpc_options['VPCId'] - if domain_vpc_id != vpc_id: - continue - elasticsearch_version = domain_status['ElasticsearchVersion'] - title = f'{domain_name} ({elasticsearch_version})' - domains.append(SocaOpenSearchDomain( - type='aws.opensearch.domain', - title=title, - ref=describe_domain_result - )) - - self._db.set_opensearch_domains(domains) - return domains - except Exception as e: - self.aws_util.handle_aws_exception(e) - def get_directories(self, vpc_id: str, refresh: bool = False) -> List[SocaDirectory]: try: if not refresh: diff --git a/source/idea/idea-sdk/src/ideasdk/aws/aws_util.py b/source/idea/idea-sdk/src/ideasdk/aws/aws_util.py index be9829e..3457b89 100644 --- a/source/idea/idea-sdk/src/ideasdk/aws/aws_util.py +++ b/source/idea/idea-sdk/src/ideasdk/aws/aws_util.py @@ -1001,6 +1001,84 @@ def dynamodb_create_table(self, create_table_request: Dict, wait: bool = False, ) return True + + def dynamodb_import_table(self, import_table_request: Dict, wait: bool = False): + """ + Import table data from S3 to DynamoDB table + + :param import_table_request: + """ + + table_creation_parameters = import_table_request.get('TableCreationParameters') + if not table_creation_parameters: + raise exceptions.invalid_params("import_table_request must include 'TableCreationParameters'") + + table_name = table_creation_parameters.get("TableName") + if not table_name: + raise exceptions.invalid_params("import_table_request must include 'TableCreationParameters.TableName'") + + s3_bucket_source = import_table_request.get('S3BucketSource') + if not s3_bucket_source: + raise exceptions.invalid_params("import_table_request must include 'S3BucketSource'") + + s3_bucket_name = s3_bucket_source.get('S3Bucket') + s3_key_prefix = s3_bucket_source.get('S3KeyPrefix') + + if not s3_bucket_name or not s3_key_prefix: + raise exceptions.invalid_params("import_table_request must include 'S3BucketSource.S3Bucket' and 'S3BucketSource.S3KeyPrefix'. One or both were not provided") + + dynamodb_kms_key_id = self._context.config().get_string('cluster.dynamodb.kms_key_id') + if dynamodb_kms_key_id is not None: + import_table_request['TableCreationParameters']['SSESpecification'] = { + 'Enabled': True, + 'SSEType': 'KMS', + 'KMSMasterKeyId': dynamodb_kms_key_id + } + + self._logger.info(f'importing table {table_name} from S3 bucket {s3_bucket_name} and path {s3_key_prefix} to dynamodb table: {table_name} ...') + + # Response for the query takes about 5 seconds. The response ensures that the table creation process has started. + res = self.aws().dynamodb().import_table(**import_table_request) + + return res + + def dynamodb_check_import_completed_successfully(self, import_arn: str) -> bool: + while True: + describe_import_result = self.aws().dynamodb().describe_import(ImportArn=import_arn) + + result = describe_import_result['ImportTableDescription'] + if result['ImportStatus'] == 'COMPLETED': + return True + + elif result['ImportStatus'] == 'CANCELLED': + self._logger.error(f'Import attempt {import_arn} was CANCELLED') + return False + + elif result['ImportStatus'] == 'FAILED': + self._logger.error(f"Import attempt {import_arn} FAILED with code: {result['FailureCode']} and message: {result['FailureMessage']}") + return False + + time.sleep(10) + + + def dynamodb_delete_table(self, table_name: str, wait: bool = True) -> bool: + """ + Deletes a DynamoDB table if exists. Most likely used for deletion of temp tables created during ApplySnapshot process + + :param table_name: Name of the DDB table that should be deleted + :param wait: wait for table creation and status to become active + """ + + if wait: + try: + self.dynamodb_check_table_exists(table_name, wait) + except botocore.exceptions.ClientError as e: + if e.response['Error']['Code'] == 'ResourceNotFoundException': + return True + + self.aws().dynamodb().delete_table(TableName=table_name) + + return True def create_s3_presigned_url(self, key: str, expires_in=3600) -> str: return self.aws().s3().generate_presigned_url( diff --git a/source/idea/idea-sdk/src/ideasdk/aws/instance_metadata_util.py b/source/idea/idea-sdk/src/ideasdk/aws/instance_metadata_util.py index a02fe86..61fa8ed 100644 --- a/source/idea/idea-sdk/src/ideasdk/aws/instance_metadata_util.py +++ b/source/idea/idea-sdk/src/ideasdk/aws/instance_metadata_util.py @@ -9,7 +9,7 @@ # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions # and limitations under the License. -from ideasdk.protocols import SocaContextProtocol, InstanceMetadataUtilProtocol +from ideasdk.protocols import InstanceMetadataUtilProtocol from ideadatamodel import exceptions, errorcodes, EC2InstanceIdentityDocument from ideasdk.utils import Utils diff --git a/source/idea/idea-sdk/src/ideasdk/aws/opensearch/aws_opensearch_client.py b/source/idea/idea-sdk/src/ideasdk/aws/opensearch/aws_opensearch_client.py deleted file mode 100644 index e4adf93..0000000 --- a/source/idea/idea-sdk/src/ideasdk/aws/opensearch/aws_opensearch_client.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance -# with the License. A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES -# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions -# and limitations under the License. -import opensearchpy - -from ideasdk.aws.opensearch.opensearch_filters import ( - TermFilter, - RangeFilter, - FreeTextFilter, - SortFilter -) -from ideasdk.protocols import SocaContextProtocol - -from typing import Optional, Dict, List -from opensearchpy import OpenSearch, helpers - -from ideasdk.utils import Utils - - -class AwsOpenSearchClient: - def __init__(self, context: SocaContextProtocol): - self.context = context - self._logger = context.logger('opensearch-client') - - domain_endpoint = self.context.config().get_string('analytics.opensearch.domain_endpoint', required=True) - if not domain_endpoint.startswith('https://'): - domain_endpoint = f'https://{domain_endpoint}' - - self.os_client = OpenSearch( - hosts=[domain_endpoint], - port=443, - use_ssl=True, - verify_certs=True - ) - - def add_index_entry(self, index_name: str, doc_id: str, body: str, **kwargs): - timeout = kwargs.get('timeout', '10s') - response = self.os_client.index( - index=index_name, - id=doc_id, - body=body, - timeout=timeout - ) - return response['result'] == 'created' - - def bulk_index(self, index_name: str, docs: Dict[str, Dict], **kwargs): - - items = [] - - for doc_id, doc in docs.items(): - items.append({ - '_index': index_name, - '_id': doc_id, - '_source': doc - }) - - helpers.bulk( - client=self.os_client, - actions=items - ) - - return True - - def search(self, index: str, term_filters: Optional[List[TermFilter]] = None, range_filter: Optional[RangeFilter] = None, free_text_filter: Optional[FreeTextFilter] = None, sort_filter: Optional[SortFilter] = None, size: Optional[int] = None, start_from: Optional[int] = None, source: bool = True) -> dict: - query = {"bool": {}} - if Utils.is_not_empty(term_filters): - query["bool"]["must"] = Utils.get_value_as_list("must", query["bool"], default=[]) - for term_filter in term_filters: - query["bool"]["must"].append(term_filter.get_term_filter()) - - if Utils.is_not_empty(free_text_filter): - query["bool"]["must"] = Utils.get_value_as_list("must", query["bool"], default=[]) - query["bool"]["must"].append(free_text_filter.get_free_text_filter()) - - if Utils.is_not_empty(range_filter): - query["bool"]["filter"] = Utils.get_value_as_list("filter", query["bool"], default=[]) - query["bool"]["filter"].append(range_filter.get_range_filter()) - - sort_by = None - if Utils.is_not_empty(sort_filter): - sort_by = sort_filter.get_sort_filter() - - return self._search( - index=index, - body={'query': query}, - sort_by=sort_by, - size=size, - start_from=start_from, - source=source - ) - - def _search(self, index: str, body: dict, sort_by: Optional[str] = None, size: Optional[int] = None, start_from: Optional[int] = None, source: bool = True) -> dict: - try: - result = self.os_client.search( - index=index, - body=body, - sort=sort_by, - size=size, - from_=start_from, - _source=source - ) - except opensearchpy.exceptions.NotFoundError: - # if index does not exist, return False - # this allows new index to be created automatically - return {} - return result - - def exists(self, index: str, body: dict): - result = self._search(index, body, source=False) - return Utils.get_value_as_int('value', Utils.get_value_as_dict('total', Utils.get_value_as_dict('hits', result, {}), {}), 0) > 0 - - def get_template(self, name: str) -> Optional[Dict]: - try: - return self.os_client.indices.get_template(name=name) - except opensearchpy.exceptions.NotFoundError: - return None - - def delete_alias_and_index(self, name: str): - try: - response = self.os_client.indices.get_alias(name=name) - except opensearchpy.exceptions.NotFoundError as _: - response = None - - if Utils.is_empty(response): - return - - for index_name in response: - try: - self.os_client.indices.delete_alias( - index=index_name, - name=name - ) - except opensearchpy.exceptions.NotFoundError as _: - pass - self._logger.warning(f'Alias: {name} deleted ...') - self.delete_index(index_name) - - def delete_index(self, name: str): - try: - self.os_client.indices.delete( - index=name - ) - except opensearchpy.exceptions.NotFoundError as _: - pass - self._logger.warning(f'Index: {name} deleted ...') - - def delete_template(self, name: str): - try: - self.os_client.indices.delete_template(name=name) - except opensearchpy.exceptions.NotFoundError as _: - pass - self._logger.warning(f'Template: {name} deleted ...') - - def put_template(self, name: str, body: Dict) -> Dict: - return self.os_client.indices.put_template(name=name, body=body) diff --git a/source/idea/idea-sdk/src/ideasdk/aws/opensearch/opensearch_filters.py b/source/idea/idea-sdk/src/ideasdk/aws/opensearch/opensearch_filters.py deleted file mode 100644 index 688cc57..0000000 --- a/source/idea/idea-sdk/src/ideasdk/aws/opensearch/opensearch_filters.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance -# with the License. A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES -# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions -# and limitations under the License. - -__all__ = ( - 'TermFilter', - 'RangeFilter', - 'FreeTextFilter', - 'SortFilter', - 'SortOrder' -) - -from enum import Enum -from typing import List, Union, Dict - - -class TermFilter: - key: str - value: List[str] - - def __init__(self, key: str, value: Union[str, List[str]]): - self.key = key - if isinstance(value, str): - self.value = [value] - else: - self.value = value - - def get_term_filter(self) -> Dict: - return { - 'terms': { - self.key: self.value - } - } - - -class RangeFilter: - key: str - start: str - end: str - - def __init__(self, key: str, start: str, end: str): - self.key = key - self.start = start - self.end = end - - def get_range_filter(self) -> Dict: - return { - 'range': { - self.key: { - 'gte': self.start, - 'lt': self.end - } - } - } - - -class FreeTextFilter: - text: str - - def __init__(self, text: str): - self.text = text - - def get_free_text_filter(self) -> Dict: - return { - 'query_string': { - "fields": [], - "query": self.text - } - } - - -class SortOrder(str, Enum): - ASC = 'asc' - DESC = 'desc' - - -class SortFilter: - key: str - order: SortOrder - - def __init__(self, key: str, order: SortOrder): - self.key = key - self.order = order - - def get_sort_filter(self) -> str: - return f'{self.key}:{self.order}' diff --git a/source/idea/idea-sdk/src/ideasdk/aws/opensearch/opensearchable_db.py b/source/idea/idea-sdk/src/ideasdk/aws/opensearch/opensearchable_db.py deleted file mode 100644 index 06368e0..0000000 --- a/source/idea/idea-sdk/src/ideasdk/aws/opensearch/opensearchable_db.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance -# with the License. A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES -# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions -# and limitations under the License. -from abc import abstractmethod -from logging import Logger -from typing import Dict, Optional -import json - -from ideadatamodel import SocaListingPayload, SocaSortOrder, SocaSortBy -from ideasdk.aws.opensearch.aws_opensearch_client import AwsOpenSearchClient -from ideasdk.aws.opensearch.opensearch_filters import FreeTextFilter, TermFilter, RangeFilter, SortOrder, SortFilter -from ideasdk.context import SocaContext -from ideasdk.utils import Utils - - -class OpenSearchableDB: - FREE_TEXT_FILTER_KEY = '$all' - - def __init__(self, context: SocaContext, logger: Logger, term_filter_map: Optional[Dict] = None, date_range_filter_map: Optional[Dict] = None, default_page_size: int = 10, free_text_search_support: bool = True): - self.context = context - self._logger = logger - self._term_filter_map = term_filter_map - self._date_range_filter_map = date_range_filter_map - self._default_page_size = default_page_size - self._os_client = AwsOpenSearchClient(self.context) - self._free_text_search_support = free_text_search_support - - def list_from_opensearch(self, options: SocaListingPayload): - # documentation to read - https://opensearch.org/docs/latest/opensearch/query-dsl/ - - term_filters = [] - free_text_filter_value = None - if Utils.is_not_empty(options.filters): - for listing_filter in options.filters: - - if Utils.is_empty(listing_filter.key): - continue - - if Utils.is_empty(listing_filter.value): - continue - - if listing_filter.key == self.FREE_TEXT_FILTER_KEY: - free_text_filter_value = listing_filter.value.lower() - continue - - if listing_filter.key not in self._term_filter_map.keys(): - continue - - term_filters.append(TermFilter( - key=self._term_filter_map[listing_filter.key], - value=listing_filter.value - )) - - range_filter = None - if Utils.is_not_empty(options.date_range) and Utils.is_not_empty(self._date_range_filter_map): - date_range = options.date_range - if date_range.key in self._date_range_filter_map.keys(): - range_filter = RangeFilter( - key=self._date_range_filter_map[date_range.key], - start=f'{Utils.to_milliseconds(date_range.start)}', - end=f'{Utils.to_milliseconds(date_range.end)}' - ) - - if Utils.is_empty(options.sort_by): - options.sort_by = self.get_default_sort() - - sort_filter = SortFilter( - key=options.sort_by.key, - order=SortOrder.DESC if options.sort_by.order == SocaSortOrder.DESC else SortOrder.ASC - ) - - response = self._os_client.search( - index=self.get_index_name(), - term_filters=term_filters, - range_filter=range_filter, - sort_filter=sort_filter, - start_from=0, - size=None, - ) - - response_hits = (response.get('hits') or {}).get('hits') - - if response_hits and free_text_filter_value: - free_text_filtered_responses = [] - for hit in response_hits: - json_dump = json.dumps(hit.get('_source')).lower() - - if free_text_filter_value in json_dump: - free_text_filtered_responses.append(hit) - - response['hits']['hits'] = free_text_filtered_responses - - return response - - @abstractmethod - def get_index_name(self) -> str: - ... - - @abstractmethod - def get_default_sort(self) -> SocaSortBy: - ... diff --git a/source/idea/idea-sdk/src/ideasdk/client/accounts_client.py b/source/idea/idea-sdk/src/ideasdk/client/accounts_client.py index c93aa8b..0c9e9a9 100644 --- a/source/idea/idea-sdk/src/ideasdk/client/accounts_client.py +++ b/source/idea/idea-sdk/src/ideasdk/client/accounts_client.py @@ -20,6 +20,8 @@ ListUsersInGroupResult, GetUserRequest, GetUserResult, + GetUserByEmailRequest, + GetUserByEmailResult, GetGroupRequest, GetGroupResult ) @@ -66,6 +68,14 @@ def get_user(self, request: GetUserRequest) -> GetUserResult: access_token=self.get_access_token() ) + def get_user_by_email(self, request: GetUserByEmailRequest) -> GetUserByEmailResult: + return self.client.invoke_alt( + namespace='Accounts.GetUserByEmail', + payload=request, + result_as=GetUserByEmailResult, + access_token=self.get_access_token() + ) + def get_group(self, request: GetGroupRequest) -> GetGroupResult: return self.client.invoke_alt( namespace='Accounts.GetUserGroup', diff --git a/source/idea/idea-sdk/src/ideasdk/client/soca_client.py b/source/idea/idea-sdk/src/ideasdk/client/soca_client.py index 4209f54..4cbbc6b 100644 --- a/source/idea/idea-sdk/src/ideasdk/client/soca_client.py +++ b/source/idea/idea-sdk/src/ideasdk/client/soca_client.py @@ -127,7 +127,7 @@ def is_enable_logging(self) -> bool: def timeout(self) -> float: return Utils.get_as_float(self.options.timeout, DEFAULT_TIMEOUT_SECONDS) - def invoke(self, request: SocaEnvelope, result_as: Optional[Type[T]] = SocaAnyPayload, access_token: str = None) -> T: + def invoke(self, request: SocaEnvelope, result_as: Optional[Type[T]] = SocaAnyPayload, access_token: Optional[str] = None) -> T: try: header = request.header request_id = header.request_id @@ -192,7 +192,7 @@ def invoke(self, request: SocaEnvelope, result_as: Optional[Type[T]] = SocaAnyPa def invoke_alt(self, namespace: str, payload: Optional[Any], result_as: Optional[Type[T]] = SocaAnyPayload, - access_token: str = None) -> T: + access_token: Optional[str] = None) -> T: request = SocaEnvelope( header=SocaHeader( namespace=namespace, diff --git a/source/idea/idea-sdk/src/ideasdk/client/vdc_client.py b/source/idea/idea-sdk/src/ideasdk/client/vdc_client.py new file mode 100644 index 0000000..2b1ae2f --- /dev/null +++ b/source/idea/idea-sdk/src/ideasdk/client/vdc_client.py @@ -0,0 +1,201 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from ideasdk.protocols import SocaContextProtocol +from ideasdk.client.soca_client import SocaClient, SocaClientOptions +from ideasdk.auth import TokenService +from ideasdk.utils import Utils +from ideadatamodel import exceptions +from ideadatamodel import ( + ListSessionsRequest, + ListSessionsResponse, + ListSoftwareStackRequest, + ListSoftwareStackResponse, + GetBasePermissionsRequest, + GetBasePermissionsResponse, + CreatePermissionProfileRequest, + CreatePermissionProfileResponse, + DeletePermissionProfileRequest, + DeletePermissionProfileResponse, + GetPermissionProfileRequest, + GetPermissionProfileResponse, + CreateSoftwareStackRequest, + CreateSoftwareStackResponse, + DeleteSoftwareStackRequest, + DeleteSoftwareStackResponse, + GetSoftwareStackInfoResponse, + VirtualDesktopSoftwareStack, + VirtualDesktopSession, + VirtualDesktopPermission, + VirtualDesktopPermissionProfile, + SocaFilter, +) + +from abc import abstractmethod +from typing import Optional + + +class AbstractVirtualDesktopControllerClient: + @abstractmethod + def list_sessions_by_project_id(self, project_id: str) -> list[VirtualDesktopSession]: + ... + + @abstractmethod + def list_software_stacks_by_project_id(self, project_id: str) -> list[VirtualDesktopSoftwareStack]: + ... + + @abstractmethod + def get_base_permissions(self) -> list[VirtualDesktopPermission]: + ... + + @abstractmethod + def create_permission_profile(self, profile: VirtualDesktopPermissionProfile) -> VirtualDesktopPermissionProfile: + ... + + @abstractmethod + def delete_permission_profile(self, profile_id: str) -> None: + ... + + @abstractmethod + def get_permission_profile(self, profile_id: str) -> VirtualDesktopPermissionProfile: + ... + + @abstractmethod + def get_software_stacks_by_name(self, stack_name: str) -> list[VirtualDesktopSoftwareStack]: + ... + + @abstractmethod + def create_software_stack(self, software_stack: VirtualDesktopSoftwareStack) -> VirtualDesktopSoftwareStack: + ... + + @abstractmethod + def delete_software_stack(self, software_stack: VirtualDesktopSoftwareStack) -> None: + ... + + +class VirtualDesktopControllerClient(AbstractVirtualDesktopControllerClient): + + def __init__(self, context: SocaContextProtocol, + options: SocaClientOptions, token_service: Optional[TokenService]): + """ + :param context: Application Context + :param options: Client Options + """ + self.context = context + self.logger = context.logger('vdc-client') + self.client = SocaClient(context=context, options=options) + + if Utils.is_empty(options.unix_socket) and token_service is None: + raise exceptions.invalid_params('token_service is required for http client') + self.token_service = token_service + + def get_access_token(self) -> Optional[str]: + if self.token_service is None: + return None + return self.token_service.get_access_token() + + def list_sessions_by_project_id(self, project_id: str) -> list[VirtualDesktopSession]: + result = self.client.invoke_alt( + namespace='VirtualDesktop.ListSessions', + payload=ListSessionsRequest(), + result_as=ListSessionsResponse, + access_token=self.get_access_token(), + ) + + return [session for session in result.listing if session.project.project_id == project_id] if result.listing else [] + + def list_software_stacks_by_project_id(self, project_id: str) -> list[VirtualDesktopSoftwareStack]: + result = self.client.invoke_alt( + namespace='VirtualDesktop.ListSoftwareStacks', + payload=ListSoftwareStackRequest(project_id=project_id), + result_as=ListSoftwareStackResponse, + access_token=self.get_access_token(), + ) + + return result.listing if result.listing else [] + + def get_base_permissions(self) -> list[VirtualDesktopPermission]: + result = self.client.invoke_alt( + namespace='VirtualDesktopUtils.GetBasePermissions', + payload=GetBasePermissionsRequest(), + result_as=GetBasePermissionsResponse, + access_token=self.get_access_token(), + ) + + return result.permissions if result.permissions else [] + + def create_permission_profile(self, profile: VirtualDesktopPermissionProfile) -> VirtualDesktopPermissionProfile: + result = self.client.invoke_alt( + namespace='VirtualDesktopAdmin.CreatePermissionProfile', + payload=CreatePermissionProfileRequest(profile=profile), + result_as=CreatePermissionProfileResponse, + access_token=self.get_access_token(), + ) + + return result.profile + + def delete_permission_profile(self, profile_id: str) -> None: + self.client.invoke_alt( + namespace='VirtualDesktopAdmin.DeletePermissionProfile', + payload=DeletePermissionProfileRequest(profile_id=profile_id), + result_as=DeletePermissionProfileResponse, + access_token=self.get_access_token(), + ) + + def get_permission_profile(self, profile_id: str) -> VirtualDesktopPermissionProfile: + result = self.client.invoke_alt( + namespace='VirtualDesktopUtils.GetPermissionProfile', + payload=GetPermissionProfileRequest(profile_id=profile_id), + result_as=GetPermissionProfileResponse, + access_token=self.get_access_token(), + ) + + return result.profile + + def get_software_stacks_by_name(self, stack_name: str) -> list[VirtualDesktopSoftwareStack]: + result = self.client.invoke_alt( + namespace='VirtualDesktopAdmin.ListSoftwareStacks', + payload=ListSoftwareStackRequest( + filters=[ + SocaFilter( + key="name", + value=stack_name, + ) + ] + ), + result_as=ListSoftwareStackResponse, + access_token=self.get_access_token(), + ) + + return result.listing if result.listing else [] + + def create_software_stack(self, software_stack: VirtualDesktopSoftwareStack) -> VirtualDesktopSoftwareStack: + result = self.client.invoke_alt( + namespace='VirtualDesktopAdmin.CreateSoftwareStack', + payload=CreateSoftwareStackRequest(software_stack=software_stack), + result_as=CreateSoftwareStackResponse, + access_token=self.get_access_token(), + ) + + return result.software_stack + + def delete_software_stack(self, software_stack: VirtualDesktopSoftwareStack) -> None: + self.client.invoke_alt( + namespace='VirtualDesktopAdmin.DeleteSoftwareStack', + payload=DeleteSoftwareStackRequest( + software_stack=software_stack + ), + result_as=DeleteSoftwareStackResponse, + access_token=self.get_access_token(), + ) + + def destroy(self): + self.client.close() diff --git a/source/idea/idea-sdk/src/ideasdk/context/arn_builder.py b/source/idea/idea-sdk/src/ideasdk/context/arn_builder.py index 73d126d..268a61d 100644 --- a/source/idea/idea-sdk/src/ideasdk/context/arn_builder.py +++ b/source/idea/idea-sdk/src/ideasdk/context/arn_builder.py @@ -241,6 +241,11 @@ def get_sqs_arn(self, queue_name_suffix: str) -> str: def get_route53_hostedzone_arn(self) -> str: return f'arn:{self.config.get_string("cluster.aws.partition", required=True)}:route53:::hostedzone/*' + + def get_iam_arn(self, role_name_suffix: str) -> str: + return self.get_arn(service='iam', + resource=f'role/{self.config.get_string("cluster.cluster_name")}-{role_name_suffix}-{self.config.get_string("cluster.aws.region")}', + aws_region='') @property def kms_secretsmanager_key_arn(self) -> str: @@ -277,18 +282,6 @@ def kms_backup_key_arn(self) -> str: resource=f'key/{self.config.get_string("cluster.backups.backup_vault.kms_key_id")}', aws_region=self.config.get_string("cluster.aws.region"))) - @property - def kms_opensearch_key_arn(self) -> str: - return(self.get_arn(service='kms', - resource=f'key/{self.config.get_string("analytics.opensearch.kms_key_id")}', - aws_region=self.config.get_string("cluster.aws.region"))) - - @property - def kms_kinesis_key_arn(self) -> str: - return(self.get_arn(service='kms', - resource=f'key/{self.config.get_string("analytics.kinesis.kms_key_id")}', - aws_region=self.config.get_string("cluster.aws.region"))) - @property def kms_key_arn(self) -> List[str]: kms_key_arns = [] @@ -299,8 +292,6 @@ def kms_key_arn(self) -> List[str]: 'dynamodb': 'cluster.dynamodb.kms_key_id', 'ebs': 'cluster.ebs.kms_key_id', 'backup': 'cluster.backups.backup_vault.kms_key_id', - 'opensearch': 'analytics.opensearch.kms_key_id', - 'kinesis': 'analytics.kinesis.kms_key_id' } service_kms_key_arns = { 'secretsmanager': self.kms_secretsmanager_key_arn, @@ -309,8 +300,6 @@ def kms_key_arn(self) -> List[str]: 'dynamodb': self.kms_dynamodb_key_arn, 'ebs': self.kms_ebs_key_arn, 'backup': self.kms_backup_key_arn, - 'opensearch': self.kms_opensearch_key_arn, - 'kinesis': self.kms_kinesis_key_arn } for service in service_kms_key_arns.keys(): if self.config.get_string(service_kms_key_ids[service]) is not None: diff --git a/source/idea/idea-sdk/src/ideasdk/context/bootstrap_context.py b/source/idea/idea-sdk/src/ideasdk/context/bootstrap_context.py index 238030e..d5af670 100644 --- a/source/idea/idea-sdk/src/ideasdk/context/bootstrap_context.py +++ b/source/idea/idea-sdk/src/ideasdk/context/bootstrap_context.py @@ -124,9 +124,6 @@ def eval_project() -> bool: if 'project' not in context_vars: return False projects = Utils.get_value_as_list('projects', shared_storage, []) - # empty list = allow all - if Utils.is_empty(projects): - return True return self.vars.project in projects def eval_module() -> bool: diff --git a/source/idea/idea-sdk/src/ideasdk/context/soca_context.py b/source/idea/idea-sdk/src/ideasdk/context/soca_context.py index 3036522..97ca611 100644 --- a/source/idea/idea-sdk/src/ideasdk/context/soca_context.py +++ b/source/idea/idea-sdk/src/ideasdk/context/soca_context.py @@ -14,7 +14,6 @@ ) import ideasdk -from ideasdk.analytics.analytics_service import AnalyticsService from ideasdk.protocols import ( SocaContextProtocol, CacheProviderProtocol, SocaPubSubProtocol, SocaServiceRegistryProtocol, SocaConfigType @@ -48,7 +47,6 @@ class SocaContextOptions(SocaBaseModel): module_id: Optional[str] module_set: Optional[str] - enable_analytics: Optional[bool] enable_metrics: Optional[bool] metrics_namespace: Optional[str] @@ -84,7 +82,6 @@ def default() -> 'SocaContextOptions': use_vpc_endpoints=False, enable_distributed_lock=False, enable_leader_election=False, - enable_analytics=False, config=None ) @@ -109,7 +106,6 @@ def __init__(self, options: SocaContextOptions = None): self._distributed_lock: Optional[DistributedLock] = None self._metrics_service: Optional[MetricsService] = None self._leader_election: Optional[LeaderElection] = None - self._analytics_service: Optional[AnalyticsService] = None is_app_server = Utils.get_as_bool(options.is_app_server, False) if is_app_server: @@ -216,10 +212,6 @@ def __init__(self, options: SocaContextOptions = None): if options.enable_leader_election: self._leader_election = LeaderElection(context=self) - # analytics - if options.enable_analytics: - self._analytics_service = AnalyticsService(context=self) - # metrics if options.enable_metrics: self._metrics_service = MetricsService(context=self, default_namespace=options.metrics_namespace) @@ -229,8 +221,6 @@ def __init__(self, options: SocaContextOptions = None): self._distributed_lock.stop() if self._leader_election is not None: self._leader_election.stop() - if self._analytics_service is not None: - self._analytics_service.stop() if self._metrics_service is not None: self._metrics_service.stop() if self._config is not None and isinstance(self._config, ClusterConfig): @@ -325,6 +315,3 @@ def is_leader(self) -> bool: if self._leader_election is None: return True return self._leader_election.is_leader() - - def analytics_service(self) -> AnalyticsService: - return self._analytics_service diff --git a/source/idea/idea-sdk/src/ideasdk/protocols/__init__.py b/source/idea/idea-sdk/src/ideasdk/protocols/__init__.py index 0e5a449..0bc7b79 100644 --- a/source/idea/idea-sdk/src/ideasdk/protocols/__init__.py +++ b/source/idea/idea-sdk/src/ideasdk/protocols/__init__.py @@ -27,6 +27,7 @@ SocaJob, AuthResult ) +from ideadatamodel.api.api_model import ApiAuthorization from abc import ABC, abstractmethod from typing import Dict, Optional, Tuple, List, Callable, Any, Union, Hashable, Set, TypeVar @@ -584,10 +585,6 @@ def distributed_lock(self) -> DistributedLockProtocol: def is_leader(self) -> bool: ... - @abstractmethod - def analytics_service(self) -> AnalyticsServiceProtocol: - ... - class SocaServiceProtocol(SocaBaseProtocol): @@ -666,6 +663,21 @@ def context(self) -> SocaContextProtocolType: ... +class ApiAuthorizationServiceProtocol(SocaBaseProtocol): + + @abstractmethod + def get_authorization(self, decoded_token: Optional[Dict]) -> Optional[ApiAuthorization]: + ... + + @abstractmethod + def is_scope_authorized(self, decoded_token: str, scope: str) -> bool: + ... + + @abstractmethod + def get_username(self, decoded_token: str) -> Optional[str]: + ... + + class TokenServiceProtocol(SocaBaseProtocol): @abstractmethod @@ -680,10 +692,6 @@ def decode_token(self, token: str, verify_exp: Optional[bool] = True) -> Dict: def is_token_expired(self, token: str) -> bool: ... - @abstractmethod - def get_username(self, access_token: str, verify_exp=True) -> Optional[str]: - ... - class ApiInvokerProtocol(SocaBaseProtocol): @@ -691,6 +699,10 @@ class ApiInvokerProtocol(SocaBaseProtocol): def get_token_service(self) -> Optional[TokenServiceProtocol]: ... + @abstractmethod + def get_api_authorization_service(self) -> Optional[APIAuthorizationServiceProtocol]: + ... + @abstractmethod def invoke(self, context: ApiInvocationContextProtocol): ... @@ -721,10 +733,3 @@ def acquire(self, key: str): @abstractmethod def release(self, key: str): pass - - -class AnalyticsServiceProtocol(SocaBaseProtocol): - - @abstractmethod - def post_entry(self, document): - ... diff --git a/source/idea/idea-sdk/src/ideasdk/server/soca_server.py b/source/idea/idea-sdk/src/ideasdk/server/soca_server.py index e151938..76718de 100644 --- a/source/idea/idea-sdk/src/ideasdk/server/soca_server.py +++ b/source/idea/idea-sdk/src/ideasdk/server/soca_server.py @@ -432,7 +432,8 @@ def is_authenticated_request(self, http_request) -> bool: def get_username(self, http_request) -> Optional[str]: try: token_service = self.api_invoker.get_token_service() - if token_service is None: + api_authorization_service = self.api_invoker.get_api_authorization_service() + if not token_service or not api_authorization_service: return None token = self.get_token(http_request) token_type = Utils.get_value_as_string('token_type', token) @@ -441,7 +442,8 @@ def get_username(self, http_request) -> Optional[str]: if token_type != 'Bearer': return None access_token = Utils.get_value_as_string('token', token) - return token_service.get_username(access_token) + decoded_token = token_service.decode_token(access_token) + return api_authorization_service.get_username(decoded_token) except Exception: # noqa return None @@ -462,11 +464,13 @@ def _invoke(self, http_request) -> Dict: invocation_context = ApiInvocationContext( context=self.context, request=request, + http_headers=http_request.headers, invocation_source=invocation_source, group_name_helper=self.group_name_helper, logger=self.logger, token=self.get_token(http_request), token_service=self.api_invoker.get_token_service(), + api_authorization_service = self.api_invoker.get_api_authorization_service(), ) # validate request prior to logging diff --git a/source/idea/idea-sdk/src/ideasdk/utils/__init__.py b/source/idea/idea-sdk/src/ideasdk/utils/__init__.py index 0f47d1a..ea5c751 100644 --- a/source/idea/idea-sdk/src/ideasdk/utils/__init__.py +++ b/source/idea/idea-sdk/src/ideasdk/utils/__init__.py @@ -13,7 +13,9 @@ from ideasdk.utils.environment_utils import EnvironmentUtils from ideasdk.utils.utils import Utils +from ideasdk.utils.api_utils import ApiUtils from ideasdk.utils.datetime_utils import DateTimeUtils from ideasdk.utils.group_name_helper import GroupNameHelper from ideasdk.utils.jinja2_utils import Jinja2Utils from ideasdk.utils.module_metadata import * +from ideasdk.utils.fetch_records_from_db_util import * diff --git a/source/idea/idea-sdk/src/ideasdk/utils/api_utils.py b/source/idea/idea-sdk/src/ideasdk/utils/api_utils.py new file mode 100644 index 0000000..3f9b811 --- /dev/null +++ b/source/idea/idea-sdk/src/ideasdk/utils/api_utils.py @@ -0,0 +1,26 @@ +import re + +from ideadatamodel import ( + exceptions, + errorcodes, +) + +from ideadatamodel.constants import INVALID_RANGE_ERROR_MESSAGE + +class ApiUtils: + + @staticmethod + def validate_input(input_string: str, validation_regex: str, error_message=""): + if not re.match(validation_regex, input_string): + raise exceptions.soca_exception( + error_code=errorcodes.INVALID_PARAMS, + message=error_message + ) + + @staticmethod + def validate_input_range(value: int, range: tuple, error_message=INVALID_RANGE_ERROR_MESSAGE): + if value < range[0] or value > range[1]: + raise exceptions.soca_exception( + error_code=errorcodes.INVALID_PARAMS, + message=error_message + ) \ No newline at end of file diff --git a/source/idea/idea-sdk/src/ideasdk/utils/environment_utils.py b/source/idea/idea-sdk/src/ideasdk/utils/environment_utils.py index a66f9be..0ad0358 100644 --- a/source/idea/idea-sdk/src/ideasdk/utils/environment_utils.py +++ b/source/idea/idea-sdk/src/ideasdk/utils/environment_utils.py @@ -86,3 +86,11 @@ def aws_default_region(required=False, default=None): required=required, default=default ) + + @staticmethod + def res_test_mode(required=False, default=None): + return EnvironmentUtils.get_environment_variable( + key='RES_TEST_MODE', + required=required, + default=default + ) diff --git a/source/idea/idea-sdk/src/ideasdk/utils/fetch_records_from_db_util.py b/source/idea/idea-sdk/src/ideasdk/utils/fetch_records_from_db_util.py new file mode 100644 index 0000000..ebd9aec --- /dev/null +++ b/source/idea/idea-sdk/src/ideasdk/utils/fetch_records_from_db_util.py @@ -0,0 +1,43 @@ +from ideasdk.utils import Utils +from typing import Dict, TypeVar + +TRequest = TypeVar('TRequest') + +def scan_db_records(request: TRequest, table) -> Dict: + scan_request = {} + + cursor = request.cursor + last_evaluated_key = None + if Utils.is_not_empty(cursor): + last_evaluated_key = Utils.from_json(Utils.base64_decode(cursor)) + if last_evaluated_key is not None: + scan_request['LastEvaluatedKey'] = last_evaluated_key + + scan_filter = None + if Utils.is_not_empty(request.filters): + scan_filter = {} + for filter_ in request.filters: + if filter_.value == '$all': + continue + + if filter_.eq is not None: + scan_filter[filter_.key] = { + 'AttributeValueList': [filter_.eq], + 'ComparisonOperator': 'EQ' + } + if filter_.value is not None: + scan_filter[filter_.key] = { + 'AttributeValueList': [filter_.value], + 'ComparisonOperator': 'CONTAINS' + } + if filter_.like is not None: + scan_filter[filter_.key] = { + 'AttributeValueList': [filter_.like], + 'ComparisonOperator': 'CONTAINS' + } + if scan_filter is not None: + scan_request['ScanFilter'] = scan_filter + + scan_result = table.scan(**scan_request) + + return scan_result \ No newline at end of file diff --git a/source/idea/idea-sdk/src/ideasdk/utils/module_metadata.py b/source/idea/idea-sdk/src/ideasdk/utils/module_metadata.py index d4520be..8dc3190 100644 --- a/source/idea/idea-sdk/src/ideasdk/utils/module_metadata.py +++ b/source/idea/idea-sdk/src/ideasdk/utils/module_metadata.py @@ -31,7 +31,6 @@ class ModuleMetadata(SocaBaseModel): ModuleMetadata(name=constants.MODULE_GLOBAL_SETTINGS, title='Global Settings', type=constants.MODULE_TYPE_CONFIG, deployment_priority=0), ModuleMetadata(name=constants.MODULE_BOOTSTRAP, title='Bootstrap', type=constants.MODULE_TYPE_STACK, deployment_priority=1), ModuleMetadata(name=constants.MODULE_CLUSTER, title='Cluster', type=constants.MODULE_TYPE_STACK, deployment_priority=2), - ModuleMetadata(name=constants.MODULE_ANALYTICS, title='Analytics', type=constants.MODULE_TYPE_STACK, deployment_priority=3), ModuleMetadata(name=constants.MODULE_METRICS, title='Metrics & Monitoring', type=constants.MODULE_TYPE_STACK, deployment_priority=3), ModuleMetadata(name=constants.MODULE_IDENTITY_PROVIDER, title='Identity Provider', type=constants.MODULE_TYPE_STACK, deployment_priority=3), ModuleMetadata(name=constants.MODULE_DIRECTORYSERVICE, title='Directory Service', type=constants.MODULE_TYPE_STACK, deployment_priority=3), diff --git a/source/idea/idea-sdk/src/ideasdk/utils/utils.py b/source/idea/idea-sdk/src/ideasdk/utils/utils.py index a1a8d38..7650d11 100644 --- a/source/idea/idea-sdk/src/ideasdk/utils/utils.py +++ b/source/idea/idea-sdk/src/ideasdk/utils/utils.py @@ -664,6 +664,17 @@ def convert_tags_dict_to_aws_tags(tags: Dict) -> List[Dict]: }) return result + @staticmethod + def convert_tags_list_of_dict_to_tags_dict(list: list[dict]) -> dict: + result = {} + if (not Utils.is_empty(list)): + for tag_pair in list: + key = tag_pair.get("session_tags_keys") + value = tag_pair.get("session_tags_values") + if key and value is not None: + result[key] = value + return result + @staticmethod def create_boto_session(aws_region: str, aws_profile: Optional[str] = None): """ @@ -734,3 +745,7 @@ def flatten_dict(dictionary: dict, parent_key='', separator='.'): else: items.append((new_key, value)) return dict(items) + + @staticmethod + def is_test_mode() -> bool: + return Utils.get_as_bool(EnvironmentUtils.res_test_mode(), False) diff --git a/source/idea/idea-sdk/src/ideasdk_meta/__init__.py b/source/idea/idea-sdk/src/ideasdk_meta/__init__.py index 2460420..a1ced91 100644 --- a/source/idea/idea-sdk/src/ideasdk_meta/__init__.py +++ b/source/idea/idea-sdk/src/ideasdk_meta/__init__.py @@ -12,4 +12,4 @@ # pkgconfig for soca-sdk. no dependencies # noqa __name__ = 'idea-sdk' -__version__ = '2023.11' +__version__ = '2024.01' diff --git a/source/idea/idea-data-model/src/ideadatamodel/analytics/__init__.py b/source/idea/idea-test-utils/src/ideatestutils/api_authorization_service/__init__.py similarity index 94% rename from source/idea/idea-data-model/src/ideadatamodel/analytics/__init__.py rename to source/idea/idea-test-utils/src/ideatestutils/api_authorization_service/__init__.py index e0e0263..6d8d18a 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/analytics/__init__.py +++ b/source/idea/idea-test-utils/src/ideatestutils/api_authorization_service/__init__.py @@ -8,5 +8,3 @@ # or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions # and limitations under the License. - -from .analytics_api import * diff --git a/source/idea/idea-test-utils/src/ideatestutils/api_authorization_service/mock_api_authorization_service.py b/source/idea/idea-test-utils/src/ideatestutils/api_authorization_service/mock_api_authorization_service.py new file mode 100644 index 0000000..642fb97 --- /dev/null +++ b/source/idea/idea-test-utils/src/ideatestutils/api_authorization_service/mock_api_authorization_service.py @@ -0,0 +1,25 @@ +from ideasdk.auth.api_authorization_service_base import ApiAuthorizationServiceBase +from ideadatamodel.api.api_model import ApiAuthorizationType, ApiAuthorization +from ideadatamodel.auth import User +from typing import Optional, Dict + + +class MockApiAuthorizationService(ApiAuthorizationServiceBase): + + def get_user_from_token_username(self, token_username: Optional[str]): + return User ( + username = token_username, + role = 'user', + enabled = True + ) + + def get_authorization(self, decoded_token: Optional[Dict]) -> Optional[ApiAuthorization]: + return ApiAuthorization( + type=ApiAuthorizationType.USER, + ) + + def is_scope_authorized(self, decoded_token: str, scope: str) -> bool: + return False + + def get_username(self, decoded_token: str) -> Optional[str]: + return None diff --git a/source/idea/idea-test-utils/src/ideatestutils/config/templates/default.yml b/source/idea/idea-test-utils/src/ideatestutils/config/templates/default.yml index b733947..8e09e1b 100644 --- a/source/idea/idea-test-utils/src/ideatestutils/config/templates/default.yml +++ b/source/idea/idea-test-utils/src/ideatestutils/config/templates/default.yml @@ -53,6 +53,7 @@ identity-provider: sso_idp_provider_name: test-provider sso_idp_provider_type: SAML-OIDC sso_idp_provider_email_attribute: email + sso_enabled: false provider: cognito-idp directoryservice: @@ -84,6 +85,7 @@ directoryservice: tls_certificate_secret_arn: arn:aws:secretsmanager:us-east-1:123456789012:secret:{{ context.cluster_name }}-directoryservice-certificate-FN5MSq tls_private_key_secret_arn: arn:aws:secretsmanager:us-east-1:123456789012:secret:{{ context.cluster_name }}-directoryservice-private-key-FRWbNY volume_size: 200 + sssd.ldap_id_mapping: True scheduler: @@ -161,8 +163,6 @@ global-settings: default: cluster: module_id: cluster - analytics: - module_id: analytics identity-provider: module_id: identity-provider directoryservice: diff --git a/source/idea/idea-test-utils/src/ideatestutils/token_service/__init__.py b/source/idea/idea-test-utils/src/ideatestutils/token_service/__init__.py new file mode 100644 index 0000000..6d8d18a --- /dev/null +++ b/source/idea/idea-test-utils/src/ideatestutils/token_service/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. diff --git a/source/idea/idea-test-utils/src/ideatestutils/token_service/mock_token_service.py b/source/idea/idea-test-utils/src/ideatestutils/token_service/mock_token_service.py new file mode 100644 index 0000000..ff27be5 --- /dev/null +++ b/source/idea/idea-test-utils/src/ideatestutils/token_service/mock_token_service.py @@ -0,0 +1,19 @@ +from ideasdk.protocols import TokenServiceProtocol, AuthResult +from ideadatamodel.api.api_model import ApiAuthorizationType +from typing import Optional, Dict + + +class MockTokenService(TokenServiceProtocol): + + def get_access_token_using_client_credentials(self, cached=True) -> AuthResult: + return AuthResult() + + def decode_token(self, token: str, verify_exp: Optional[bool] = True) -> Dict: + return {} + + def is_token_expired(self, token: str) -> bool: + return False + + def get_username(self, access_token: str, verify_exp=True) -> Optional[str]: + return None + diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_admin_api.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_admin_api.py index b72828b..161bfb3 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_admin_api.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_admin_api.py @@ -16,6 +16,8 @@ GetProjectRequest, CreateSessionRequest, CreateSessionResponse, + DeleteSoftwareStackRequest, + DeleteSoftwareStackResponse, BatchCreateSessionRequest, BatchCreateSessionResponse, GetSessionConnectionInfoRequest, @@ -45,54 +47,125 @@ ListPermissionsRequest, CreateSoftwareStackFromSessionRequest, CreateSoftwareStackFromSessionResponse, - ReIndexUserSessionsResponse, - ReIndexSoftwareStacksResponse, CreatePermissionProfileResponse, CreatePermissionProfileRequest, UpdatePermissionProfileRequest, UpdatePermissionProfileResponse, + DeletePermissionProfileRequest, + DeletePermissionProfileResponse, UpdateSessionPermissionRequest, UpdateSessionPermissionResponse, VirtualDesktopSession, VirtualDesktopArchitecture, VirtualDesktopSoftwareStack ) -from ideadatamodel import errorcodes, exceptions +from ideadatamodel import errorcodes, exceptions, constants from ideasdk.api import ApiInvocationContext -from ideasdk.utils import Utils +from ideasdk.utils import Utils, ApiUtils from ideavirtualdesktopcontroller.app.api.virtual_desktop_api import VirtualDesktopAPI - class VirtualDesktopAdminAPI(VirtualDesktopAPI): def __init__(self, context: ideavirtualdesktopcontroller.AppContext): super().__init__(context) self.context = context self._logger = context.logger('virtual-desktop-admin-api') - self.namespace_handler_map: Dict[str, ()] = { - 'VirtualDesktopAdmin.CreateSession': self.create_session, - 'VirtualDesktopAdmin.BatchCreateSessions': self.batch_create_sessions, - 'VirtualDesktopAdmin.UpdateSession': self.update_session, - 'VirtualDesktopAdmin.DeleteSessions': self.delete_sessions, - 'VirtualDesktopAdmin.GetSessionInfo': self.get_session_info, - 'VirtualDesktopAdmin.ListSessions': self.list_sessions, - 'VirtualDesktopAdmin.StopSessions': self.stop_sessions, - 'VirtualDesktopAdmin.RebootSessions': self.reboot_sessions, - 'VirtualDesktopAdmin.ResumeSessions': self.resume_sessions, - 'VirtualDesktopAdmin.GetSessionScreenshot': self.get_session_screenshots, - 'VirtualDesktopAdmin.GetSessionConnectionInfo': self.get_session_connection_info, - 'VirtualDesktopAdmin.CreateSoftwareStack': self.create_software_stack, - 'VirtualDesktopAdmin.UpdateSoftwareStack': self.update_software_stack, - 'VirtualDesktopAdmin.GetSoftwareStackInfo': self.get_software_stack_info, - 'VirtualDesktopAdmin.ListSoftwareStacks': self.list_software_stacks, - 'VirtualDesktopAdmin.CreateSoftwareStackFromSession': self.create_software_stack_from_session, - 'VirtualDesktopAdmin.CreatePermissionProfile': self.create_permission_profile, - 'VirtualDesktopAdmin.UpdatePermissionProfile': self.update_permission_profile, - 'VirtualDesktopAdmin.ListSessionPermissions': self.list_session_permissions, - 'VirtualDesktopAdmin.ListSharedPermissions': self.list_shared_permissions, - 'VirtualDesktopAdmin.UpdateSessionPermissions': self.update_session_permission, - 'VirtualDesktopAdmin.ReIndexUserSessions': self.re_index_user_sessions, - 'VirtualDesktopAdmin.ReIndexSoftwareStacks': self.re_index_software_stacks, + self.SCOPE_WRITE = f'{self.context.module_id()}/write' + self.SCOPE_READ = f'{self.context.module_id()}/read' + + self.acl = { + 'VirtualDesktopAdmin.CreateSession': { + 'scope': self.SCOPE_WRITE, + 'method': self.create_session, + }, + 'VirtualDesktopAdmin.BatchCreateSessions': { + 'scope': self.SCOPE_WRITE, + 'method': self.batch_create_sessions, + }, + 'VirtualDesktopAdmin.UpdateSession': { + 'scope': self.SCOPE_WRITE, + 'method': self.update_session, + }, + 'VirtualDesktopAdmin.DeleteSessions': { + 'scope': self.SCOPE_WRITE, + 'method': self.delete_sessions, + }, + 'VirtualDesktopAdmin.GetSessionInfo': { + 'scope': self.SCOPE_READ, + 'method': self.get_session_info, + }, + 'VirtualDesktopAdmin.ListSessions': { + 'scope': self.SCOPE_READ, + 'method': self.list_sessions, + }, + 'VirtualDesktopAdmin.StopSessions': { + 'scope': self.SCOPE_WRITE, + 'method': self.stop_sessions, + }, + 'VirtualDesktopAdmin.RebootSessions': { + 'scope': self.SCOPE_WRITE, + 'method': self.reboot_sessions, + }, + 'VirtualDesktopAdmin.ResumeSessions': { + 'scope': self.SCOPE_WRITE, + 'method': self.resume_sessions, + }, + 'VirtualDesktopAdmin.GetSessionScreenshot': { + 'scope': self.SCOPE_READ, + 'method': self.get_session_screenshots, + }, + 'VirtualDesktopAdmin.GetSessionConnectionInfo': { + 'scope': self.SCOPE_READ, + 'method': self.get_session_connection_info, + }, + 'VirtualDesktopAdmin.CreateSoftwareStack': { + 'scope': self.SCOPE_WRITE, + 'method': self.create_software_stack, + }, + 'VirtualDesktopAdmin.UpdateSoftwareStack': { + 'scope': self.SCOPE_WRITE, + 'method': self.update_software_stack, + }, + 'VirtualDesktopAdmin.GetSoftwareStackInfo': { + 'scope': self.SCOPE_READ, + 'method': self.get_software_stack_info, + }, + 'VirtualDesktopAdmin.ListSoftwareStacks': { + 'scope': self.SCOPE_READ, + 'method': self.list_software_stacks, + }, + 'VirtualDesktopAdmin.CreateSoftwareStackFromSession': { + 'scope': self.SCOPE_WRITE, + 'method': self.create_software_stack_from_session, + }, + 'VirtualDesktopAdmin.DeleteSoftwareStack': { + 'scope': self.SCOPE_WRITE, + 'method': self.delete_software_stack, + }, + 'VirtualDesktopAdmin.CreatePermissionProfile': { + 'scope': self.SCOPE_WRITE, + 'method': self.create_permission_profile, + }, + 'VirtualDesktopAdmin.UpdatePermissionProfile': { + 'scope': self.SCOPE_WRITE, + 'method': self.update_permission_profile, + }, + 'VirtualDesktopAdmin.DeletePermissionProfile': { + 'scope': self.SCOPE_WRITE, + 'method': self.delete_permission_profile, + }, + 'VirtualDesktopAdmin.ListSessionPermissions': { + 'scope': self.SCOPE_READ, + 'method': self.list_session_permissions, + }, + 'VirtualDesktopAdmin.ListSharedPermissions': { + 'scope': self.SCOPE_READ, + 'method': self.list_shared_permissions, + }, + 'VirtualDesktopAdmin.UpdateSessionPermissions': { + 'scope': self.SCOPE_WRITE, + 'method': self.update_session_permission, + }, } def _validate_resume_session_request(self, session: VirtualDesktopSession) -> (VirtualDesktopSession, bool): @@ -114,18 +187,6 @@ def _validate_create_session_request(self, session: VirtualDesktopSession) -> (V session.failure_reason = 'Missing Create Session Info' return session, False - # validate if the user belongs within allowed group - - # Check is done for admin only since, virtual-desktop-user-api already enforces the group - # check for api-call - response = self.context.accounts_client.list_users_in_group(ListUsersInGroupRequest( - group_names=self.VDI_GROUPS - )) - - group = next(filter(lambda user: user.username == session.owner, response.listing), None) - if not group: - session.failure_reason = f'User {session.owner} is not authorized to create sessions.' - return session, False - return self.validate_create_session_request(session) @staticmethod @@ -147,10 +208,6 @@ def _validate_update_software_stack_request(software_stack: VirtualDesktopSoftwa software_stack.failure_reason = 'software_stack.base_os missing' return software_stack, False - if Utils.is_empty(software_stack.projects): - software_stack.failure_reason = 'software_stack.projects missing' - return software_stack, False - for project in software_stack.projects: if Utils.is_empty(project.project_id): software_stack.failure_reason = 'software_stack.project.project_id missing' @@ -196,6 +253,10 @@ def _validate_create_software_stack_request(self, software_stack: VirtualDesktop def create_session(self, context: ApiInvocationContext): session = context.get_request_payload_as(CreateSessionRequest).session + if session.name: + ApiUtils.validate_input(session.name, + constants.SESSION_NAME_REGEX, + constants.SESSION_NAME_ERROR_MESSAGE) session, is_valid = self._validate_create_session_request(session) if not is_valid: @@ -272,6 +333,7 @@ def get_session_screenshots(self, context: ApiInvocationContext): def get_software_stack_info(self, context: ApiInvocationContext): stack_id = context.get_request_payload_as(GetSoftwareStackInfoRequest).stack_id + base_os = context.get_request_payload_as(GetSoftwareStackInfoRequest).base_os if Utils.is_empty(stack_id): context.fail( error_code=errorcodes.INVALID_PARAMS, @@ -280,14 +342,15 @@ def get_software_stack_info(self, context: ApiInvocationContext): return context.success(GetSoftwareStackInfoResponse( - software_stack=self._get_software_stack_info(stack_id) + software_stack=self._get_software_stack_info(stack_id, base_os) )) def get_session_info(self, context: ApiInvocationContext): - session = context.get_request_payload_as(GetSessionInfoRequest).session + request = context.get_request_payload_as(GetSessionInfoRequest) + session = request.session self.validate_get_session_info_request(session) session = self.complete_get_session_info_request(session, context) - session = self._get_session_info(session) + session = self.session_db.get_from_db(session.owner, session.idea_session_id) if Utils.is_empty(session.failure_reason): context.success(GetSessionInfoResponse( session=session @@ -391,8 +454,7 @@ def list_software_stacks(self, context: ApiInvocationContext): def list_sessions(self, context: ApiInvocationContext): request = context.get_request_payload_as(ListSessionsRequest) - - result = self.session_db.list_from_index(request) + result = self.session_db.list_all_from_db(request) context.success(result) def update_session(self, context: ApiInvocationContext): @@ -481,6 +543,12 @@ def update_software_stack(self, context: ApiInvocationContext): def create_software_stack(self, context: ApiInvocationContext): software_stack = context.get_request_payload_as(CreateSoftwareStackRequest).software_stack + + if software_stack.name: + ApiUtils.validate_input(software_stack.name, + constants.SOFTWARE_STACK_NAME_REGEX, + constants.SOFTWARE_STACK_NAME_ERROR_MESSAGE) + software_stack, is_valid = self._validate_create_software_stack_request(software_stack) if not is_valid: context.fail( @@ -527,6 +595,27 @@ def create_software_stack_from_session(self, context: ApiInvocationContext): error_code=errorcodes.CREATED_SOFTWARE_STACK_FROM_SESSION_FAILED ) + def delete_software_stack(self, context: ApiInvocationContext): + software_stack = context.get_request_payload_as(DeleteSoftwareStackRequest).software_stack + if Utils.is_empty(software_stack.stack_id) or Utils.is_empty(software_stack.base_os): + context.fail( + error_code=errorcodes.INVALID_PARAMS, + message=f'Stack ID and base OS are required' + ) + + return + + sessions = self.session_db.list_all_for_software_stack(ListSessionsRequest(), software_stack) + if sessions.listing: + session_ids_by_software_stack_id = [session.dcv_session_id for session in sessions.listing] + context.fail(error_code=errorcodes.GENERAL_ERROR, message=f'Software stack is still in use by virtual desktop sessions. ' + f'Stack ID: {software_stack.stack_id}, Session IDs: {session_ids_by_software_stack_id}') + + return + + self._delete_software_stack(software_stack) + context.success(DeleteSoftwareStackResponse()) + def update_permission_profile(self, context: ApiInvocationContext): permission_profile = context.get_request_payload_as(UpdatePermissionProfileRequest).profile existing_profile = self.permission_profile_db.get(profile_id=permission_profile.profile_id) @@ -556,43 +645,22 @@ def create_permission_profile(self, context: ApiInvocationContext): profile=permission_profile )) - def re_index_software_stacks(self, context: ApiInvocationContext): - # got a request to reindex everything again. - request = ListSoftwareStackRequest() - request.disabled_also = True - response = self.software_stack_db.list_all_from_db(request) - - while True: - for software_stack in response.listing: - self.software_stack_utils.index_software_stack_entry_to_opensearch(software_stack=software_stack) - - if Utils.is_empty(response.cursor): - # this was the last page, - break - - request.paginator = response.paginator - response = self.software_stack_db.list_all_from_db(request) - - context.success(ReIndexSoftwareStacksResponse()) - - def re_index_user_sessions(self, context: ApiInvocationContext): - # got a request to reindex everything again. - - request = ListSessionsRequest() - response = self.session_db.list_all_from_db(request) - - while True: - for session in response.listing: - self.session_utils.index_session_entry_to_opensearch(session=session) - - if Utils.is_empty(response.cursor): - # this was the last page, - break - - request.paginator = response.paginator - response = self.session_db.list_all_from_db(request) + def delete_permission_profile(self, context: ApiInvocationContext): + profile_id = context.get_request_payload_as(DeletePermissionProfileRequest).profile_id + if not profile_id: + context.fail( + message="permission profile ID is required", + error_code=errorcodes.INVALID_PARAMS, + payload=DeletePermissionProfileResponse( + )) + return + existing_profile = self.permission_profile_db.get(profile_id=profile_id) + if not existing_profile: + context.success(DeletePermissionProfileResponse()) + return - context.success(ReIndexUserSessionsResponse()) + self.permission_profile_db.delete(profile_id) + context.success(DeletePermissionProfileResponse()) def get_session_connection_info(self, context: ApiInvocationContext): self._logger.info(f'received get session connection info request from user: {context.get_username()}') @@ -654,10 +722,16 @@ def list_session_permissions(self, context: ApiInvocationContext): context.success(response) def invoke(self, context: ApiInvocationContext): + namespace = context.namespace - if not context.is_authorized(elevated_access=True): + acl_entry = self.acl.get(namespace) + if acl_entry is None: raise exceptions.unauthorized_access() - namespace = context.namespace - if namespace in self.namespace_handler_map: - self.namespace_handler_map[namespace](context) + acl_entry_scope = acl_entry.get('scope') + is_authorized = context.is_authorized(elevated_access=True, scopes=[acl_entry_scope]) + + if is_authorized: + acl_entry['method'](context) + else: + raise exceptions.unauthorized_access() diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_api.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_api.py index 6c098c8..b4cf609 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_api.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_api.py @@ -96,7 +96,6 @@ def __init__(self, context: ideavirtualdesktopcontroller.AppContext): ) self.DEFAULT_ROOT_VOL_IOPS = '100' - self.VDI_GROUPS = [self.controller_utils.get_virtual_desktop_users_group(), self.controller_utils.get_virtual_desktop_admin_group()] def get_session_if_owner(self, username: str, idea_session_id: str) -> Optional[VirtualDesktopSession]: return self.session_db.get_from_db(idea_session_id=idea_session_id, idea_session_owner=username) @@ -404,8 +403,8 @@ def complete_delete_session_request(session: VirtualDesktopSession, context: Api return session - def _get_software_stack_info(self, software_stack_id: str) -> VirtualDesktopSoftwareStack: - software_stack_response = self.software_stack_db.get_from_index(stack_id=software_stack_id) + def _get_software_stack_info(self, software_stack_id: str, software_stack_base_os: str) -> VirtualDesktopSoftwareStack: + software_stack_response = self.software_stack_db.get_with_project_info(stack_id=software_stack_id, base_os=software_stack_base_os) if Utils.is_empty(software_stack_response): software_stack = VirtualDesktopSoftwareStack( stack_id=software_stack_id, @@ -415,14 +414,6 @@ def _get_software_stack_info(self, software_stack_id: str) -> VirtualDesktopSoft software_stack = software_stack_response return software_stack - def _get_session_info(self, session: VirtualDesktopSession) -> VirtualDesktopSession: - session_response = self.session_db.get_from_index(idea_session_id=session.idea_session_id) - if Utils.is_empty(session_response): - session.failure_reason = f'invalid session.res_session_id: {session.idea_session_id} for {session.name} owned by user: {session.owner}' - else: - session = session_response - return session - @staticmethod def validate_resume_session_request(session: VirtualDesktopSession) -> (VirtualDesktopSession, bool): if Utils.is_empty(session.idea_session_id): @@ -445,10 +436,6 @@ def validate_create_software_stack_request(software_stack: VirtualDesktopSoftwar software_stack.failure_reason = 'missing software_stack.name' return software_stack, False - if Utils.is_empty(software_stack.projects): - software_stack.failure_reason = 'missing software_stack.projects' - return software_stack, False - return software_stack, True @staticmethod @@ -528,7 +515,6 @@ def complete_create_session_request(self, session: VirtualDesktopSession, contex # self.default_instance_profile_arn = self.context.app_config.virtual_desktop_dcv_host_profile_arn # self.default_security_group = self.context.app_config.virtual_desktop_dcv_host_security_group_id - if Utils.is_empty(session.server.key_pair_name): session.server.key_pair_name = self.context.config().get_string('cluster.network.ssh_key_pair', required=True) @@ -706,6 +692,9 @@ def _create_software_stack_from_session(self, session: VirtualDesktopSession, ne return new_software_stack + def _delete_software_stack(self, software_stack: VirtualDesktopSoftwareStack): + self.software_stack_utils.delete_software_stack(software_stack) + def _list_session_permissions(self, idea_session_id: str, request: ListPermissionsRequest) -> ListPermissionsResponse: session_filter_found = False if Utils.is_empty(request.filters): @@ -722,7 +711,7 @@ def _list_session_permissions(self, idea_session_id: str, request: ListPermissio value=idea_session_id )) - return self.session_permissions_db.list_from_index(request) + return self.session_permissions_db.list_session_permissions(request) def _list_shared_permissions(self, username: str, request: ListPermissionsRequest) -> ListPermissionsResponse: actor_filter_found = False @@ -740,10 +729,10 @@ def _list_shared_permissions(self, username: str, request: ListPermissionsReques value=username )) - return self.session_permissions_db.list_from_index(request) + return self.session_permissions_db.list_session_permissions(request) def _list_software_stacks(self, request: ListSoftwareStackRequest) -> ListSoftwareStackResponse: - return self.software_stack_db.list_from_index(request) + return self.software_stack_db.list_all_from_db(request) def _get_session_connection_info(self, connection_info_request: VirtualDesktopSessionConnectionInfo, context: ApiInvocationContext) -> VirtualDesktopSessionConnectionInfo: connection_info = VirtualDesktopSessionConnectionInfo() diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_api_invoker.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_api_invoker.py index b0571d1..2bbad52 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_api_invoker.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_api_invoker.py @@ -12,13 +12,13 @@ import ideavirtualdesktopcontroller from ideasdk.api import ApiInvocationContext from ideasdk.app import SocaAppAPI -from ideasdk.auth import TokenService +from ideasdk.auth import TokenService, ApiAuthorizationServiceBase from ideasdk.protocols import ApiInvokerProtocol from ideavirtualdesktopcontroller.app.api.virtual_desktop_admin_api import VirtualDesktopAdminAPI from ideavirtualdesktopcontroller.app.api.virtual_desktop_dcv_api import VirtualDesktopDCVAPI from ideavirtualdesktopcontroller.app.api.virtual_desktop_user_api import VirtualDesktopUserAPI from ideavirtualdesktopcontroller.app.api.virtual_desktop_utils_api import VirtualDesktopUtilsAPI - +from typing import Optional class VirtualDesktopApiInvoker(ApiInvokerProtocol): @@ -32,8 +32,11 @@ def __init__(self, context: ideavirtualdesktopcontroller.AppContext): 'App': SocaAppAPI(context) } - def get_token_service(self) -> TokenService: + def get_token_service(self) -> Optional[TokenService]: return self._context.token_service + + def get_api_authorization_service(self) -> Optional[ApiAuthorizationServiceBase]: + return self._context.api_authorization_service def invoke(self, context: ApiInvocationContext): namespace = context.namespace diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_dcv_api.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_dcv_api.py index fccd440..e3f8583 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_dcv_api.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_dcv_api.py @@ -25,9 +25,16 @@ def __init__(self, context: ideavirtualdesktopcontroller.AppContext): super().__init__(context) self.context = context self._logger = context.logger('virtual-desktop-dcv-api') - self.namespace_handler_map = { - 'VirtualDesktopDCV.DescribeServers': self.describe_dcv_servers, - 'VirtualDesktopDCV.DescribeSessions': self.describe_dcv_sessions + self.SCOPE_READ = f'{self.context.module_id()}/read' + self.acl = { + 'VirtualDesktopDCV.DescribeServers': { + 'scope': self.SCOPE_READ, + 'method': self.describe_dcv_servers, + }, + 'VirtualDesktopDCV.DescribeSessions': { + 'scope': self.SCOPE_READ, + 'method': self.describe_dcv_sessions, + }, } def describe_dcv_servers(self, context: ApiInvocationContext): @@ -44,10 +51,16 @@ def describe_dcv_sessions(self, context: ApiInvocationContext): )) def invoke(self, context: ApiInvocationContext): + namespace = context.namespace - if not context.is_authorized(elevated_access=True): + acl_entry = self.acl.get(namespace) + if acl_entry is None: raise exceptions.unauthorized_access() - namespace = context.namespace - if namespace in self.namespace_handler_map: - self.namespace_handler_map[namespace](context) + acl_entry_scope = acl_entry.get('scope') + is_authorized = context.is_authorized(elevated_access=True, scopes=[acl_entry_scope]) + + if is_authorized: + acl_entry['method'](context) + else: + raise exceptions.unauthorized_access() diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_user_api.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_user_api.py index 1eccc30..927ec07 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_user_api.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_user_api.py @@ -44,7 +44,7 @@ from ideasdk.api import ApiInvocationContext from ideasdk.utils import Utils from ideavirtualdesktopcontroller.app.api.virtual_desktop_api import VirtualDesktopAPI -from ideavirtualdesktopcontroller.app.software_stacks.constants import SOFTWARE_STACK_DB_FILTER_PROJECT_ID_KEY +from ideavirtualdesktopcontroller.app.software_stacks.constants import SOFTWARE_STACK_DB_FILTER_PROJECT_ID_KEY, SOFTWARE_STACK_DB_PROJECTS_KEY class VirtualDesktopUserAPI(VirtualDesktopAPI): @@ -53,22 +53,66 @@ def __init__(self, context: ideavirtualdesktopcontroller.AppContext): super().__init__(context) self.context = context self._logger = context.logger('virtual-desktop-user-api') - - self.namespace_handler_map: Dict[str, ()] = { - 'VirtualDesktop.CreateSession': self.create_session, - 'VirtualDesktop.UpdateSession': self.update_session, - 'VirtualDesktop.DeleteSessions': self.delete_sessions, - 'VirtualDesktop.GetSessionInfo': self.get_session_info, - 'VirtualDesktop.GetSessionScreenshot': self.get_session_screenshots, - 'VirtualDesktop.GetSessionConnectionInfo': self.get_session_connection_info, - 'VirtualDesktop.ListSessions': self.list_sessions, - 'VirtualDesktop.StopSessions': self.stop_sessions, - 'VirtualDesktop.ResumeSessions': self.resume_sessions, - 'VirtualDesktop.RebootSessions': self.reboot_sessions, - 'VirtualDesktop.ListSoftwareStacks': self.list_software_stacks, - 'VirtualDesktop.ListSharedPermissions': self.list_shared_permissions, - 'VirtualDesktop.ListSessionPermissions': self.list_session_permissions, - 'VirtualDesktop.UpdateSessionPermissions': self.update_session_permission + self.SCOPE_WRITE = f'{self.context.module_id()}/write' + self.SCOPE_READ = f'{self.context.module_id()}/read' + + self.acl = { + 'VirtualDesktop.CreateSession': { + 'scope': self.SCOPE_WRITE, + 'method': self.create_session, + }, + 'VirtualDesktop.UpdateSession': { + 'scope': self.SCOPE_WRITE, + 'method': self.update_session, + }, + 'VirtualDesktop.DeleteSessions': { + 'scope': self.SCOPE_WRITE, + 'method': self.delete_sessions, + }, + 'VirtualDesktop.GetSessionInfo': { + 'scope': self.SCOPE_READ, + 'method': self.get_session_info, + }, + 'VirtualDesktop.GetSessionScreenshot': { + 'scope': self.SCOPE_READ, + 'method': self.get_session_screenshots, + }, + 'VirtualDesktop.GetSessionConnectionInfo': { + 'scope': self.SCOPE_READ, + 'method': self.get_session_connection_info, + }, + 'VirtualDesktop.ListSessions': { + 'scope': self.SCOPE_READ, + 'method': self.list_sessions, + }, + 'VirtualDesktop.StopSessions': { + 'scope': self.SCOPE_WRITE, + 'method': self.stop_sessions, + }, + 'VirtualDesktop.ResumeSessions': { + 'scope': self.SCOPE_WRITE, + 'method': self.resume_sessions, + }, + 'VirtualDesktop.RebootSessions': { + 'scope': self.SCOPE_WRITE, + 'method': self.reboot_sessions, + }, + 'VirtualDesktop.ListSoftwareStacks': { + 'scope': self.SCOPE_READ, + 'method': self.list_software_stacks, + }, + 'VirtualDesktop.ListSharedPermissions': { + 'scope': self.SCOPE_READ, + 'method': self.list_shared_permissions, + }, + 'VirtualDesktop.ListSessionPermissions': { + 'scope': self.SCOPE_READ, + 'method': self.list_session_permissions, + }, + 'VirtualDesktop.UpdateSessionPermissions': { + 'scope': self.SCOPE_WRITE, + 'method': self.update_session_permission + }, } @staticmethod @@ -284,7 +328,7 @@ def get_session_info(self, context: ApiInvocationContext): return session = self.complete_get_session_info_request(session, context) - session = self._get_session_info(session) + session = self.session_db.get_from_db(session.owner, session.idea_session_id) if Utils.is_empty(session.failure_reason): context.success(GetSessionInfoResponse( @@ -437,7 +481,7 @@ def list_software_stacks(self, context: ApiInvocationContext): if not project_filter_found: request.add_filter(SocaFilter( - key=SOFTWARE_STACK_DB_FILTER_PROJECT_ID_KEY, + key=SOFTWARE_STACK_DB_PROJECTS_KEY, value=project_id )) @@ -499,10 +543,16 @@ def update_session_permission(self, context: ApiInvocationContext): context.success(response) def invoke(self, context: ApiInvocationContext): - - if not context.is_authorized_user(): + namespace = context.namespace + + acl_entry = self.acl.get(namespace) + if acl_entry is None: raise exceptions.unauthorized_access() - namespace = context.namespace - if namespace in self.namespace_handler_map: - self.namespace_handler_map[namespace](context) + acl_entry_scope = acl_entry.get('scope') + is_authorized = context.is_authorized(elevated_access=False, scopes=[acl_entry_scope]) + self._logger.info(f'Is authorized: {is_authorized}') + if is_authorized: + acl_entry['method'](context) + else: + raise exceptions.unauthorized_access() diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_utils_api.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_utils_api.py index bad7eb4..e3fc429 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_utils_api.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/api/virtual_desktop_utils_api.py @@ -37,16 +37,41 @@ def __init__(self, context: ideavirtualdesktopcontroller.AppContext): super().__init__(context) self.context = context self._logger = context.logger('virtual-desktop-utils-api') - - self.namespace_handler_map = { - 'VirtualDesktopUtils.ListSupportedOS': self.list_supported_os, - 'VirtualDesktopUtils.ListSupportedGPU': self.list_supported_gpu, - 'VirtualDesktopUtils.ListScheduleTypes': self.list_schedule_types, - 'VirtualDesktopUtils.ListAllowedInstanceTypes': self.list_allowed_instance_types, - 'VirtualDesktopUtils.ListAllowedInstanceTypesForSession': self.list_allowed_instance_types_for_session, - 'VirtualDesktopUtils.GetBasePermissions': self.get_base_permissions, - 'VirtualDesktopUtils.ListPermissionProfiles': self.list_permission_profiles, - 'VirtualDesktopUtils.GetPermissionProfile': self.get_permission_profile + self.SCOPE_READ = f'{self.context.module_id()}/read' + + self.acl = { + 'VirtualDesktopUtils.ListSupportedOS': { + 'scope': self.SCOPE_READ, + 'method': self.list_supported_os, + }, + 'VirtualDesktopUtils.ListSupportedGPU': { + 'scope': self.SCOPE_READ, + 'method': self.list_supported_gpu, + }, + 'VirtualDesktopUtils.ListScheduleTypes': { + 'scope': self.SCOPE_READ, + 'method': self.list_schedule_types, + }, + 'VirtualDesktopUtils.ListAllowedInstanceTypes': { + 'scope': self.SCOPE_READ, + 'method': self.list_allowed_instance_types, + }, + 'VirtualDesktopUtils.ListAllowedInstanceTypesForSession': { + 'scope': self.SCOPE_READ, + 'method': self.list_allowed_instance_types_for_session, + }, + 'VirtualDesktopUtils.GetBasePermissions': { + 'scope': self.SCOPE_READ, + 'method': self.get_base_permissions, + }, + 'VirtualDesktopUtils.ListPermissionProfiles': { + 'scope': self.SCOPE_READ, + 'method': self.list_permission_profiles, + }, + 'VirtualDesktopUtils.GetPermissionProfile': { + 'scope': self.SCOPE_READ, + 'method': self.get_permission_profile, + }, } def get_base_permissions(self, context: ApiInvocationContext): @@ -166,10 +191,16 @@ def list_supported_os(context: ApiInvocationContext): )) def invoke(self, context: ApiInvocationContext): + namespace = context.namespace - if not context.is_authorized_user(): + acl_entry = self.acl.get(namespace) + if acl_entry is None: raise exceptions.unauthorized_access() - namespace = context.namespace - if namespace in self.namespace_handler_map: - self.namespace_handler_map[namespace](context) + acl_entry_scope = acl_entry.get('scope') + is_authorized = context.is_authorized(elevated_access=False, scopes=[acl_entry_scope]) + + if is_authorized: + acl_entry['method'](context) + else: + raise exceptions.unauthorized_access() diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/app_context.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/app_context.py index 02a8cab..a6a53c1 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/app_context.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/app_context.py @@ -13,7 +13,7 @@ from typing import Optional from ideadatamodel import constants -from ideasdk.auth import TokenService +from ideasdk.auth import TokenService, ApiAuthorizationServiceBase from ideasdk.client import NotificationsAsyncClient, ProjectsClient, AccountsClient from ideasdk.context import SocaContext, SocaContextOptions from ideasdk.service import SocaService @@ -30,6 +30,7 @@ def __init__(self, options: SocaContextOptions): ) self.token_service: Optional[TokenService] = None + self.api_authorization_service: Optional[ApiAuthorizationServiceBase] = None self.dcv_broker_client: Optional[DCVClientProtocol] = None self.event_queue_monitor_service: Optional[SocaService] = None self.controller_queue_monitor_service: Optional[SocaService] = None diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/app_main.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/app_main.py index d1af1e5..5a87b96 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/app_main.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/app_main.py @@ -55,7 +55,6 @@ def main(**kwargs): enable_leader_election=True, enable_metrics=True, metrics_namespace=f'{cluster_name}/{module_id}/controller', - enable_analytics=True ) ), **kwargs diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/auth/__init__.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/auth/__init__.py new file mode 100644 index 0000000..59d9e03 --- /dev/null +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/auth/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. \ No newline at end of file diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/auth/api_authorization_service.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/auth/api_authorization_service.py new file mode 100644 index 0000000..ae7b1f0 --- /dev/null +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/auth/api_authorization_service.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from ideasdk.auth.api_authorization_service_base import ApiAuthorizationServiceBase +from ideasdk.auth.token_service import TokenService +from ideasdk.client import AccountsClient +from ideadatamodel.auth import GetUserByEmailRequest, GetUserRequest, User +from ideadatamodel import exceptions, errorcodes +from typing import Optional + +class VdcApiAuthorizationService(ApiAuthorizationServiceBase): + def __init__(self, accounts_client: AccountsClient, token_service: TokenService): + self.accounts_client = accounts_client + self.token_service = token_service + + def get_user_from_token_username(self, token_username: str) -> Optional[User]: + if not token_username: + raise exceptions.unauthorized_access() + + email = self.token_service.get_email_from_token_username(token_username=token_username) + user = None + if email: + user = self.accounts_client.get_user_by_email(request=GetUserByEmailRequest(email=email)).user + else: + # This is for clusteradmin + user = self.accounts_client.get_user(request=GetUserRequest(username=token_username)).user + if not user: + exception_string = f'email: {email}' if email else f'username: {username}' + raise exceptions.SocaException( + error_code=errorcodes.AUTH_USER_NOT_FOUND, + message=f'User not found with {exception_string}' + ) + return user diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/events/handlers/db_entry_event_handlers/db_entry_created_event_handler.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/events/handlers/db_entry_event_handlers/db_entry_created_event_handler.py index c49ec9b..d8e7b19 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/events/handlers/db_entry_event_handlers/db_entry_created_event_handler.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/events/handlers/db_entry_event_handlers/db_entry_created_event_handler.py @@ -32,13 +32,11 @@ def __init__(self, context: ideavirtualdesktopcontroller.AppContext): } def _handle_software_stack_created(self, _: str, __: str, new_value: dict, ___: str): - software_stack = self.software_stack_db.convert_db_dict_to_software_stack_object(new_value) - self.software_stack_utils.index_software_stack_entry_to_opensearch(software_stack) + self._logger.debug(f'created entry for {new_value}. No=OP. Returning') def _handle_user_session_created(self, _: str, __: str, new_value: dict, ___: str): session = self.session_db.convert_db_dict_to_session_object(new_value) self._notify_session_owner_of_state_update(session) - self.session_utils.index_session_entry_to_opensearch(session=session) def _handle_permission_profile_created(self, _: str, __: str, ___: dict, table_name: str): self._logger.debug(f'created entry for {table_name} not handled. No=OP. Returning') @@ -51,8 +49,7 @@ def _handle_dcv_host_created(self, _: str, __: str, ___: dict, table_name: str): def _handle_session_permission_created(self, _: str, __: str, new_value: dict, ___: str): session_permission = self.session_permissions_db.convert_db_dict_to_session_permission_object(new_value) - self.session_permission_utils.index_session_permission_to_opensearch(session_permission=session_permission) - + notifications_enabled = self.context.config().get_bool('virtual-desktop-controller.dcv_session.notifications.session-shared.enabled', required=True) if not notifications_enabled: return diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/events/handlers/db_entry_event_handlers/db_entry_deleted_event_handler.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/events/handlers/db_entry_event_handlers/db_entry_deleted_event_handler.py index a06cfbf..b9fb3b9 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/events/handlers/db_entry_event_handlers/db_entry_deleted_event_handler.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/events/handlers/db_entry_event_handlers/db_entry_deleted_event_handler.py @@ -32,8 +32,7 @@ def __init__(self, context: ideavirtualdesktopcontroller.AppContext): } def _handle_software_stack_deleted(self, _: str, __: str, deleted_value: dict, ___: str): - software_stack = self.software_stack_db.convert_db_dict_to_software_stack_object(deleted_value) - self.software_stack_utils.delete_software_stack_entry_from_opensearch(software_stack.stack_id) + self._logger.debug(f'deleted entry for {deleted_value}. No=OP. Returning') def _handle_permission_profile_deleted(self, _: str, __: str, ___: dict, table_name: str): self._logger.debug(f'deleted entry for {table_name} not handled. No=OP. Returning') @@ -44,14 +43,12 @@ def _handle_schedule_deleted(self, _: str, __: str, ___: dict, table_name: str): def _handle_user_session_deleted(self, _: str, __: str, deleted_value: dict, ___: str): session = self.session_db.convert_db_dict_to_session_object(deleted_value) self._notify_session_owner_of_state_update(session, deleted=True) - self.session_utils.delete_session_entry_from_opensearch(session.idea_session_id) def _handle_dcv_host_deleted(self, _: str, __: str, ___: dict, table_name: str): self._logger.debug(f'deleted entry for {table_name} not handled. No=OP. Returning') def _handle_session_permission_deleted(self, _: str, __: str, deleted_value: dict, ___: str): deleted_session_permission = self.session_permissions_db.convert_db_dict_to_session_permission_object(deleted_value) - self.session_permission_utils.delete_session_entry_from_opensearch(deleted_session_permission) notifications_enabled = self.context.config().get_bool('virtual-desktop-controller.dcv_session.notifications.session-permission-expired.enabled', required=True) if not notifications_enabled: diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/events/handlers/db_entry_event_handlers/db_entry_updated_event_handler.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/events/handlers/db_entry_event_handlers/db_entry_updated_event_handler.py index a87890e..4a187de 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/events/handlers/db_entry_event_handlers/db_entry_updated_event_handler.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/events/handlers/db_entry_event_handlers/db_entry_updated_event_handler.py @@ -35,7 +35,6 @@ def __init__(self, context: ideavirtualdesktopcontroller.AppContext): def _handle_software_stack_updated(self, _: str, __: str, old_value: dict, new_value: dict, ___: str): new_software_stack = self.software_stack_db.convert_db_dict_to_software_stack_object(new_value) - self.software_stack_utils.update_software_stack_entry_to_opensearch(new_software_stack) old_software_stack = self.software_stack_db.convert_db_dict_to_software_stack_object(old_value) if new_software_stack.name != old_software_stack.name or new_software_stack.description != old_software_stack.description: @@ -46,7 +45,7 @@ def _handle_software_stack_updated(self, _: str, __: str, old_value: dict, new_v key=USER_SESSION_DB_FILTER_SOFTWARE_STACK_ID_KEY, value=new_software_stack.stack_id )) - response = self.session_db.list_from_index(request) + response = self.session_db.list_all_from_db(request) for session in response.listing: idea_session_info.add((session.idea_session_id, session.owner)) @@ -62,8 +61,6 @@ def _handle_schedule_updated(self, _: str, __: str, ___: dict, ____: dict, table def _handle_user_session_updated(self, _: str, __: str, old_value: dict, new_value: dict, ___: str): new_session = self.session_db.convert_db_dict_to_session_object(new_value) - self.session_utils.update_session_entry_to_opensearch(new_session) - old_session = self.session_db.convert_db_dict_to_session_object(old_value) publish_permission_update_event = False if old_session.state != new_session.state: @@ -101,7 +98,7 @@ def _handle_permission_profile_updated(self, _: str, __: str, old_value: dict, n break if is_permission_updated: - response = self.session_permissions_db.list_from_index(ListPermissionsRequest( + response = self.session_permissions_db.list_session_permissions(ListPermissionsRequest( filters=[SocaFilter( key=SESSION_PERMISSIONS_FILTER_PERMISSION_PROFILE_ID_KEY, value=new_permission_profile.profile_id @@ -119,8 +116,7 @@ def _handle_permission_profile_updated(self, _: str, __: str, old_value: dict, n def _handle_session_permission_updated(self, _: str, __: str, old_value: dict, new_value: dict, ___: str): new_session_permission = self.session_permissions_db.convert_db_dict_to_session_permission_object(new_value) - self.session_permission_utils.update_session_entry_to_opensearch(new_session_permission) - + notifications_enabled = self.context.config().get_bool('virtual-desktop-controller.dcv_session.notifications.session-permission-updated.enabled', required=True) if not notifications_enabled: return diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/session_permissions/virtual_desktop_session_permission_db.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/session_permissions/virtual_desktop_session_permission_db.py index 01f8408..5ca32a8 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/session_permissions/virtual_desktop_session_permission_db.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/session_permissions/virtual_desktop_session_permission_db.py @@ -15,7 +15,6 @@ import ideavirtualdesktopcontroller from ideadatamodel import ( exceptions, - SocaSortBy, VirtualDesktopSessionPermission, VirtualDesktopBaseOS, VirtualDesktopSessionState, @@ -26,15 +25,13 @@ ListPermissionsResponse, SocaPaginator ) -from ideadatamodel.common.common_model import SocaSortOrder -from ideasdk.aws.opensearch.opensearchable_db import OpenSearchableDB from ideasdk.utils import Utils from ideavirtualdesktopcontroller.app.session_permissions import constants as session_permissions_constants from ideavirtualdesktopcontroller.app.virtual_desktop_notifiable_db import VirtualDesktopNotifiableDB -class VirtualDesktopSessionPermissionDB(VirtualDesktopNotifiableDB, OpenSearchableDB): +class VirtualDesktopSessionPermissionDB(VirtualDesktopNotifiableDB): DEFAULT_PAGE_SIZE = 10 def __init__(self, context: ideavirtualdesktopcontroller.AppContext): @@ -43,23 +40,6 @@ def __init__(self, context: ideavirtualdesktopcontroller.AppContext): self._table_obj = None self._ddb_client = self.context.aws().dynamodb_table() VirtualDesktopNotifiableDB.__init__(self, context=context, table_name=self.table_name, logger=self._logger) - OpenSearchableDB.__init__( - self, context=context, logger=self._logger, - term_filter_map={ - session_permissions_constants.SESSION_PERMISSIONS_FILTER_ACTOR_KEY: 'actor_name.raw', - session_permissions_constants.SESSION_PERMISSIONS_FILTER_SESSION_ID_KEY: 'idea_session_id.raw', - session_permissions_constants.SESSION_PERMISSIONS_FILTER_SESSION_NAME_KEY: 'idea_session_name.raw', - session_permissions_constants.SESSION_PERMISSIONS_FILTER_SESSION_BASE_OS_KEY: 'idea_session_base_os.raw', - session_permissions_constants.SESSION_PERMISSIONS_FILTER_SESSION_INSTANCE_TYPE_KEY: 'idea_session_instance_type.raw', - session_permissions_constants.SESSION_PERMISSIONS_FILTER_SESSION_STATE_KEY: 'idea_session_state.raw', - session_permissions_constants.SESSION_PERMISSIONS_FILTER_SESSION_TYPE_KEY: 'idea_session_type.raw', - session_permissions_constants.SESSION_PERMISSIONS_FILTER_PERMISSION_PROFILE_ID_KEY: 'permission_profile_id.raw' - }, - date_range_filter_map={ - session_permissions_constants.SESSION_PERMISSIONS_DB_IDEA_SESSION_CREATED_ON_KEY: 'idea_session_created_on' - }, - default_page_size=self.DEFAULT_PAGE_SIZE - ) @property def _table(self): @@ -102,15 +82,6 @@ def initialize(self): wait=True ) - def get_index_name(self) -> str: - return f"{self.context.config().get_string('virtual-desktop-controller.opensearch.session_permission.alias', required=True)}-{self.context.session_permission_template_version}" - - def get_default_sort(self) -> SocaSortBy: - return SocaSortBy( - key=session_permissions_constants.SESSION_PERMISSIONS_FILTER_SESSION_CREATED_ON_KEY, - order=SocaSortOrder.ASC - ) - @staticmethod def convert_db_dict_to_session_permission_object(db_entry: Dict) -> Optional[VirtualDesktopSessionPermission]: return VirtualDesktopSessionPermission( @@ -273,19 +244,49 @@ def list_all_from_db(self, cursor: str) -> (List[VirtualDesktopSessionPermission self._logger.error(e) finally: return permissions, response_cursor - - def list_from_index(self, options: ListPermissionsRequest) -> ListPermissionsResponse: - response = self.list_from_opensearch(options) - permissions: List[VirtualDesktopSessionPermission] = [] - permissions_response = Utils.get_value_as_list('hits', Utils.get_value_as_dict('hits', response, default={}), default=[]) - for permission in permissions_response: - index_object = Utils.get_value_as_dict('_source', permission, default={}) - permissions.append(self.convert_db_dict_to_session_permission_object(index_object)) - + + def list_session_permissions(self, request: ListPermissionsRequest) -> ListPermissionsResponse: + list_request = {} + + exclusive_start_key = None + if Utils.is_not_empty(request.cursor): + exclusive_start_key = Utils.from_json(Utils.base64_decode(request.cursor)) + if exclusive_start_key is not None: + list_request['ExclusiveStartKey'] = exclusive_start_key + + scan_filter = None + if Utils.is_not_empty(request.filters): + scan_filter = {} + for filter_ in request.filters: + if filter_.eq is not None: + scan_filter[filter_.key] = { + 'AttributeValueList': [filter_.eq], + 'ComparisonOperator': 'EQ' + } + if filter_.value is not None: + scan_filter[filter_.key] = { + 'AttributeValueList': [filter_.value], + 'ComparisonOperator': 'CONTAINS' + } + if filter_.like is not None: + scan_filter[filter_.key] = { + 'AttributeValueList': [filter_.like], + 'ComparisonOperator': 'CONTAINS' + } + if scan_filter is not None: + list_request['ScanFilter'] = scan_filter + + list_result = self._table.scan(**list_request) + + session_permissions_entries = list_result.get('Items', []) + result = [self.convert_db_dict_to_session_permission_object(session_permission) for session_permission in session_permissions_entries] + + exclusive_start_key = list_result.get("LastEvaluatedKey") + response_cursor = Utils.base64_encode(Utils.to_json(exclusive_start_key)) if exclusive_start_key else None + return ListPermissionsResponse( - listing=permissions, - paginator={}, - filters=options.filters, - date_range=options.date_range, - sort_by=options.sort_by + listing=result, + paginator=SocaPaginator( + cursor=response_cursor + ) ) diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/session_permissions/virtual_desktop_session_permission_utils.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/session_permissions/virtual_desktop_session_permission_utils.py index 5549986..e795549 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/session_permissions/virtual_desktop_session_permission_utils.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/session_permissions/virtual_desktop_session_permission_utils.py @@ -12,7 +12,6 @@ import ideavirtualdesktopcontroller from ideadatamodel import VirtualDesktopSessionPermission, VirtualDesktopSession, VirtualDesktopBaseOS, UpdateSessionPermissionRequest, UpdateSessionPermissionResponse -from ideasdk.analytics.analytics_service import AnalyticsEntry, EntryAction, EntryContent from ideasdk.utils import Utils from ideavirtualdesktopcontroller.app.events.events_utils import EventsUtils from ideavirtualdesktopcontroller.app.permission_profiles.virtual_desktop_permission_profile_db import VirtualDesktopPermissionProfileDB @@ -36,38 +35,6 @@ def __init__(self, context: ideavirtualdesktopcontroller.AppContext, db: Virtual def _generate_entry_id(session_permission: VirtualDesktopSessionPermission) -> str: return f'permission-{session_permission.idea_session_id}-{session_permission.actor_name}' - def index_session_permission_to_opensearch(self, session_permission: VirtualDesktopSessionPermission): - index_name = f"{self.context.config().get_string('virtual-desktop-controller.opensearch.session_permission.alias', required=True)}-{self.context.session_permission_template_version}" - self.context.analytics_service().post_entry(AnalyticsEntry( - entry_id=self._generate_entry_id(session_permission), - entry_action=EntryAction.CREATE_ENTRY, - entry_content=EntryContent( - index_id=index_name, - entry_record=self._session_permission_db.convert_session_permission_object_to_db_dict(session_permission) - ) - )) - - def delete_session_entry_from_opensearch(self, session_permission: VirtualDesktopSessionPermission): - index_name = f"{self.context.config().get_string('virtual-desktop-controller.opensearch.session_permission.alias', required=True)}-{self.context.session_permission_template_version}" - self.context.analytics_service().post_entry(AnalyticsEntry( - entry_id=self._generate_entry_id(session_permission), - entry_action=EntryAction.DELETE_ENTRY, - entry_content=EntryContent( - index_id=index_name - ) - )) - - def update_session_entry_to_opensearch(self, session_permission: VirtualDesktopSessionPermission): - index_name = f"{self.context.config().get_string('virtual-desktop-controller.opensearch.session_permission.alias', required=True)}-{self.context.session_permission_template_version}" - self.context.analytics_service().post_entry(AnalyticsEntry( - entry_id=self._generate_entry_id(session_permission), - entry_action=EntryAction.UPDATE_ENTRY, - entry_content=EntryContent( - index_id=index_name, - entry_record=self._session_permission_db.convert_session_permission_object_to_db_dict(session_permission) - ) - )) - def _retrieve_permission_profile_values(self, profile_id) -> (List[str], List[str]): allow_permissions: List[str] = [] deny_permissions: List[str] = [] @@ -96,10 +63,7 @@ def generate_permissions_for_session(self, session: VirtualDesktopSession, for_b new_line_char = self.WINDOWS_POWERSHELL_NEW_LINE permission = f'[groups]{new_line_char}' - if session.software_stack.base_os == VirtualDesktopBaseOS.WINDOWS: - permission += f'group:ideaadmin=user:{admin_username}{new_line_char}' - else: - permission += f'group:ideaadmin=user:{admin_username}, osgroup:{self.controller_utils.get_virtual_desktop_admin_group()}{new_line_char}' + permission += f'group:ideaadmin=user:{admin_username}{new_line_char}' permission += f'[aliases]{new_line_char}' permission_profiles = set() diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/sessions/constants.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/sessions/constants.py index 626159d..bddc6e8 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/sessions/constants.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/sessions/constants.py @@ -11,6 +11,7 @@ USER_SESSION_DB_DESCRIPTION_KEY = 'description' USER_SESSION_DB_DCV_SESSION_ID_KEY = 'dcv_session_id' USER_SESSION_DB_SESSION_TYPE_KEY = 'session_type' +USER_SESSION_DB_SESSION_TAGS_KEY = 'session_tags' USER_SESSION_DB_STATE_KEY = 'state' USER_SESSION_DB_SCHEDULE_KEYS = { DayOfWeek.MONDAY: 'monday_schedule', diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/sessions/virtual_desktop_session_db.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/sessions/virtual_desktop_session_db.py index 140db0d..803793b 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/sessions/virtual_desktop_session_db.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/sessions/virtual_desktop_session_db.py @@ -8,7 +8,7 @@ # or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions # and limitations under the License. -from typing import Union, Dict, List, Optional +from typing import Dict, Optional import ideavirtualdesktopcontroller from ideadatamodel import ( @@ -16,21 +16,19 @@ VirtualDesktopBaseOS, VirtualDesktopSessionType, VirtualDesktopSessionState, + VirtualDesktopSoftwareStack, VirtualDesktopWeekSchedule, Project, ListSessionsRequest, SocaFilter, ListSessionsResponse, SocaPaginator, - SocaSortBy, SocaListingPayload, DayOfWeek ) from ideadatamodel import exceptions -from ideadatamodel.common.common_model import SocaSortOrder -from ideasdk.aws.opensearch.opensearchable_db import OpenSearchableDB -from ideasdk.utils import Utils, DateTimeUtils +from ideasdk.utils import Utils, scan_db_records from ideavirtualdesktopcontroller.app.schedules.virtual_desktop_schedule_db import VirtualDesktopScheduleDB from ideavirtualdesktopcontroller.app.servers.virtual_desktop_server_db import VirtualDesktopServerDB from ideavirtualdesktopcontroller.app.software_stacks.virtual_desktop_software_stack_db import VirtualDesktopSoftwareStackDB @@ -39,7 +37,7 @@ from ideavirtualdesktopcontroller.app.sessions import constants as sessions_constants -class VirtualDesktopSessionDB(VirtualDesktopNotifiableDB, OpenSearchableDB): +class VirtualDesktopSessionDB(VirtualDesktopNotifiableDB): DEFAULT_PAGE_SIZE = 10 def __init__(self, context: ideavirtualdesktopcontroller.AppContext, server_db: VirtualDesktopServerDB, software_stack_db: VirtualDesktopSoftwareStackDB, schedule_db: VirtualDesktopScheduleDB): @@ -54,23 +52,6 @@ def __init__(self, context: ideavirtualdesktopcontroller.AppContext, server_db: self._controller_utils = VirtualDesktopControllerUtils(self.context) self._ddb_client = self.context.aws().dynamodb_table() VirtualDesktopNotifiableDB.__init__(self, context=context, table_name=self.table_name, logger=self._logger) - OpenSearchableDB.__init__( - self, context=context, logger=self._logger, - term_filter_map={ - sessions_constants.USER_SESSION_DB_FILTER_SOFTWARE_STACK_ID_KEY: 'software_stack.stack_id.raw', - sessions_constants.USER_SESSION_DB_FILTER_BASE_OS_KEY: 'base_os.raw', - sessions_constants.USER_SESSION_DB_FILTER_OWNER_KEY: 'owner.raw', - sessions_constants.USER_SESSION_DB_FILTER_IDEA_SESSION_ID_KEY: 'idea_session_id.raw', - sessions_constants.USER_SESSION_DB_FILTER_STATE_KEY: 'state.raw', - sessions_constants.USER_SESSION_DB_FILTER_SESSION_TYPE_KEY: 'session_type.raw', - sessions_constants.USER_SESSION_DB_FILTER_INSTANCE_TYPE_KEY: 'server.instance_type.raw' - }, - date_range_filter_map={ - sessions_constants.USER_SESSION_DB_FILTER_CREATED_ON_KEY: 'created_on', - sessions_constants.USER_SESSION_DB_FILTER_UPDATED_ON_KEY: 'updated_on' - }, - default_page_size=self.DEFAULT_PAGE_SIZE - ) @property def _table(self): @@ -154,7 +135,8 @@ def convert_db_dict_to_session_object(self, db_entry: Dict) -> Optional[VirtualD project_id=Utils.get_value_as_string(sessions_constants.USER_SESSION_DB_PROJECT_ID_KEY, Utils.get_value_as_dict(sessions_constants.USER_SESSION_DB_PROJECT_KEY, db_entry, {}), None), name=Utils.get_value_as_string(sessions_constants.USER_SESSION_DB_PROJECT_NAME_KEY, Utils.get_value_as_dict(sessions_constants.USER_SESSION_DB_PROJECT_KEY, db_entry, {}), None), title=Utils.get_value_as_string(sessions_constants.USER_SESSION_DB_PROJECT_TITLE_KEY, Utils.get_value_as_dict(sessions_constants.USER_SESSION_DB_PROJECT_KEY, db_entry, {}), None) - ) + ), + tags=Utils.get_value_as_list(sessions_constants.USER_SESSION_DB_SESSION_TAGS_KEY, db_entry) ) def convert_session_object_to_db_dict(self, session: VirtualDesktopSession) -> Dict: @@ -191,66 +173,12 @@ def convert_session_object_to_db_dict(self, session: VirtualDesktopSession) -> D sessions_constants.USER_SESSION_DB_SCHEDULE_KEYS[DayOfWeek.THURSDAY]: self._schedule_db.convert_schedule_object_to_db_dict(session.schedule.thursday), sessions_constants.USER_SESSION_DB_SCHEDULE_KEYS[DayOfWeek.FRIDAY]: self._schedule_db.convert_schedule_object_to_db_dict(session.schedule.friday), sessions_constants.USER_SESSION_DB_SCHEDULE_KEYS[DayOfWeek.SATURDAY]: self._schedule_db.convert_schedule_object_to_db_dict(session.schedule.saturday), - sessions_constants.USER_SESSION_DB_SCHEDULE_KEYS[DayOfWeek.SUNDAY]: self._schedule_db.convert_schedule_object_to_db_dict(session.schedule.sunday) + sessions_constants.USER_SESSION_DB_SCHEDULE_KEYS[DayOfWeek.SUNDAY]: self._schedule_db.convert_schedule_object_to_db_dict(session.schedule.sunday), + sessions_constants.USER_SESSION_DB_SESSION_TAGS_KEY: session.tags } return db_dict - def convert_session_object_to_index_dict(self, session: VirtualDesktopSession) -> Dict: - db_dict = self.convert_session_object_to_db_dict(session) - response = self._ec2_client.describe_instances( - InstanceIds=[session.server.instance_id] - ) - reservation = Utils.get_value_as_list('Reservations', response, default=[]) - if Utils.is_empty(reservation): - return db_dict - - db_dict["server"]["reservation_id"] = Utils.get_value_as_string('ReservationId', reservation[0], 'N/A') - instances = Utils.get_value_as_list('Instances', reservation[0], default=[]) - - if Utils.is_empty(instances): - return db_dict - instance = instances[0] - db_dict["server"]["private_ip"] = Utils.get_value_as_string('PrivateIpAddress', instance, default=None) - db_dict["server"]["public_ip"] = Utils.get_value_as_string('PublicIpAddress', instance, default=None) - - launch_time = Utils.get_value_as_string('LaunchTime', instance, default=None) - if Utils.is_not_empty(launch_time): - launch_time = Utils.to_milliseconds(DateTimeUtils.to_utc_datetime_from_iso_format(launch_time)) - db_dict["server"]["launch_time"] = launch_time - - db_dict["server"]["tags"] = [] - tags = Utils.get_value_as_list('Tags', instance, default=[]) - for tag in tags: - db_dict["server"]["tags"].append({ - "key": Utils.get_value_as_string("Key", tag, default=None), - "value": Utils.get_value_as_string("Value", tag, default=None) - }) - - placement = Utils.get_value_as_dict('Placement', instance, default={}) - db_dict["server"]["availability_zone"] = Utils.get_value_as_string('AvailabilityZone', placement, default=None) - db_dict["server"]["tenancy"] = Utils.get_value_as_string('Tenancy', placement, default=None) - - instance_type = Utils.get_value_as_string('InstanceType', instance, default=None) - db_dict["server"]["instance_type"] = instance_type - - instance_type_info = self._controller_utils.get_instance_type_info(instance_type) - default_vcpus = Utils.get_value_as_int("DefaultVCpus", Utils.get_value_as_dict("VCpuInfo", instance_type_info, default={}), default=0) - memory_size_in_mb = Utils.get_value_as_int("SizeInMiB", Utils.get_value_as_dict("MemoryInfo", instance_type_info, default={}), default=0) - total_gpu_memory_in_mb = Utils.get_value_as_int("TotalGpuMemoryInMiB", Utils.get_value_as_dict("GpuInfo", instance_type_info, default={}), default=0) - - db_dict["server"]["default_vcpus"] = default_vcpus - db_dict["server"]["memory_size_in_mb"] = memory_size_in_mb - db_dict["server"]["total_gpu_memory_in_mb"] = total_gpu_memory_in_mb - return db_dict - - def convert_index_dict_to_session_object(self, index_entry: Dict) -> VirtualDesktopSession: - session = self.convert_db_dict_to_session_object(index_entry) - server_dict = Utils.get_value_as_dict('server', index_entry, {}) - session.server.private_ip = Utils.get_value_as_string('private_ip', server_dict, None) - session.server.public_ip = Utils.get_value_as_string('public_ip', server_dict, None) - return session - def create(self, session: VirtualDesktopSession) -> VirtualDesktopSession: db_entry = self.convert_session_object_to_db_dict(session) db_entry[sessions_constants.USER_SESSION_DB_CREATED_ON_KEY] = Utils.current_time_ms() @@ -326,43 +254,6 @@ def get_from_db(self, idea_session_owner: str, idea_session_id: str) -> Optional raise e return self.convert_db_dict_to_session_object(session_db_entry) - def get_from_index(self, idea_session_id: str) -> Union[VirtualDesktopSession, None]: - request = ListSessionsRequest() - request.add_filter(SocaFilter( - key=sessions_constants.USER_SESSION_DB_FILTER_IDEA_SESSION_ID_KEY, - value=idea_session_id - )) - response = self.list_from_index(request) - if Utils.is_empty(response.listing): - return None - return response.listing[0] - - def list_from_index(self, options: ListSessionsRequest) -> ListSessionsResponse: - response = self._list_from_index(options) - - sessions: List[VirtualDesktopSession] = [] - session_responses = Utils.get_value_as_list('hits', Utils.get_value_as_dict('hits', response, default={}), default=[]) - for session_response in session_responses: - index_object = Utils.get_value_as_dict('_source', session_response, default={}) - sessions.append(self.convert_index_dict_to_session_object(index_object)) - - return ListSessionsResponse( - listing=sessions, - paginator={}, - filters=options.filters, - date_range=options.date_range, - sort_by=options.sort_by - ) - - def get_default_sort(self) -> SocaSortBy: - return SocaSortBy( - key='created_on', - order=SocaSortOrder.ASC - ) - - def get_index_name(self) -> str: - return f"{self.context.config().get_string('virtual-desktop-controller.opensearch.dcv_session.alias', required=True)}-{self.context.sessions_template_version}" - def get_session_count_for_user(self, username: str) -> int: count_request = { 'Select': 'COUNT', @@ -379,25 +270,12 @@ def get_session_count_for_user(self, username: str) -> int: return Utils.get_value_as_int('Count', response) def list_all_from_db(self, request: ListSessionsRequest) -> SocaListingPayload: - list_request = {} - - exclusive_start_key = None - if Utils.is_not_empty(request.cursor): - exclusive_start_key = Utils.from_json(Utils.base64_decode(request.cursor)) + list_result = scan_db_records(request, self._table) + session_entries = list_result.get('Items', []) + result = [self.convert_db_dict_to_session_object(session) for session in session_entries] - if exclusive_start_key is not None: - list_request['ExclusiveStartKey'] = exclusive_start_key - - list_result = self._table.scan(**list_request) - session_entries = Utils.get_value_as_list('Items', list_result, []) - result = [] - for session_entry in session_entries: - result.append(self.convert_db_dict_to_session_object(session_entry)) - - response_cursor = None - exclusive_start_key = Utils.get_any_value('LastEvaluatedKey', list_result) - if exclusive_start_key is not None: - response_cursor = Utils.base64_encode(Utils.to_json(exclusive_start_key)) + exclusive_start_key = list_result.get("LastEvaluatedKey") + response_cursor = Utils.base64_encode(Utils.to_json(exclusive_start_key)) if exclusive_start_key else None return SocaListingPayload( listing=result, @@ -418,20 +296,17 @@ def list_all_for_user(self, request: ListSessionsRequest, username: str) -> List key=sessions_constants.USER_SESSION_DB_FILTER_OWNER_KEY, value=username )) - return self.list_from_index(request) - - def _list_from_index(self, options: ListSessionsRequest) -> Dict: - if Utils.is_not_empty(options.filters): - new_filters: List[SocaFilter] = [] - for listing_filter in options.filters: - if listing_filter.key == sessions_constants.USER_SESSION_DB_FILTER_BASE_OS_KEY and listing_filter.value == '$all': - # needs to see all OS's no point adding a filter for OS at this point - continue - elif listing_filter.key == sessions_constants.USER_SESSION_DB_FILTER_BASE_OS_KEY and listing_filter.value == 'linux': - listing_filter.value = [VirtualDesktopBaseOS.AMAZON_LINUX2.value, VirtualDesktopBaseOS.RHEL7.value, VirtualDesktopBaseOS.RHEL8.value, VirtualDesktopBaseOS.RHEL9.value, VirtualDesktopBaseOS.CENTOS7.value] + return self.list_all_from_db(request) - new_filters.append(listing_filter) + def list_all_for_software_stack(self, request: ListSessionsRequest, software_stack: VirtualDesktopSoftwareStack) -> ListSessionsResponse: + if Utils.is_empty(request): + request = ListSessionsRequest() - options.filters = new_filters + if Utils.is_empty(request.filters): + request.filters = [] - return self.list_from_opensearch(options) + request.filters.append(SocaFilter( + key=sessions_constants.USER_SESSION_DB_SOFTWARE_STACK_KEY, + eq=self._software_stack_db.convert_software_stack_object_to_db_dict(software_stack) + )) + return self.list_all_from_db(request) diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/sessions/virtual_desktop_session_utils.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/sessions/virtual_desktop_session_utils.py index 82b983a..ca9d568 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/sessions/virtual_desktop_session_utils.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/sessions/virtual_desktop_session_utils.py @@ -12,7 +12,6 @@ import ideavirtualdesktopcontroller from ideadatamodel import VirtualDesktopSession, VirtualDesktopServer, VirtualDesktopSessionState -from ideasdk.analytics.analytics_service import AnalyticsEntry, EntryAction, EntryContent from ideasdk.utils import Utils, DateTimeUtils from ideavirtualdesktopcontroller.app.events.events_utils import EventsUtils from ideavirtualdesktopcontroller.app.permission_profiles.virtual_desktop_permission_profile_db import VirtualDesktopPermissionProfileDB @@ -258,37 +257,4 @@ def terminate_sessions(self, sessions: List[VirtualDesktopSession]) -> (List[Vir session_map[session.dcv_session_id].failure_reason = session.failure_reason fail_response_list.append(session_map[session.dcv_session_id]) - return success_response_list, fail_response_list - - def delete_session_entry_from_opensearch(self, idea_session_id: str): - index_name = f"{self.context.config().get_string('virtual-desktop-controller.opensearch.dcv_session.alias', required=True)}-{self.context.sessions_template_version}" - self.context.analytics_service().post_entry(AnalyticsEntry( - entry_id=idea_session_id, - entry_action=EntryAction.DELETE_ENTRY, - entry_content=EntryContent( - index_id=index_name - ) - )) - - def update_session_entry_to_opensearch(self, session: VirtualDesktopSession): - index_name = f"{self.context.config().get_string('virtual-desktop-controller.opensearch.dcv_session.alias', required=True)}-{self.context.sessions_template_version}" - self.context.analytics_service().post_entry(AnalyticsEntry( - entry_id=session.idea_session_id, - entry_action=EntryAction.UPDATE_ENTRY, - entry_content=EntryContent( - index_id=index_name, - entry_record=self._session_db.convert_session_object_to_db_dict(session) - ) - )) - - def index_session_entry_to_opensearch(self, session: VirtualDesktopSession): - index_name = f"{self.context.config().get_string('virtual-desktop-controller.opensearch.dcv_session.alias', required=True)}-{self.context.sessions_template_version}" - index_dict = self._session_db.convert_session_object_to_index_dict(session) - self.context.analytics_service().post_entry(AnalyticsEntry( - entry_id=session.idea_session_id, - entry_action=EntryAction.CREATE_ENTRY, - entry_content=EntryContent( - index_id=index_name, - entry_record=index_dict - ) - )) + return success_response_list, fail_response_list \ No newline at end of file diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/software_stacks/virtual_desktop_software_stack_db.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/software_stacks/virtual_desktop_software_stack_db.py index 37497cc..4c17a78 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/software_stacks/virtual_desktop_software_stack_db.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/software_stacks/virtual_desktop_software_stack_db.py @@ -8,7 +8,7 @@ # or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions # and limitations under the License. -from typing import List, Optional, Dict +from typing import Optional, Dict import yaml @@ -22,22 +22,18 @@ SocaMemoryUnit, Project, GetProjectRequest, - SocaSortBy, - SocaSortOrder, ListSoftwareStackRequest, ListSoftwareStackResponse, - SocaFilter, SocaListingPayload, SocaPaginator ) from ideadatamodel.virtual_desktop.virtual_desktop_model import VirtualDesktopGPU -from ideasdk.aws.opensearch.opensearchable_db import OpenSearchableDB -from ideasdk.utils import Utils +from ideasdk.utils import Utils, scan_db_records from ideavirtualdesktopcontroller.app.virtual_desktop_notifiable_db import VirtualDesktopNotifiableDB from ideavirtualdesktopcontroller.app.software_stacks import constants as software_stacks_constants -class VirtualDesktopSoftwareStackDB(VirtualDesktopNotifiableDB, OpenSearchableDB): +class VirtualDesktopSoftwareStackDB(VirtualDesktopNotifiableDB): DEFAULT_PAGE_SIZE = 10 def __init__(self, context: ideavirtualdesktopcontroller.AppContext): @@ -48,20 +44,6 @@ def __init__(self, context: ideavirtualdesktopcontroller.AppContext): self._ddb_client = self.context.aws().dynamodb_table() VirtualDesktopNotifiableDB.__init__(self, context=self.context, table_name=self.table_name, logger=self._logger) - OpenSearchableDB.__init__( - self, context=self.context, logger=self._logger, - term_filter_map={ - software_stacks_constants.SOFTWARE_STACK_DB_STACK_ID_KEY: 'stack_id.raw', - software_stacks_constants.SOFTWARE_STACK_DB_FILTER_BASE_OS_KEY: 'base_os.raw', - software_stacks_constants.SOFTWARE_STACK_DB_FILTER_PROJECT_ID_KEY: 'projects.project_id.raw', - '$all': '$all' - }, - date_range_filter_map={ - software_stacks_constants.SOFTWARE_STACK_DB_FILTER_CREATED_ON_KEY: 'created_on', - software_stacks_constants.SOFTWARE_STACK_DB_FILTER_UPDATED_ON_KEY: 'updated_on' - }, - default_page_size=self.DEFAULT_PAGE_SIZE - ) def initialize(self): exists = self.context.aws_util().dynamodb_check_table_exists(self.table_name, True) @@ -249,33 +231,6 @@ def convert_db_dict_to_software_stack_object(db_entry: dict) -> Optional[Virtual return software_stack - def convert_software_stack_object_to_index_dict(self, software_stack: VirtualDesktopSoftwareStack) -> Dict: - index_dict = self.convert_software_stack_object_to_db_dict(software_stack) - - project_ids = Utils.get_value_as_list(software_stacks_constants.SOFTWARE_STACK_DB_PROJECTS_KEY, index_dict, []) - index_dict[software_stacks_constants.SOFTWARE_STACK_DB_PROJECTS_KEY] = [] - - for project_id in project_ids: - project = self.context.projects_client.get_project(GetProjectRequest(project_id=project_id)).project - index_dict[software_stacks_constants.SOFTWARE_STACK_DB_PROJECTS_KEY].append({ - software_stacks_constants.SOFTWARE_STACK_DB_PROJECT_ID_KEY: project_id, - software_stacks_constants.SOFTWARE_STACK_DB_PROJECT_NAME_KEY: project.name, - software_stacks_constants.SOFTWARE_STACK_DB_PROJECT_TITLE_KEY: project.title - }) - return index_dict - - def convert_index_dict_to_software_stack_object(self, index_dict: Dict): - ss_projects = Utils.get_value_as_list(software_stacks_constants.SOFTWARE_STACK_DB_PROJECTS_KEY, index_dict, []) - index_dict[software_stacks_constants.SOFTWARE_STACK_DB_PROJECTS_KEY] = [] - software_stack = self.convert_db_dict_to_software_stack_object(index_dict) - for project in ss_projects: - software_stack.projects.append(Project( - project_id=Utils.get_value_as_string(software_stacks_constants.SOFTWARE_STACK_DB_PROJECT_ID_KEY, project, None), - name=Utils.get_value_as_string(software_stacks_constants.SOFTWARE_STACK_DB_PROJECT_NAME_KEY, project, None), - title=Utils.get_value_as_string(software_stacks_constants.SOFTWARE_STACK_DB_PROJECT_TITLE_KEY, project, None), - )) - return software_stack - @staticmethod def convert_software_stack_object_to_db_dict(software_stack: VirtualDesktopSoftwareStack) -> Dict: if Utils.is_empty(software_stack): @@ -306,15 +261,6 @@ def convert_software_stack_object_to_db_dict(software_stack: VirtualDesktopSoftw db_dict[software_stacks_constants.SOFTWARE_STACK_DB_PROJECTS_KEY] = project_ids return db_dict - def get_index_name(self) -> str: - return f"{self.context.config().get_string('virtual-desktop-controller.opensearch.software_stack.alias', required=True)}-{self.context.software_stack_template_version}" - - def get_default_sort(self) -> SocaSortBy: - return SocaSortBy( - key='created_on', - order=SocaSortOrder.ASC - ) - def create(self, software_stack: VirtualDesktopSoftwareStack) -> VirtualDesktopSoftwareStack: db_entry = self.convert_software_stack_object_to_db_dict(software_stack) db_entry[software_stacks_constants.SOFTWARE_STACK_DB_CREATED_ON_KEY] = Utils.current_time_ms() @@ -325,16 +271,39 @@ def create(self, software_stack: VirtualDesktopSoftwareStack) -> VirtualDesktopS self.trigger_create_event(db_entry[software_stacks_constants.SOFTWARE_STACK_DB_HASH_KEY], db_entry[software_stacks_constants.SOFTWARE_STACK_DB_RANGE_KEY], new_entry=db_entry) return self.convert_db_dict_to_software_stack_object(db_entry) - def get_from_index(self, stack_id: str) -> Optional[VirtualDesktopSoftwareStack]: - request = ListSoftwareStackRequest() - request.add_filter(SocaFilter( - key=software_stacks_constants.SOFTWARE_STACK_DB_STACK_ID_KEY, - value=stack_id - )) - response = self.list_from_index(request) - if Utils.is_empty(response.listing): - return None - return response.listing[0] + def get_with_project_info(self, stack_id: str, base_os: str) -> Optional[VirtualDesktopSoftwareStack]: + software_stack_db_entry = None + if Utils.is_empty(stack_id) or Utils.is_empty(base_os): + self._logger.error(f'invalid values for stack_id: {stack_id} and/or base_os: {base_os}') + else: + try: + result = self._table.get_item( + Key={ + software_stacks_constants.SOFTWARE_STACK_DB_HASH_KEY: base_os, + software_stacks_constants.SOFTWARE_STACK_DB_RANGE_KEY: stack_id + } + ) + software_stack_db_entry = result.get('Item') + except self._ddb_client.exceptions.ResourceNotFoundException as _: + # in this case we simply need to return None since the resource was not found + return None + except Exception as e: + self._logger.exception(e) + raise e + + software_stack_entry = self.convert_db_dict_to_software_stack_object(software_stack_db_entry) + + def _get_project(project_id): + project = self.context.projects_client.get_project(GetProjectRequest(project_id=project_id)).project + return { + software_stacks_constants.SOFTWARE_STACK_DB_PROJECT_ID_KEY: project_id, + software_stacks_constants.SOFTWARE_STACK_DB_PROJECT_NAME_KEY: project.name, + software_stacks_constants.SOFTWARE_STACK_DB_PROJECT_TITLE_KEY: project.title + } + + software_stack_projects = [_get_project(project_entry.project_id) for project_entry in software_stack_entry.projects] + software_stack_entry.projects = software_stack_projects + return software_stack_entry def get(self, stack_id: str, base_os: str) -> Optional[VirtualDesktopSoftwareStack]: software_stack_db_entry = None @@ -348,7 +317,7 @@ def get(self, stack_id: str, base_os: str) -> Optional[VirtualDesktopSoftwareSta software_stacks_constants.SOFTWARE_STACK_DB_RANGE_KEY: stack_id } ) - software_stack_db_entry = Utils.get_value_as_dict('Item', result) + software_stack_db_entry = result.get('Item') except self._ddb_client.exceptions.ResourceNotFoundException as _: # in this case we simply need to return None since the resource was not found return None @@ -396,63 +365,23 @@ def delete(self, software_stack: VirtualDesktopSoftwareStack): software_stacks_constants.SOFTWARE_STACK_DB_HASH_KEY: software_stack.base_os, software_stacks_constants.SOFTWARE_STACK_DB_RANGE_KEY: software_stack.stack_id }, - ReturnValue='ALL_OLD' + ReturnValues='ALL_OLD' ) old_db_entry = result['Attributes'] self.trigger_delete_event(old_db_entry[software_stacks_constants.SOFTWARE_STACK_DB_HASH_KEY], old_db_entry[software_stacks_constants.SOFTWARE_STACK_DB_RANGE_KEY], deleted_entry=old_db_entry) - def list_from_index(self, options: ListSoftwareStackRequest) -> ListSoftwareStackResponse: - if Utils.is_not_empty(options.filters): - new_filters: List[SocaFilter] = [] - for listing_filter in options.filters: - if listing_filter.key == software_stacks_constants.SOFTWARE_STACK_DB_FILTER_BASE_OS_KEY and listing_filter.value == '$all': - # needs to see all OS's no point adding a filter for OS at this point - continue - new_filters.append(listing_filter) - - options.filters = new_filters - - response = self.list_from_opensearch(options) - software_stacks: List[VirtualDesktopSoftwareStack] = [] - software_stacks_responses = Utils.get_value_as_list('hits', Utils.get_value_as_dict('hits', response, default={}), default=[]) - for software_stacks_response in software_stacks_responses: - index_object = Utils.get_value_as_dict('_source', software_stacks_response, default={}) - software_stacks.append(self.convert_index_dict_to_software_stack_object(index_object)) - - return ListSoftwareStackResponse( - listing=software_stacks, - paginator={}, - filters=options.filters, - date_range=options.date_range, - sort_by=options.sort_by - ) - - def list_all_from_db(self, request: ListSoftwareStackRequest) -> SocaListingPayload: - list_request = {} - - exclusive_start_key = None - if Utils.is_not_empty(request.cursor): - exclusive_start_key = Utils.from_json(Utils.base64_decode(request.cursor)) - - if exclusive_start_key is not None: - list_request['ExclusiveStartKey'] = exclusive_start_key - - list_result = self._table.scan(**list_request) + def list_all_from_db(self, request: ListSoftwareStackRequest) -> ListSoftwareStackResponse: + list_result = scan_db_records(request, self._table) - session_entries = Utils.get_value_as_list('Items', list_result, []) - result = [] - for session in session_entries: - result.append(self.convert_db_dict_to_software_stack_object(session)) + session_entries = list_result.get('Items', []) + result = [self.convert_db_dict_to_software_stack_object(session) for session in session_entries] - response_cursor = None - exclusive_start_key = Utils.get_any_value('LastEvaluatedKey', list_result) - if exclusive_start_key is not None: - response_cursor = Utils.base64_encode(Utils.to_json(exclusive_start_key)) + exclusive_start_key = list_result.get("LastEvaluatedKey") + response_cursor = Utils.base64_encode(Utils.to_json(exclusive_start_key)) if exclusive_start_key else None return SocaListingPayload( listing=result, paginator=SocaPaginator( - page_size=request.page_size, cursor=response_cursor ) ) diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/software_stacks/virtual_desktop_software_stack_utils.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/software_stacks/virtual_desktop_software_stack_utils.py index b5516b7..beb57fe 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/software_stacks/virtual_desktop_software_stack_utils.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/software_stacks/virtual_desktop_software_stack_utils.py @@ -11,7 +11,6 @@ import ideavirtualdesktopcontroller from ideadatamodel import VirtualDesktopSoftwareStack, VirtualDesktopSession -from ideasdk.analytics.analytics_service import AnalyticsEntry, EntryAction, EntryContent from ideasdk.utils import Utils from ideavirtualdesktopcontroller.app.events.events_utils import EventsUtils from ideavirtualdesktopcontroller.app.software_stacks.virtual_desktop_software_stack_db import VirtualDesktopSoftwareStackDB @@ -33,38 +32,8 @@ def create_software_stack(self, software_stack: VirtualDesktopSoftwareStack) -> software_stack.stack_id = Utils.uuid() return self._software_stack_db.create(software_stack) + def delete_software_stack(self, software_stack: VirtualDesktopSoftwareStack): + self._software_stack_db.delete(software_stack) + def create_software_stack_from_session_when_ready(self, session: VirtualDesktopSession, new_software_stack: VirtualDesktopSoftwareStack) -> VirtualDesktopSoftwareStack: pass - - def delete_software_stack_entry_from_opensearch(self, software_stack_id: str): - index_name = f"{self.context.config().get_string('virtual-desktop-controller.opensearch.software_stack.alias', required=True)}-{self.context.software_stack_template_version}" - self.context.analytics_service().post_entry(AnalyticsEntry( - entry_id=software_stack_id, - entry_action=EntryAction.DELETE_ENTRY, - entry_content=EntryContent( - index_id=index_name - ) - )) - - def update_software_stack_entry_to_opensearch(self, software_stack: VirtualDesktopSoftwareStack): - index_name = f"{self.context.config().get_string('virtual-desktop-controller.opensearch.software_stack.alias', required=True)}-{self.context.software_stack_template_version}" - self.context.analytics_service().post_entry(AnalyticsEntry( - entry_id=software_stack.stack_id, - entry_action=EntryAction.UPDATE_ENTRY, - entry_content=EntryContent( - index_id=index_name, - entry_record=self._software_stack_db.convert_software_stack_object_to_index_dict(software_stack) - ) - )) - - def index_software_stack_entry_to_opensearch(self, software_stack: VirtualDesktopSoftwareStack): - index_name = f"{self.context.config().get_string('virtual-desktop-controller.opensearch.software_stack.alias', required=True)}-{self.context.software_stack_template_version}" - index_dict = self._software_stack_db.convert_software_stack_object_to_index_dict(software_stack) - self.context.analytics_service().post_entry(AnalyticsEntry( - entry_id=software_stack.stack_id, - entry_action=EntryAction.CREATE_ENTRY, - entry_content=EntryContent( - index_id=index_name, - entry_record=index_dict - ) - )) diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/virtual_desktop_controller_app.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/virtual_desktop_controller_app.py index 7ccde5d..dc192b4 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/virtual_desktop_controller_app.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/virtual_desktop_controller_app.py @@ -30,6 +30,7 @@ from ideavirtualdesktopcontroller.app.sessions.virtual_desktop_session_db import VirtualDesktopSessionDB from ideavirtualdesktopcontroller.app.software_stacks.virtual_desktop_software_stack_db import VirtualDesktopSoftwareStackDB from ideavirtualdesktopcontroller.app.ssm_commands.virtual_desktop_ssm_commands_db import VirtualDesktopSSMCommandsDB +from ideavirtualdesktopcontroller.app.auth.api_authorization_service import VdcApiAuthorizationService import os import yaml @@ -64,7 +65,6 @@ def __init__(self, context: ideavirtualdesktopcontroller.AppContext, self.context = context def app_initialize(self): - self._initialize_templates() self._initialize_clients() self._initialize_dbs() self._initialize_services() @@ -84,68 +84,6 @@ def _initialize_dbs(self): self._permission_profile_db = VirtualDesktopPermissionProfileDB(self.context).initialize() self._session_permissions_db = VirtualDesktopSessionPermissionDB(self.context).initialize() - def _initialize_session_template(self): - session_template_file = os.path.join(self.context.get_resources_dir(), 'opensearch', 'session_entry_template.yml') - with open(session_template_file, 'r') as f: - sessions_index_template = yaml.safe_load(f) - - if Utils.is_empty(sessions_index_template): - return - - sessions_index_template["index_patterns"] = [ - f"{self.context.config().get_string('virtual-desktop-controller.opensearch.dcv_session.alias')}-*" - ] - sessions_index_template["aliases"] = { - f"{self.context.config().get_string('virtual-desktop-controller.opensearch.dcv_session.alias')}": {} - } - self.context.sessions_template_version = self.context.analytics_service().initialize_template( - template_name=f'{self.context.cluster_name()}_{self.context.module_id()}_user_sessions_template', - template_body=sessions_index_template - ) - - def _initialize_software_stack_template(self): - software_stack_template_file = os.path.join(self.context.get_resources_dir(), 'opensearch', 'software_stack_entry_template.yml') - with open(software_stack_template_file, 'r') as f: - software_stack_index_template = yaml.safe_load(f) - - if Utils.is_empty(software_stack_index_template): - return - - software_stack_index_template["index_patterns"] = [ - f"{self.context.config().get_string('virtual-desktop-controller.opensearch.software_stack.alias')}-*" - ] - software_stack_index_template["aliases"] = { - f"{self.context.config().get_string('virtual-desktop-controller.opensearch.software_stack.alias')}": {} - } - self.context.software_stack_template_version = self.context.analytics_service().initialize_template( - template_name=f'{self.context.cluster_name()}_{self.context.module_id()}_software_stack_template', - template_body=software_stack_index_template - ) - - def _initialize_session_permission_template(self): - session_permission_template_file = os.path.join(self.context.get_resources_dir(), 'opensearch', 'session_permission_entry_template.yml') - with open(session_permission_template_file, 'r') as f: - session_permission_index_template = yaml.safe_load(f) - - if Utils.is_empty(session_permission_index_template): - return - - session_permission_index_template["index_patterns"] = [ - f"{self.context.config().get_string('virtual-desktop-controller.opensearch.session_permission.alias')}-*" - ] - session_permission_index_template["aliases"] = { - f"{self.context.config().get_string('virtual-desktop-controller.opensearch.session_permission.alias')}": {} - } - self.context.session_permission_template_version = self.context.analytics_service().initialize_template( - template_name=f'{self.context.cluster_name()}_{self.context.module_id()}_session_permission_template', - template_body=session_permission_index_template - ) - - def _initialize_templates(self): - self._initialize_session_template() - self._initialize_software_stack_template() - self._initialize_session_permission_template() - def _initialize_clients(self): group_name_helper = GroupNameHelper(self.context) provider_url = self.context.config().get_string('identity-provider.cognito.provider_url', required=True) @@ -192,6 +130,11 @@ def _initialize_clients(self): ), token_service=self.context.token_service ) + + self.context.api_authorization_service = VdcApiAuthorizationService( + accounts_client=self.context.accounts_client, + token_service= self.context.token_service + ) self.context.notification_async_client = NotificationsAsyncClient(context=self.context) self.context.events_client = EventsClient(context=self.context) diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/virtual_desktop_controller_utils.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/virtual_desktop_controller_utils.py index 4febbc6..29f7853 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/virtual_desktop_controller_utils.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/app/virtual_desktop_controller_utils.py @@ -160,9 +160,11 @@ def provision_dcv_host_for_session(self, session: VirtualDesktopSession) -> dict custom_tags = self.context.config().get_list('global-settings.custom_tags', []) custom_tags_dict = Utils.convert_custom_tags_to_key_value_pairs(custom_tags) + session_tags = Utils.convert_tags_list_of_dict_to_tags_dict(session.tags) tags = { **custom_tags_dict, - **tags + **tags, + **session_tags } aws_tags = [] @@ -541,8 +543,3 @@ def change_instance_type(self, instance_id: str, instance_type_name: str) -> (st return repr(e), False return '', True - def get_virtual_desktop_users_group(self) -> str: - return self.group_name_helper.get_module_users_group(self.context.module_id()) - - def get_virtual_desktop_admin_group(self) -> str: - return self.group_name_helper.get_module_administrators_group(module_id=self.context.module_id()) diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/cli/cli_main.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/cli/cli_main.py index 2dc7f3f..6c0036d 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/cli/cli_main.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/cli/cli_main.py @@ -19,13 +19,10 @@ import click from ideavirtualdesktopcontroller.cli.sessions import ( - reindex_user_sessions, batch_create_sessions, create_session, delete_session ) -from ideavirtualdesktopcontroller.cli.software_stacks import reindex_software_stacks -from ideavirtualdesktopcontroller.cli.module import app_module_clean_up @click.group(CLICK_SETTINGS) @@ -38,12 +35,9 @@ def main(): main.add_command(logs) -main.add_command(reindex_user_sessions) main.add_command(batch_create_sessions) main.add_command(create_session) main.add_command(delete_session) -main.add_command(reindex_software_stacks) -main.add_command(app_module_clean_up) # used only for local testing if __name__ == '__main__': diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/cli/module.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/cli/module.py deleted file mode 100644 index ff4cdc6..0000000 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/cli/module.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance -# with the License. A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES -# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions -# and limitations under the License. - -from ideadatamodel import constants -from ideasdk.aws.opensearch.aws_opensearch_client import AwsOpenSearchClient - -from ideavirtualdesktopcontroller.cli import build_cli_context - -import click - - -@click.command(context_settings=constants.CLICK_SETTINGS, short_help='Execute commands to clean up before deleting module') -@click.option('--delete-databases', is_flag=True) -def app_module_clean_up(delete_databases: bool): - """ - Utility hook to do any clean-up before the module is being deleted - """ - if not delete_databases: - return - - context = build_cli_context() - os_client = AwsOpenSearchClient(context) - os_client.delete_template(f'{context.cluster_name()}-{context.module_id()}_session_permission_template') - os_client.delete_alias_and_index( - name=context.config().get_string('virtual-desktop-controller.opensearch.session_permission.alias') - ) - - os_client.delete_template(f'{context.cluster_name()}-{context.module_id()}_user_sessions_template') - os_client.delete_alias_and_index( - name=context.config().get_string('virtual-desktop-controller.opensearch.dcv_session.alias') - ) - os_client.delete_template(f'{context.cluster_name()}-{context.module_id()}_software_stack_template') - os_client.delete_alias_and_index( - name=context.config().get_string('virtual-desktop-controller.opensearch.software_stack.alias') - ) diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/cli/sessions.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/cli/sessions.py index cb46c3c..68b2531 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/cli/sessions.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/cli/sessions.py @@ -10,8 +10,6 @@ # and limitations under the License. from ideadatamodel import ( - ReIndexUserSessionsRequest, - ReIndexUserSessionsResponse, VirtualDesktopSession, VirtualDesktopSoftwareStack, VirtualDesktopBaseOS, @@ -39,23 +37,6 @@ from typing import List from rich.table import Table - -@click.command(context_settings=constants.CLICK_SETTINGS, short_help='Re Index all user-sessions to Open Search') -@click.argument('tokens', nargs=-1) -def reindex_user_sessions(tokens, **kwargs): - context = build_cli_context(unix_socket_timeout=360000) - - request = ReIndexUserSessionsRequest() - response = context.unix_socket_client.invoke_alt( - namespace='VirtualDesktopAdmin.ReIndexUserSessions', - payload=request, - result_as=ReIndexUserSessionsResponse - ) - # TODO: PrettyPrint response. - # TODO: handle flag --destroy-and-recreate-index - print(response) - - @click.command(context_settings=constants.CLICK_SETTINGS, short_help='Creates a session') @click.option('--name', required=True, help='virtual session name') @click.option('--owner', required=True, help='session owner name') diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/cli/software_stacks.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/cli/software_stacks.py deleted file mode 100644 index 353fd26..0000000 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller/cli/software_stacks.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance -# with the License. A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES -# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions -# and limitations under the License. - -from ideadatamodel import ( - ReIndexSoftwareStacksRequest, - ReIndexSoftwareStacksResponse -) -from ideadatamodel import constants - -from ideavirtualdesktopcontroller.cli import build_cli_context - -import click - - -@click.command(context_settings=constants.CLICK_SETTINGS, short_help='Re Index all software stacks to Open Search') -@click.argument('tokens', nargs=-1) -def reindex_software_stacks(tokens, **kwargs): - - context = build_cli_context() - - request = ReIndexSoftwareStacksRequest() - response = context.unix_socket_client.invoke_alt( - namespace='VirtualDesktopAdmin.ReIndexSoftwareStacks', - payload=request, - result_as=ReIndexSoftwareStacksResponse - ) - # TODO: PrettyPrint response. - # TODO: handle flag --destroy-and-recreate-index - print(response) diff --git a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller_meta/__init__.py b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller_meta/__init__.py index 18f2787..3dbabda 100644 --- a/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller_meta/__init__.py +++ b/source/idea/idea-virtual-desktop-controller/src/ideavirtualdesktopcontroller_meta/__init__.py @@ -10,4 +10,4 @@ # and limitations under the License. __name__ = 'idea-virtual-desktop-controller' -__version__ = '2023.11' +__version__ = '2024.01' diff --git a/source/idea/infrastructure/install/commands/create.py b/source/idea/infrastructure/install/commands/create.py index a9779fc..58abc1a 100644 --- a/source/idea/infrastructure/install/commands/create.py +++ b/source/idea/infrastructure/install/commands/create.py @@ -1,14 +1,15 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Any +from typing import Any, Union import aws_cdk +from idea.batteries_included.parameters.parameters import BIParameters from idea.infrastructure.install.parameters.common import CommonKey from idea.infrastructure.install.parameters.customdomain import CustomDomainKey from idea.infrastructure.install.parameters.directoryservice import DirectoryServiceKey -from idea.infrastructure.install.parameters.parameters import Parameters +from idea.infrastructure.install.parameters.parameters import RESParameters from idea.infrastructure.install.parameters.shared_storage import SharedStorageKey VALUES_FILE_PATH = "values.yml" @@ -21,7 +22,7 @@ class Create: automagically """ - def __init__(self, params: Parameters): + def __init__(self, params: Union[RESParameters, BIParameters]): self.params = params def get_commands(self) -> list[str]: @@ -40,12 +41,13 @@ def _config(self) -> list[str]: 1. Generate values.yml using parameters 2. Render settings files from generated values.yml 3. Write rendered settings to dynamo - 4. Update the dynamo settings table for AD + 4. Update the dynamo settings table for infrastructure host AMIs, AD, and custom domain 5. Remove the local copy of config, and replace with what's in dynamo """ return [ *self._create_values_file(), *self._render_settings_from_values_and_write_to_dynamo(), + *self._update_infrastructure_host_ami_config(), *self._update_directory_service_config(), *self._update_custom_domain_config(), f"rm -rf /root/.idea/clusters/{self.params.get_str(CommonKey.CLUSTER_NAME)}/{aws_cdk.Aws.REGION}/config", @@ -78,6 +80,28 @@ def _render_settings_from_values_and_write_to_dynamo(self) -> list[str]: f"{EXE} config update --overwrite --force {self._get_suffix()}", ] + def _update_infrastructure_host_ami_config(self) -> list[str]: + """ + Generate commands to update each of the settings for the custom infrastructure host amis + """ + + # Construct none_null to check none of the custom values are null + vals = "" + nvals = len(self._get_infrastructure_host_ami_settings().keys()) + for key, value in self._get_infrastructure_host_ami_settings().items(): + vals += f" {value} " + quote = r'"' + vals = f"{quote} {vals} {quote}" + + # Update only if none of the custom values are null + config_update_commands: list[str] = [] + for key, value in self._get_infrastructure_host_ami_settings().items(): + cmd = f" if [ $(echo {vals}|wc -w) -eq {nvals} ] ; then " + cmd += f" {EXE} config set Key={key},Type=str,Value={value} --force {self._get_suffix()}; fi" + config_update_commands.append(cmd) + + return config_update_commands + def _update_custom_domain_config(self) -> list[str]: """ Generate commands to update each of the settings for custom dmain @@ -133,6 +157,28 @@ def _update_directory_service_config(self) -> list[str]: ) return config_update_commands + def _get_infrastructure_host_ami_settings(self) -> dict[str, Any]: + """ + Returns a mapping of RES settings key to parameters for updating the infrastructure host ami config + """ + return { + "vdc.dcv_broker.autoscaling.instance_ami": self.params.get_str( + CommonKey.INFRASTRUCTURE_HOST_AMI + ), + "vdc.dcv_connection_gateway.autoscaling.instance_ami": self.params.get_str( + CommonKey.INFRASTRUCTURE_HOST_AMI + ), + "bastion-host.instance_ami": self.params.get_str( + CommonKey.INFRASTRUCTURE_HOST_AMI + ), + "vdc.controller.autoscaling.instance_ami": self.params.get_str( + CommonKey.INFRASTRUCTURE_HOST_AMI + ), + "cluster-manager.ec2.autoscaling.instance_ami": self.params.get_str( + CommonKey.INFRASTRUCTURE_HOST_AMI + ), + } + def _get_directory_service_settings(self) -> dict[str, Any]: """ Returns a mapping of RES settings key to parameters for updating the directoryservice config @@ -168,6 +214,9 @@ def _get_directory_service_settings(self) -> dict[str, Any]: "directoryservice.tls_certificate_secret_arn": self.params.get_str( DirectoryServiceKey.DOMAIN_TLS_CERTIFICATE_SECRET_ARN ), + "directoryservice.sssd.ldap_id_mapping": self.params.get_str( + DirectoryServiceKey.ENABLE_LDAP_ID_MAPPING + ), } def _get_custom_domain_settings(self) -> dict[str, Any]: @@ -210,8 +259,8 @@ def _create_values_file(self) -> list[str]: - {self.params.get_str(CommonKey.CLIENT_IP)} prefix_list_ids: - {self.params.get_str(CommonKey.CLIENT_PREFIX_LIST)} -alb_public: true -use_vpc_endpoints: false +alb_public: {self.params.get_str(CommonKey.IS_LOAD_BALANCER_INTERNET_FACING)} +use_vpc_endpoints: true directory_service_provider: activedirectory enable_aws_backup: false kms_key_type: aws-managed @@ -223,22 +272,24 @@ def _create_values_file(self) -> list[str]: existing_resources: - subnets:public - subnets:private +- subnets:external_load_balancer +- subnets:infrastructure_hosts +- subnets:dcv_session - shared-storage:home -public_subnet_ids: +load_balancer_subnet_ids: $(echo "{ aws_cdk.Stack.of( - self.params.get(CommonKey.PUBLIC_SUBNETS).stack + self.params.get(CommonKey.LOAD_BALANCER_SUBNETS).stack ).to_json_string( - self.params.get(CommonKey.PUBLIC_SUBNETS).value_as_list - ) + self.params.get(CommonKey.LOAD_BALANCER_SUBNETS).value_as_list) }" | tr -d '[]" ' | tr ',' '\n' | sed 's/^/- /') -private_subnet_ids: +infrastructure_host_subnet_ids: $(echo "{ - aws_cdk.Stack.of( - self.params.get(CommonKey.PRIVATE_SUBNETS).stack - ).to_json_string( - self.params.get(CommonKey.PRIVATE_SUBNETS).value_as_list - ) + aws_cdk.Stack.of(self.params.get(CommonKey.INFRASTRUCTURE_HOST_SUBNETS).stack).to_json_string(self.params.get(CommonKey.INFRASTRUCTURE_HOST_SUBNETS).value_as_list) +}" | tr -d '[]" ' | tr ',' '\n' | sed 's/^/- /') +dcv_session_private_subnet_ids: +$(echo "{ + aws_cdk.Stack.of(self.params.get(CommonKey.VDI_SUBNETS).stack).to_json_string(self.params.get(CommonKey.VDI_SUBNETS).value_as_list) }" | tr -d '[]" ' | tr ',' '\n' | sed 's/^/- /') enabled_modules: - metrics diff --git a/source/idea/infrastructure/install/handlers.py b/source/idea/infrastructure/install/handlers.py index 1a80cfa..e757763 100644 --- a/source/idea/infrastructure/install/handlers.py +++ b/source/idea/infrastructure/install/handlers.py @@ -81,13 +81,13 @@ def handle_custom_resource_lifecycle_event(event: Dict[str, Any], _: Any) -> Non def send_wait_condition_response(event: Dict[str, Any], _: Any) -> Any: is_wait_condition_response = event.get("RequestType") != RequestType.DELETE - response: Union[ - WaitConditionResponse, CustomResourceResponse - ] = WaitConditionResponse( - Status=WaitConditionResponseStatus.SUCCESS, - UniqueId=str(uuid.uuid4()), - Reason=WaitConditionResponseStatus.SUCCESS, - Data="", + response: Union[WaitConditionResponse, CustomResourceResponse] = ( + WaitConditionResponse( + Status=WaitConditionResponseStatus.SUCCESS, + UniqueId=str(uuid.uuid4()), + Reason=WaitConditionResponseStatus.SUCCESS, + Data="", + ) ) if not is_wait_condition_response: response = CustomResourceResponse( diff --git a/source/idea/infrastructure/install/installer.py b/source/idea/infrastructure/install/installer.py index 3247a60..8284ab2 100644 --- a/source/idea/infrastructure/install/installer.py +++ b/source/idea/infrastructure/install/installer.py @@ -3,7 +3,7 @@ import inspect import pathlib from datetime import datetime -from typing import Any, Callable, Dict, TypedDict +from typing import Any, Callable, Dict, TypedDict, Union import aws_cdk from aws_cdk import aws_lambda as lambda_ @@ -11,8 +11,9 @@ from aws_cdk import aws_stepfunctions_tasks as sfn_tasks from constructs import Construct, DependencyGroup +from idea.batteries_included.parameters.parameters import BIParameters from idea.infrastructure.install import handlers, tasks -from idea.infrastructure.install.parameters.parameters import Parameters +from idea.infrastructure.install.parameters.parameters import RESParameters class LambdaCodeParams(TypedDict): @@ -26,7 +27,7 @@ def __init__( scope: Construct, id: str, registry_name: str, - params: Parameters, + params: Union[RESParameters, BIParameters], dependency_group: DependencyGroup, ): super().__init__(scope, id) diff --git a/source/idea/infrastructure/install/parameters/base.py b/source/idea/infrastructure/install/parameters/base.py index e088fff..0d84c41 100644 --- a/source/idea/infrastructure/install/parameters/base.py +++ b/source/idea/infrastructure/install/parameters/base.py @@ -5,7 +5,7 @@ import typing from dataclasses import asdict, dataclass, field, fields from enum import Enum -from typing import Any, ClassVar, Generator, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Generator, List, Optional, Type, TypeVar, Union import aws_cdk from constructs import Construct @@ -33,6 +33,7 @@ class Attributes: description: Optional[str] = None type: Optional[str] = None allowed_pattern: Optional[str] = None + allowed_values: Optional[list[str]] = None constraint_description: Optional[str] = None no_echo: Optional[bool] = None diff --git a/source/idea/infrastructure/install/parameters/common.py b/source/idea/infrastructure/install/parameters/common.py index ecf53df..0f9cc95 100644 --- a/source/idea/infrastructure/install/parameters/common.py +++ b/source/idea/infrastructure/install/parameters/common.py @@ -9,12 +9,15 @@ class CommonKey(Key): CLUSTER_NAME = "EnvironmentName" ADMIN_EMAIL = "AdministratorEmail" + INFRASTRUCTURE_HOST_AMI = "InfrastructureHostAMI" SSH_KEY_PAIR = "SSHKeyPair" CLIENT_IP = "ClientIp" CLIENT_PREFIX_LIST = "ClientPrefixList" VPC_ID = "VpcId" - PUBLIC_SUBNETS = "PublicSubnets" - PRIVATE_SUBNETS = "PrivateSubnets" + LOAD_BALANCER_SUBNETS = "LoadBalancerSubnets" + INFRASTRUCTURE_HOST_SUBNETS = "InfrastructureHostSubnets" + VDI_SUBNETS = "VdiSubnets" + IS_LOAD_BALANCER_INTERNET_FACING = "IsLoadBalancerInternetFacing" @dataclass @@ -55,7 +58,7 @@ class CommonParameters(Base): id=CommonKey.CLIENT_PREFIX_LIST, type="String", description=( - "A prefix list that covers IPs allowed to directly access the Web UI and SSH " + "(Optional) A prefix list that covers IPs allowed to directly access the Web UI and SSH " "into the bastion host." ), allowed_pattern="^(pl-[a-z0-9]{8,20})?$", @@ -84,6 +87,16 @@ class CommonParameters(Base): ) ) + infrastructure_host_ami: str = Base.parameter( + Attributes( + id=CommonKey.INFRASTRUCTURE_HOST_AMI, + type="String", + allowed_pattern="^(ami-[0-9a-f]{8,17})?$", + description="(Optional) You may provide a custom AMI id to use for all the infrastructure hosts. The current supported base OS is Amazon Linux 2.", + constraint_description="The AMI id must begin with 'ami-' followed by only letters (a-f) or numbers(0-9).", + ) + ) + vpc_id: str = Base.parameter( Attributes( id=CommonKey.VPC_ID, @@ -94,24 +107,42 @@ class CommonParameters(Base): ) ) - public_subnets: list[str] = Base.parameter( + load_balancer_subnets: list[str] = Base.parameter( + Attributes( + id=CommonKey.LOAD_BALANCER_SUBNETS, + type="List", + description="Select at least 2 subnets from different Availability Zones. For deployments that need restricted internet access, select private subnets. For deployments that need internet access, select public subnets.", + allowed_pattern=".+", + ) + ) + + infrastructure_host_subnets: list[str] = Base.parameter( Attributes( - id=CommonKey.PUBLIC_SUBNETS, + id=CommonKey.INFRASTRUCTURE_HOST_SUBNETS, type="List", - description="Pick at least 2 public subnets from 2 different Availability Zones", + description="Select at least 2 private subnets from different Availability Zones.", allowed_pattern=".+", ) ) - private_subnets: list[str] = Base.parameter( + vdi_subnets: list[str] = Base.parameter( Attributes( - id=CommonKey.PRIVATE_SUBNETS, + id=CommonKey.VDI_SUBNETS, type="List", - description="Pick at least 2 private subnets from 2 different Availability Zones", + description="Select at least 2 subnets from different Availability Zones. For deployments that need restricted internet access, select private subnets. For deployments that need internet access, select public subnets.", allowed_pattern=".+", ) ) + is_load_balancer_internet_facing: list[str] = Base.parameter( + Attributes( + id=CommonKey.IS_LOAD_BALANCER_INTERNET_FACING, + type="String", + description="Select true to deploy internet facing load balancer (Requires public subnets for load balancer). For deployments that need restricted internet access, select false.", + allowed_values=["true", "false"], + ) + ) + class CommonParameterGroups: parameter_group_for_environment_and_installer_details: dict[str, Any] = { @@ -119,6 +150,7 @@ class CommonParameterGroups: "Parameters": [ CommonKey.CLUSTER_NAME, CommonKey.ADMIN_EMAIL, + CommonKey.INFRASTRUCTURE_HOST_AMI, CommonKey.SSH_KEY_PAIR, CommonKey.CLIENT_IP, CommonKey.CLIENT_PREFIX_LIST, @@ -129,7 +161,9 @@ class CommonParameterGroups: "Label": {"default": "Network configuration for the RES environment"}, "Parameters": [ CommonKey.VPC_ID, - CommonKey.PRIVATE_SUBNETS, - CommonKey.PUBLIC_SUBNETS, + CommonKey.IS_LOAD_BALANCER_INTERNET_FACING, + CommonKey.LOAD_BALANCER_SUBNETS, + CommonKey.INFRASTRUCTURE_HOST_SUBNETS, + CommonKey.VDI_SUBNETS, ], } diff --git a/source/idea/infrastructure/install/parameters/directoryservice.py b/source/idea/infrastructure/install/parameters/directoryservice.py index d0b9195..b5c79d0 100644 --- a/source/idea/infrastructure/install/parameters/directoryservice.py +++ b/source/idea/infrastructure/install/parameters/directoryservice.py @@ -1,6 +1,8 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +# Code changes made to this file must be replicated in 'source/idea/batteries_included/parameters' too + from dataclasses import dataclass from typing import Any, Optional @@ -20,6 +22,7 @@ class DirectoryServiceKey(Key): ROOT_USERNAME = "ServiceAccountUsername" ROOT_PASSWORD = "ServiceAccountPassword" DOMAIN_TLS_CERTIFICATE_SECRET_ARN = "DomainTLSCertificateSecretArn" + ENABLE_LDAP_ID_MAPPING = "EnableLdapIDMapping" @dataclass @@ -119,7 +122,15 @@ class DirectoryServiceParameters(Base): Attributes( id=DirectoryServiceKey.DOMAIN_TLS_CERTIFICATE_SECRET_ARN, type="String", - description="Domain TLS Certificate Secret ARN", + description="(Optional) Domain TLS Certificate Secret ARN", + ) + ) + enable_ldap_id_mapping: str = Base.parameter( + Attributes( + id=DirectoryServiceKey.ENABLE_LDAP_ID_MAPPING, + type="String", + description="Set to False to use the uidNumbers and gidNumbers for users and group from the provided AD. Otherwise set to True.", + allowed_values=["True", "False"], ) ) @@ -144,5 +155,6 @@ class DirectoryServiceParameterGroups: DirectoryServiceKey.SUDOERS_GROUP_NAME, DirectoryServiceKey.COMPUTERS_OU, DirectoryServiceKey.DOMAIN_TLS_CERTIFICATE_SECRET_ARN, + DirectoryServiceKey.ENABLE_LDAP_ID_MAPPING, ], } diff --git a/source/idea/infrastructure/install/parameters/parameters.py b/source/idea/infrastructure/install/parameters/parameters.py index b0b8ae8..c29d82e 100644 --- a/source/idea/infrastructure/install/parameters/parameters.py +++ b/source/idea/infrastructure/install/parameters/parameters.py @@ -12,7 +12,7 @@ @dataclass -class Parameters( +class RESParameters( common.CommonParameters, customdomain.CustomDomainParameters, directoryservice.DirectoryServiceParameters, @@ -26,7 +26,7 @@ class Parameters( pass -class AllParameterGroups( +class AllRESParameterGroups( common.CommonParameterGroups, customdomain.CustomDomainParameterGroups, directoryservice.DirectoryServiceParameterGroups, diff --git a/source/idea/infrastructure/install/permissions.py b/source/idea/infrastructure/install/permissions.py index 1c61b6d..47233f3 100644 --- a/source/idea/infrastructure/install/permissions.py +++ b/source/idea/infrastructure/install/permissions.py @@ -15,33 +15,46 @@ def __init__( environment_name: str, ): super().__init__(scope, id) - - self.install_role = iam.Role( + # TODO: Split role into separate Install/Delete/Update roles to allow for finer grained permissions + self.pipeline_role = iam.Role( self, - "InstallRole", + "PipelineRole", assumed_by=self.get_principal(), - role_name=f"Admin-{environment_name}-InstallRole", + role_name=f"Admin-{environment_name}-PipelineRole", ) - self.install_role.add_to_policy(statement=self.get_install_policy_statement()) - dependency_group.add(self.install_role) - self.update_role = iam.Role( - self, - "UpdateRole", - assumed_by=self.get_principal(), - role_name=f"Admin-{environment_name}-UpdateRole", + self.pipeline_role.add_to_policy(statement=self.get_cloudformation_access()) + self.pipeline_role.add_to_policy(statement=self.get_directoryservice_access()) + self.pipeline_role.add_to_policy(statement=self.get_dynamodb_access()) + self.pipeline_role.add_to_policy(statement=self.get_ecr_access()) + self.pipeline_role.add_to_policy( + statement=self.get_ecr_authorizationtoken_access() ) - self.update_role.add_to_policy(statement=self.get_install_policy_statement()) - dependency_group.add(self.update_role) + self.pipeline_role.add_to_policy(statement=self.get_ec2_access()) + self.pipeline_role.add_to_policy(statement=self.get_ec2_describe_access()) + self.pipeline_role.add_to_policy(statement=self.get_efs_access()) + self.pipeline_role.add_to_policy(statement=self.get_elb_access()) + self.pipeline_role.add_to_policy(statement=self.get_elb_readonly_access()) + self.pipeline_role.add_to_policy(statement=self.get_es_access()) + self.pipeline_role.add_to_policy(statement=self.get_cloudtrail_access()) + self.pipeline_role.add_to_policy(statement=self.get_fsx_access()) + self.pipeline_role.add_to_policy(statement=self.get_iam_access()) + self.pipeline_role.add_to_policy(statement=self.get_kms_access()) + self.pipeline_role.add_to_policy(statement=self.get_lambda_access()) + self.pipeline_role.add_to_policy(statement=self.get_cloudwatch_logs_access()) + self.pipeline_role.add_to_policy(statement=self.get_cloudwatch_access()) + self.pipeline_role.add_to_policy(statement=self.get_route53_access()) + self.pipeline_role.add_to_policy(statement=self.get_s3_access()) + self.pipeline_role.add_to_policy(statement=self.get_secretsmanager_access()) + self.pipeline_role.add_to_policy(statement=self.get_ssm_access()) + self.pipeline_role.add_to_policy(statement=self.get_sns_access()) + self.pipeline_role.add_to_policy(statement=self.get_sqs_access()) + self.pipeline_role.add_to_policy(statement=self.get_sts_access()) + self.pipeline_role.add_to_policy(statement=self.get_tag_access()) + self.pipeline_role.add_to_policy(statement=self.get_cognito_idp_access()) + self.pipeline_role.add_to_policy(statement=self.get_cognito_idp_list_access()) - self.delete_role = iam.Role( - self, - "DeleteRole", - assumed_by=self.get_principal(), - role_name=f"Admin-{environment_name}-DeleteRole", - ) - self.delete_role.add_to_policy(statement=self.get_install_policy_statement()) - dependency_group.add(self.delete_role) + dependency_group.add(self.pipeline_role) def get_principal(self) -> iam.ServicePrincipal: return iam.ServicePrincipal( @@ -54,149 +67,313 @@ def get_principal(self) -> iam.ServicePrincipal: }, ) - def get_install_policy_statement(self) -> iam.PolicyStatement: + def get_cloudformation_access(self) -> iam.PolicyStatement: return iam.PolicyStatement( effect=iam.Effect.ALLOW, - resources=["*"], + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:cloudformation:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:stack/res-*" + ], actions=[ - "*", # TODO: narrow the scope of these permissions - "backup-storage:MountCapsule", - "backup:CreateBackupPlan", - "backup:CreateBackupSelection", - "backup:CreateBackupVault", - "backup:DescribeBackupVault", - "backup:GetBackupPlan", - "backup:GetBackupSelection", - "backup:TagResource", "cloudformation:CreateChangeSet", "cloudformation:CreateStack", "cloudformation:DeleteChangeSet", + "cloudformation:DeleteStack", "cloudformation:DescribeChangeSet", "cloudformation:DescribeStackEvents", "cloudformation:DescribeStacks", "cloudformation:ExecuteChangeSet", "cloudformation:GetTemplate", "cloudformation:UpdateTerminationProtection", - "cloudwatch:PutMetricAlarm", - "ds:CreateMicrosoftAD", - "ds:DescribeDirectories", - "dynamodb:*", - "ec2:AllocateAddress", - "ec2:AssociateAddress", - "ec2:AssociateRouteTable", - "ec2:AttachInternetGateway", - "ec2:AuthorizeSecurityGroupEgress", - "ec2:AuthorizeSecurityGroupIngress", - "ec2:CreateFlowLogs", - "ec2:CreateInternetGateway", - "ec2:CreateNatGateway", - "ec2:CreateNetworkInterface", - "ec2:CreateNetworkInterfacePermission", - "ec2:CreateRoute", - "ec2:CreateRouteTable", - "ec2:CreateSecurityGroup", - "ec2:CreateSubnet", - "ec2:CreateTags", - "ec2:CreateVpc", - "ec2:CreateVpcEndpoint", - "ec2:DescribeAddresses", - "ec2:DescribeAvailabilityZones", - "ec2:DescribeFlowLogs", - "ec2:DescribeImages", + ], + ) + + def get_directoryservice_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:ds:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=["ds:CreateMicrosoftAD", "ds:DescribeDirectories"], + ) + + def get_dynamodb_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:dynamodb:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=["dynamodb:*"], + ) + + def get_ecr_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:ecr:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=["ecr:*"], + ) + + def get_ecr_authorizationtoken_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=["*"], + actions=["ecr:GetAuthorizationToken"], + ) + + def get_ec2_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:ec2:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=["ec2:*"], + ) + + def get_ec2_describe_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=["*"], + actions=[ + "ec2:DescribeTags", "ec2:DescribeInstances", - "ec2:DescribeInternetGateways", - "ec2:DescribeKeyPairs", - "ec2:DescribeNatGateways", - "ec2:DescribeNetwork*", - "ec2:DescribeNetworkInterfaces", "ec2:DescribeRegions", - "ec2:DescribeRouteTables", - "ec2:DescribeSecurityGroups", - "ec2:DescribeSubnets", - "ec2:DescribeTags", - "ec2:DescribeVpcAttribute", - "ec2:DescribeVpcEndpointServices", - "ec2:DescribeVpcEndpoints", - "ec2:DescribeVpcs", - "ec2:ModifySubnetAttribute", - "ec2:ModifyVpcAttribute", - "ec2:RevokeSecurityGroupEgress", - "ec2:RunInstances", + ], + ) + + def get_efs_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:elasticfilesystem:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=[ "elasticfilesystem:CreateFileSystem", "elasticfilesystem:CreateMountTarget", "elasticfilesystem:DescribeFileSystems", "elasticfilesystem:DescribeMountTargets", "elasticfilesystem:PutFileSystemPolicy", "elasticfilesystem:PutLifecycleConfiguration", - "elasticloadbalancing:AddTags", - "elasticloadbalancing:CreateListener", - "elasticloadbalancing:CreateLoadBalancer", - "elasticloadbalancing:CreateRule", - "elasticloadbalancing:CreateTargetGroup", - "elasticloadbalancing:DeleteTargetGroup", - "elasticloadbalancing:DescribeListeners", - "elasticloadbalancing:DescribeLoadBalancers", - "elasticloadbalancing:DescribeRules", - "elasticloadbalancing:DescribeTargetGroups", - "elasticloadbalancing:DescribeTargetHealth", - "elasticloadbalancing:ModifyLoadBalancerAttributes", - "elasticloadbalancing:ModifyRule", - "elasticloadbalancing:RegisterTargets", + ], + ) + + def get_elb_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:elasticloadbalancing:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=["elasticloadbalancing:*"], + ) + + def get_elb_readonly_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=["*"], + actions=["elasticloadbalancing:Describe*"], + ) + + def get_es_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:es:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=[ "es:AddTags", "es:CreateElasticsearchDomain", "es:DescribeElasticsearchDomain", "es:ListDomainNames", - "events:*", + ], + ) + + def get_cloudtrail_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:events:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=["events:*"], + ) + + def get_fsx_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:fsx:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=[ "fsx:CreateFileSystem", "fsx:DescribeFileSystems", "fsx:TagResource", + ], + ) + + def get_iam_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[f"arn:{aws_cdk.Aws.PARTITION}:iam::{aws_cdk.Aws.ACCOUNT_ID}:*"], + actions=[ "iam:AddRoleToInstanceProfile", "iam:AttachRolePolicy", "iam:CreateInstanceProfile", "iam:CreateRole", + "iam:DeleteRolePolicy", + "iam:DetachRolePolicy", "iam:GetRole", "iam:GetRolePolicy", "iam:ListRoles", "iam:PassRole", "iam:PutRolePolicy", "iam:TagRole", + "iam:DeleteRole", + ], + ) + + def get_kms_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:kms:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=[ "kms:CreateGrant", "kms:Decrypt", "kms:DescribeKey", "kms:GenerateDataKey", "kms:RetireGrant", + ], + ) + + def get_lambda_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:lambda:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=[ "lambda:AddPermission", "lambda:CreateFunction", "lambda:GetFunction", "lambda:InvokeFunction", + ], + ) + + def get_cloudwatch_logs_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:logs:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=[ "logs:CreateLogGroup", "logs:CreateLogStream", "logs:DescribeLogGroups", "logs:PutRetentionPolicy", - "route53:CreateHostedZone", - "route53:CreateVPCAssociationAuthorization", - "route53:GetHostedZone", - "route53resolver:AssociateResolverEndpointIpAddress", - "route53resolver:AssociateResolverRule", - "route53resolver:CreateResolverEndpoint", - "route53resolver:CreateResolverRule", - "route53resolver:GetResolverEndpoint", - "route53resolver:GetResolverRule", - "route53resolver:GetResolverRuleAssociation", - "route53resolver:PutResolverRulePolicy", - "route53resolver:TagResource", - "s3:*Object", - "s3:GetBucketLocation", - "s3:ListBucket", - "secretsmanager:CreateSecret", - "secretsmanager:TagResource", + ], + ) + + def get_cloudwatch_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:cloudwatch:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=["cloudwatch:*"], + ) + + def get_route53_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[f"arn:{aws_cdk.Aws.PARTITION}:route53:::*"], + actions=["route53:*"], + ) + + def get_s3_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[f"arn:{aws_cdk.Aws.PARTITION}:s3:::*"], + actions=[ + "s3:*", + ], + ) + + def get_secretsmanager_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:secretsmanager:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=["secretsmanager:CreateSecret", "secretsmanager:TagResource"], + ) + + def get_sns_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:sns:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=[ "sns:CreateTopic", "sns:GetTopicAttributes", "sns:ListSubscriptionsByTopic", "sns:SetTopicAttributes", "sns:Subscribe", "sns:TagResource", - "sqs:*", - "sts:DecodeAuthorizationMessage", ], ) + + def get_sqs_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:sqs:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=["sqs:*"], + ) + + def get_ssm_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:ssm:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*", + f"arn:{aws_cdk.Aws.PARTITION}:ec2:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*", + f"arn:{aws_cdk.Aws.PARTITION}:ssm:{aws_cdk.Aws.REGION}::document/AWS-RunShellScript", + ], + actions=[ + "ssm:*", + ], + ) + + def get_sts_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:sts::{aws_cdk.Aws.ACCOUNT_ID}:*", + f"arn:{aws_cdk.Aws.PARTITION}:iam::{aws_cdk.Aws.ACCOUNT_ID}:role/*", + ], + actions=["sts:*"], + ) + + def get_tag_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=["*"], + actions=["tag:GetResources"], + ) + + def get_cognito_idp_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=[ + f"arn:{aws_cdk.Aws.PARTITION}:cognito-idp:{aws_cdk.Aws.REGION}:{aws_cdk.Aws.ACCOUNT_ID}:*" + ], + actions=["cognito-idp:*"], + ) + + def get_cognito_idp_list_access(self) -> iam.PolicyStatement: + return iam.PolicyStatement( + effect=iam.Effect.ALLOW, + resources=["*"], + actions=["cognito-idp:ListUserPools"], + ) diff --git a/source/idea/infrastructure/install/stack.py b/source/idea/infrastructure/install/stack.py index e9af203..6f5b74a 100644 --- a/source/idea/infrastructure/install/stack.py +++ b/source/idea/infrastructure/install/stack.py @@ -10,12 +10,13 @@ from constructs import Construct, DependencyGroup import idea +from idea.batteries_included.parameters.parameters import BIParameters from idea.infrastructure.install import installer from idea.infrastructure.install.parameters.common import CommonKey from idea.infrastructure.install.parameters.directoryservice import DirectoryServiceKey from idea.infrastructure.install.parameters.parameters import ( - AllParameterGroups, - Parameters, + AllRESParameterGroups, + RESParameters, ) from ideadatamodel import constants # type: ignore @@ -38,7 +39,7 @@ def __init__( self, scope: Construct, stack_id: str, - parameters: Parameters = Parameters(), + parameters: Union[RESParameters, BIParameters] = RESParameters(), registry_name: Optional[str] = None, dynamodb_kms_key_alias: Optional[str] = None, env: Union[Environment, dict[str, Any], None] = None, @@ -54,7 +55,7 @@ def __init__( self.parameters = parameters self.parameters.generate(self) - self.template_options.metadata = AllParameterGroups.template_metadata() + self.template_options.metadata = AllRESParameterGroups.template_metadata() self.registry_name = ( registry_name if registry_name is not None else PUBLIC_REGISTRY_NAME ) diff --git a/source/idea/infrastructure/install/tasks.py b/source/idea/infrastructure/install/tasks.py index dcc6522..5f8fab1 100644 --- a/source/idea/infrastructure/install/tasks.py +++ b/source/idea/infrastructure/install/tasks.py @@ -1,6 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, TypedDict +from typing import Optional, TypedDict, Union import aws_cdk from aws_cdk import aws_ecs as ecs @@ -9,10 +9,11 @@ from aws_cdk import aws_stepfunctions_tasks as sfn_tasks from constructs import Construct, DependencyGroup +from idea.batteries_included.parameters.parameters import BIParameters from idea.infrastructure.install import handlers from idea.infrastructure.install.commands import create from idea.infrastructure.install.parameters.common import CommonKey -from idea.infrastructure.install.parameters.parameters import Parameters +from idea.infrastructure.install.parameters.parameters import RESParameters from idea.infrastructure.install.permissions import Permissions @@ -28,7 +29,7 @@ def __init__( scope: Construct, id: str, registry_name: str, - params: Parameters, + params: Union[RESParameters, BIParameters], dependency_group: DependencyGroup, ): super().__init__(scope, id) @@ -77,7 +78,7 @@ def get_create_task(self) -> sfn_tasks.EcsRunTask: return self.get_task( name="Create", command=create.Create(params=self.params).get_commands(), - task_role=self.permissions.install_role, + task_role=self.permissions.pipeline_role, ) def get_update_task(self) -> sfn_tasks.EcsRunTask: @@ -87,7 +88,7 @@ def get_update_task(self) -> sfn_tasks.EcsRunTask: "res-admin --version", f"res-admin deploy all --upgrade --cluster-name {self.params.get_str(CommonKey.CLUSTER_NAME)} --aws-region {aws_cdk.Aws.REGION}", ], - task_role=self.permissions.update_role, + task_role=self.permissions.pipeline_role, ) def get_delete_task(self) -> sfn_tasks.EcsRunTask: @@ -100,7 +101,7 @@ def get_delete_task(self) -> sfn_tasks.EcsRunTask: f"--cluster-name {self.params.get_str(CommonKey.CLUSTER_NAME)} --aws-region {aws_cdk.Aws.REGION}" ), ], - task_role=self.permissions.delete_role, + task_role=self.permissions.pipeline_role, ) def get_task( diff --git a/source/idea/infrastructure/scripts/regional_pipeline_deployment.sh b/source/idea/infrastructure/scripts/regional_pipeline_deployment.sh new file mode 100755 index 0000000..e8579e2 --- /dev/null +++ b/source/idea/infrastructure/scripts/regional_pipeline_deployment.sh @@ -0,0 +1,54 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +#!/bin/bash +set -e + +# Check if jq is installed +if ! command -v jq &> /dev/null +then + echo "jq could not be found. Please install jq to run this script." + exit 1 +fi + +# Check if a file path is provided +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " + exit 1 +fi + +FILE_PATH=$1 +CDK_JSON_PATH=$2 + +# Check if the file exists +if [ ! -f "$FILE_PATH" ]; then + echo "File not found: $FILE_PATH" + exit 1 +fi + +if [ ! -f "$CDK_JSON_PATH" ]; then + echo "File not found: $CDK_JSON_PATH" + exit 1 +fi + +# Iterate over the list of accounts in the JSON file +jq -c '.parent_stack.regional_stacks[]' $FILE_PATH | while read regional_stack; do + # You can process each account here. For now, just printing it. + account=$(jq -r '.account_id' <<< "$regional_stack") + region=$(jq -r '.region' <<< "$regional_stack") + echo "Account: $account" + echo "Region: $region" + + tmp=$(mktemp) + echo "tmp file: $tmp" + # update the context file with PortalDomain name for this account + jq ".context.PortalDomainName = \"\\\"$region.integtest.res.hpc.aws.dev\\\"\"" $CDK_JSON_PATH > "$tmp" && mv "$tmp" $CDK_JSON_PATH + jq ".context.CustomDomainNameforWebApp = \"\\\"web.$region.integtest.res.hpc.aws.dev\\\"\"" $CDK_JSON_PATH > "$tmp" && mv "$tmp" $CDK_JSON_PATH + jq ".context.CustomDomainNameforVDI = \"\\\"vdi.$region.integtest.res.hpc.aws.dev\\\"\"" $CDK_JSON_PATH > "$tmp" && mv "$tmp" $CDK_JSON_PATH + echo "Getting credentials for account $account" + ada credentials update --once --provider isengard --role Admin --account $account + export AWS_DEFAULT_REGION=$region + npx cdk bootstrap + npx cdk synth + npx cdk deploy --require-approval never RESBuildPipelineStack --context repository_name=DigitalEngineeringPlatform --context branch_name=develop -c deploy=true -c batteries_included=true +done diff --git a/source/idea/pipeline/integ_tests/__init__.py b/source/idea/pipeline/integ_tests/__init__.py new file mode 100644 index 0000000..33cbe86 --- /dev/null +++ b/source/idea/pipeline/integ_tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/source/idea/pipeline/integ_tests/integ_test_step_builder.py b/source/idea/pipeline/integ_tests/integ_test_step_builder.py new file mode 100644 index 0000000..b71dd61 --- /dev/null +++ b/source/idea/pipeline/integ_tests/integ_test_step_builder.py @@ -0,0 +1,105 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from aws_cdk import Duration +from aws_cdk import aws_codebuild as codebuild +from aws_cdk import aws_iam as iam +from aws_cdk import pipelines + +from idea.pipeline.utils import get_commands_for_scripts + + +class IntegTestStepBuilder: + def __init__( + self, tox_env: str, environment_name: str, region: str, is_legacy: bool = False + ): + self._tox_env = tox_env + self._tox_command_arguments = [f"aws-region={region}"] + if is_legacy: + self._tox_command_arguments.append(f"cluster-name={environment_name}") + else: + self._tox_command_arguments.append(f"environment-name={environment_name}") + self._env = dict( + CLUSTER_NAME=environment_name, + AWS_REGION=region, + ) + self._install_commands = get_commands_for_scripts( + [ + "source/idea/pipeline/scripts/common/install_commands.sh", + "source/idea/pipeline/scripts/integ_tests/install_commands.sh", + "source/idea/pipeline/scripts/tox/install_commands.sh", + ] + ) + self._role_policy_statements = [ + # Default permissions to update security groups so that the integ test can get HTTP access to the RES environment + iam.PolicyStatement.from_json( + { + "Effect": "Allow", + "Action": [ + "ec2:DescribeSecurityGroups", + "ec2:DescribeSecurityGroupRules", + "ec2:AuthorizeSecurityGroupIngress", + "ec2:RevokeSecurityGroupIngress", + ], + "Resource": "*", + } + ), + ] + + def test_specific_tox_command_argument( + self, *arguments: str + ) -> IntegTestStepBuilder: + for argument in arguments: + self._tox_command_arguments.append(argument) + + return self + + def test_specific_install_command(self, *commands: str) -> IntegTestStepBuilder: + for command in commands: + self._install_commands.append(command) + + return self + + def test_specific_env(self, **env_variables: str) -> IntegTestStepBuilder: + self._env.update(env_variables) + + return self + + def test_specific_role_policy_statement( + self, *statements: iam.PolicyStatement + ) -> IntegTestStepBuilder: + for statement in statements: + self._role_policy_statements.append(statement) + + return self + + def build(self) -> pipelines.CodeBuildStep: + # Setting up commands necessary to run integ tests + commands = get_commands_for_scripts( + ["source/idea/pipeline/scripts/integ_tests/setup_commands.sh"] + ) + + tox_command = f"tox -e {self._tox_env} --" + for tox_command_argument in self._tox_command_arguments: + tox_command = tox_command + f" -p {tox_command_argument}" + commands.append(tox_command) + + commands += get_commands_for_scripts( + ["source/idea/pipeline/scripts/integ_tests/teardown_commands.sh"] + ) + + return pipelines.CodeBuildStep( + self._tox_env, + build_environment=codebuild.BuildEnvironment( + build_image=codebuild.LinuxBuildImage.STANDARD_5_0, + compute_type=codebuild.ComputeType.SMALL, + privileged=True, + ), + env=self._env, + install_commands=self._install_commands, + commands=commands, + role_policy_statements=self._role_policy_statements, + timeout=Duration.hours(4), + ) diff --git a/source/idea/pipeline/scripts/chrome/install_commands.sh b/source/idea/pipeline/scripts/chrome/install_commands.sh new file mode 100644 index 0000000..9afd288 --- /dev/null +++ b/source/idea/pipeline/scripts/chrome/install_commands.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +set -ex + +sudo apt-get -y install unzip + +wget -qP /tmp/ "https://edgedl.me.gvt1.com/edgedl/chrome/chrome-for-testing/120.0.6099.109/linux64/chromedriver-linux64.zip" +sudo unzip -oj /tmp/chromedriver-linux64.zip -d /usr/bin +sudo chmod 755 /usr/bin/chromedriver + +echo 'debconf debconf/frontend select Noninteractive' | sudo debconf-set-selections +wget -qO- https://dl-ssl.google.com/linux/linux_signing_key.pub | sudo apt-key add - +sudo sh -c 'echo "deb https://dl.google.com/linux/chrome/deb/ stable main" >> /etc/apt/sources.list.d/google.list' +sudo apt-get update +sudo apt-get -y install google-chrome-stable diff --git a/source/idea/pipeline/scripts/destroy/commands.sh b/source/idea/pipeline/scripts/destroy/commands.sh index 3c5cac5..7ca2598 100644 --- a/source/idea/pipeline/scripts/destroy/commands.sh +++ b/source/idea/pipeline/scripts/destroy/commands.sh @@ -8,7 +8,6 @@ resStacks=("Deploy-$INSTALL_STACK_NAME" "$CLUSTER_NAME-metrics" "$CLUSTER_NAME-directoryservice" "$CLUSTER_NAME-identity-provider" - "$CLUSTER_NAME-analytics" "$CLUSTER_NAME-shared-storage" "$CLUSTER_NAME-cluster-manager" "$CLUSTER_NAME-vdc" @@ -33,7 +32,7 @@ aws cloudformation delete-stack --stack-name Deploy-$INSTALL_STACK_NAME --region #Review Deletion of RES CFN Stacks Loop waitMinutes=0; failedStackId=""; -removedStackCount=0 +removedStackCount=0; while [[ ${#resStackIds[@]} -ne $removedStackCount && $failedStackId == "" ]] do echo "$waitMinutes minutes have past..."; @@ -56,11 +55,100 @@ do let waitMinutes++; done -#Verification of RES deletion +#Verification of RES CloudFormation stacks deletion if [[ $failedStackId != "" ]] ; then echo "RES deployment deletion FAILED"; echo "$failedStackId: DELETE_FAILED"; exit 1; else echo "All RES CFN stacks have been deleted"; + echo; +fi + +echo "Cleaning up all RES EFS file systems in VPC using $CLUSTER_NAME-shared-storage-security-group..."; + +#Pulling VPC ID from provided SSM parameter to VPC ID of BI +if [[ $BATTERIES_INCLUDED == "true" ]]; then + VPC_ID_INFO=$(aws ssm get-parameter --name $VPC_ID) + VPC_ID=$(echo $VPC_ID_INFO | jq -r '.Parameter.Value') fi + +#Collect all pertinent EFS file systems +EFS_FILE_SYSTEMS=$(aws efs describe-file-systems --region $AWS_REGION --no-paginate --query "FileSystems[?Tags[?Key == 'res:EnvironmentName' && Value == '$CLUSTER_NAME']][].FileSystemId"); + +if [[ $EFS_FILE_SYSTEMS != "[]" ]] ; then + #Loop to delete all EFS file systems + echo $EFS_FILE_SYSTEMS | jq -r '.[]' | while read FileSystemId; do + FILE_SYSTEM_MOUNT_TARGETS=$(aws efs describe-mount-targets --region $AWS_REGION --no-paginate --file-system-id $FileSystemId); + EFS_VPC_ID=$(echo $FILE_SYSTEM_MOUNT_TARGETS | jq -r '.MountTargets[0].VpcId'); + if [[ $EFS_VPC_ID == $VPC_ID ]] ; then + #Deleting all MountTargets of EFS file system + echo $FILE_SYSTEM_MOUNT_TARGETS | jq -r '.MountTargets[].MountTargetId' | while read MountTargetId; do + echo "Deleting MountTarget $MountTargetId of $FileSystemId..."; + aws efs delete-mount-target --region $AWS_REGION --mount-target-id $MountTargetId; + done + sleep 90; + #Deleting EFS file system + echo "Deleting EFS file system $FileSystemId..."; + echo; + aws efs delete-file-system --region $AWS_REGION --file-system-id $FileSystemId; + else + echo "EFS file system not in RES VPC, skipping..."; + fi + done + echo "Waiting 5 minutes for EFS file systems in VPC to finish deleting..."; + sleep 300; +else + echo "No RES EFS file systems for $CLUSTER_NAME detected in VPC to delete!"; +fi +echo; + +echo "Cleaning up all FSx OnTAP file systems in VPC using $CLUSTER_NAME-shared-storage-security-group..."; + +#Collect all pertinent FSx ONTAP file systems +FSX_ONTAP_FILE_SYSTEMS=$(aws fsx describe-file-systems --region $AWS_REGION --no-paginate --query "FileSystems[?Tags[?Key == 'res:EnvironmentName' && Value == '$CLUSTER_NAME'] && VpcId == '$VPC_ID'][].FileSystemId"); + +if [[ $FSX_ONTAP_FILE_SYSTEMS != "[]" ]] ; then + #Loop to delete all FSx ONTAP file systems + echo $FSX_ONTAP_FILE_SYSTEMS | jq -r '.[]' | while read FileSystemId; do + #Deleting all non-root volumes of FSx ONTAP file system + echo "Deleting all non-root volumes of $FileSystemId..."; + FSX_VOLUMES=$(aws fsx describe-volumes --region $AWS_REGION --no-paginate --filters Name=file-system-id,Values=$FileSystemId --query "Volumes[?!(OntapConfiguration.StorageVirtualMachineRoot)][].VolumeId"); + echo $FSX_VOLUMES | jq -r '.[]' | while read VolumeId; do + echo "Deleting Volume $VolumeId of $FileSystemId..."; + aws fsx delete-volume --region $AWS_REGION --volume-id $VolumeId --ontap-configuration SkipFinalBackup=true; + done + sleep 90; + #Deleting all SVMs of FSx ONTAP file system + echo; + echo "Deleting all storage virtual machines of $FileSystemId..."; + FSX_SVMS=$(aws fsx describe-storage-virtual-machines --region $AWS_REGION --no-paginate --filters Name=file-system-id,Values=$FileSystemId --query "StorageVirtualMachines[].StorageVirtualMachineId"); + echo $FSX_SVMS | jq -r '.[]' | while read StorageVirtualMachineId; do + echo "Deleting SVM $StorageVirtualMachineId of $FileSystemId..."; + aws fsx delete-storage-virtual-machine --region $AWS_REGION --storage-virtual-machine-id $StorageVirtualMachineId; + done + sleep 120; + #Deleting FSx ONTAP file system + echo; + echo "Deleting FSx ONTAP file system $FileSystemId..."; + aws fsx delete-file-system --region $AWS_REGION --file-system-id $FileSystemId; + echo; + done + echo "Waiting 15 minutes for FSx ONTAP file systems in VPC to finish deleting..."; + sleep 900 +else + echo "No RES FSx ONTAP file systems for $CLUSTER_NAME detected in VPC to delete!"; +fi +echo; + +echo "All RES shared-storage file systems in VPC have been deleted!"; +echo; +echo "Deleting $CLUSTER_NAME-shared-storage-security-group..."; + +SG_SHARED_STORAGE_INFO=$(aws ec2 describe-security-groups --region $AWS_REGION --filters Name=group-name,Values=$CLUSTER_NAME-shared-storage-security-group Name=vpc-id,Values=$VPC_ID); + +SG_SHARED_STORAGE_ID=$(echo $SG_SHARED_STORAGE_INFO | jq -r '.SecurityGroups[0].GroupId'); + +aws ec2 delete-security-group --group-id $SG_SHARED_STORAGE_ID; + +echo "$CLUSTER_NAME-shared-storage-security-group has been deleted!"; \ No newline at end of file diff --git a/source/idea/pipeline/scripts/destroy/install_commands.sh b/source/idea/pipeline/scripts/destroy/install_commands.sh new file mode 100644 index 0000000..d6b5b9a --- /dev/null +++ b/source/idea/pipeline/scripts/destroy/install_commands.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +set -ex + +pip install --upgrade pip +pip uninstall -y pyOpenSSL +pip install -r requirements/dev.txt +pip install source/idea/idea-data-model/src +pip install ".[dev]" diff --git a/source/idea/pipeline/scripts/helpers/create_s3_release_buckets.sh b/source/idea/pipeline/scripts/helpers/create_s3_release_buckets.sh index 3787b5b..3682ef9 100755 --- a/source/idea/pipeline/scripts/helpers/create_s3_release_buckets.sh +++ b/source/idea/pipeline/scripts/helpers/create_s3_release_buckets.sh @@ -30,13 +30,13 @@ function run(){ done echo -e "${CYAN}Attaching role with permission to buckets" aws iam put-role-policy --role-name $codeBuildRoleName --policy-name $BUCKET_PREFIX-policy \ - --policy-document "{ \"Statement\": + --policy-document "{ \"Statement\": [{ - \"Effect\": \"Allow\", - \"Action\": [\"s3:PutObject\", \"s3:getBucketLocation\", \"s3:ListBucket\", \"s3:GetObject\", \"s3:DeleteObject\"], - \"Resource\": [\"arn:aws:s3:::$BUCKET_PREFIX-*\"] - } + \"Effect\": \"Allow\", + \"Action\": [\"s3:PutObject\", \"s3:getBucketLocation\", \"s3:ListBucket\", \"s3:GetObject\", \"s3:DeleteObject\"], + \"Resource\": [\"arn:aws:s3:::$BUCKET_PREFIX-*\"] + } ]}" } -run \ No newline at end of file +run diff --git a/source/idea/pipeline/scripts/integ_tests/setup_commands.sh b/source/idea/pipeline/scripts/integ_tests/setup_commands.sh index 7c2c2bd..7c14845 100755 --- a/source/idea/pipeline/scripts/integ_tests/setup_commands.sh +++ b/source/idea/pipeline/scripts/integ_tests/setup_commands.sh @@ -4,9 +4,13 @@ PUBLIC_IP=$(curl https://checkip.amazonaws.com/) -USERPOOLID=`aws cognito-idp list-user-pools --region $AWS_REGION --max-results 60 --query 'UserPools[?Name==\`'$CLUSTER_NAME-user-pool'\`].Id' --output text` +if [[ -z $CLUSTERADMIN_USERNAME || -z $CLUSTERADMIN_PASSWORD ]]; then + echo 'skip clusteradmin credentials setup' +else + USERPOOLID=`aws cognito-idp list-user-pools --region $AWS_REGION --max-results 60 --query 'UserPools[?Name==\`'$CLUSTER_NAME-user-pool'\`].Id' --output text` -aws cognito-idp admin-set-user-password --user-pool-id $USERPOOLID --region $AWS_REGION --username $CLUSTERADMIN_USERNAME --password $CLUSTERADMIN_PASSWORD --permanent + aws cognito-idp admin-set-user-password --user-pool-id $USERPOOLID --region $AWS_REGION --username $CLUSTERADMIN_USERNAME --password $CLUSTERADMIN_PASSWORD --permanent +fi SG_EXTERNAL_ALB_INFO=$(aws ec2 describe-security-groups --region $AWS_REGION --filters Name=group-name,Values=$CLUSTER_NAME-external-load-balancer-security-group) SG_BASTION_HOST_INFO=$(aws ec2 describe-security-groups --region $AWS_REGION --filters Name=group-name,Values=$CLUSTER_NAME-bastion-host-security-group) diff --git a/source/idea/pipeline/scripts/publish/commands.sh b/source/idea/pipeline/scripts/publish/commands.sh index c0e7e2b..95e80a0 100644 --- a/source/idea/pipeline/scripts/publish/commands.sh +++ b/source/idea/pipeline/scripts/publish/commands.sh @@ -8,7 +8,7 @@ COMMIT_ID=$(echo $CODEBUILD_RESOLVED_SOURCE_VERSION | cut -b -8) RELEASE_VERSION=$(echo "$( pipelines.CodePipelineSource: ) def get_synth_step( - self, ecr_repository_name: str, ecr_public_repository_name: str + self, + ecr_repository_name: str, + ecr_public_repository_name: str, + bi_stack_template_url: str, ) -> pipelines.CodeBuildStep: return pipelines.CodeBuildStep( "Synth", @@ -193,19 +232,22 @@ def get_synth_step( REPOSITORY_NAME=self._repository_name, BRANCH=self._branch_name, DEPLOY="true" if self._deploy else "false", + INTEGRATION_TESTS="true" if self._integ_tests else "false", DESTROY="true" if self._destroy else "false", + BATTERIES_INCLUDED="true" if self._bi else "false", + BIStackTemplateURL=bi_stack_template_url, ECR_REPOSITORY=ecr_repository_name, ECR_PUBLIC_REPOSITORY_NAME=ecr_public_repository_name, PUBLISH_TEMPLATES="true" if self._publish_templates else "false", **self.params.to_context(), ), - install_commands=self.get_commands_for_scripts( + install_commands=get_commands_for_scripts( [ "source/idea/pipeline/scripts/common/install_commands.sh", "source/idea/pipeline/scripts/synth/install_commands.sh", ] ), - commands=self.get_commands_for_scripts( + commands=get_commands_for_scripts( [ "source/idea/pipeline/scripts/synth/commands.sh", ] @@ -213,7 +255,7 @@ def get_synth_step( partial_build_spec=self.get_reports_partial_build_spec("pytest-report.xml"), ) - def get_integ_test_step( + def get_component_integ_test_steps( self, integ_test_envs: list[str] ) -> list[pipelines.CodeBuildStep]: steps: list[pipelines.CodeBuildStep] = [] @@ -221,41 +263,17 @@ def get_integ_test_step( clusteradmin_password = "RESPassword1." # fixed password for running tests for _env in integ_test_envs: - # Setting up commands necessary to set clusteradmin password and then run integ tests - commands = self.get_commands_for_scripts( - ["source/idea/pipeline/scripts/integ_tests/setup_commands.sh"] - ) - # The following command uses tox to invoke the integ-test; _env is in format like integ-tests.cluseter-manager - commands += [ - f"tox -e {_env} -- -p cluster-name={self.params.cluster_name} -p aws-region={self.region} " - f"-p admin-username={clusteradmin_username} " - f"-p admin-password={clusteradmin_password}", - ] - commands += self.get_commands_for_scripts( - ["source/idea/pipeline/scripts/integ_tests/teardown_commands.sh"] - ) - _step = pipelines.CodeBuildStep( - _env, - build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.STANDARD_5_0, - compute_type=codebuild.ComputeType.SMALL, - privileged=True, - ), - env=dict( - CLUSTER_NAME=self.params.cluster_name, - AWS_REGION=self.region, + _step = ( + IntegTestStepBuilder(_env, self.params.cluster_name, self.region, True) + .test_specific_tox_command_argument( + f"admin-username={clusteradmin_username}", + f"admin-password={clusteradmin_password}", + ) + .test_specific_env( CLUSTERADMIN_USERNAME=clusteradmin_username, CLUSTERADMIN_PASSWORD=clusteradmin_password, - ), - install_commands=self.get_commands_for_scripts( - [ - "source/idea/pipeline/scripts/common/install_commands.sh", - "source/idea/pipeline/scripts/integ_tests/install_commands.sh", - "source/idea/pipeline/scripts/tox/install_commands.sh", - ] - ), - commands=commands, - role_policy_statements=[ + ) + .test_specific_role_policy_statement( iam.PolicyStatement.from_json( { "Effect": "Allow", @@ -280,31 +298,102 @@ def get_integ_test_step( ], } ), - iam.PolicyStatement.from_json( - { - "Effect": "Allow", - "Action": [ - "ec2:DescribeSecurityGroups", - "ec2:DescribeSecurityGroupRules", - "ec2:AuthorizeSecurityGroupIngress", - "ec2:RevokeSecurityGroupIngress", - ], - "Resource": "*", - } - ), - ], - timeout=Duration.hours(4), + ) + .build() ) steps.append(_step) return steps + def get_smoke_test_step(self) -> pipelines.CodeBuildStep: + step = ( + IntegTestStepBuilder( + "integ-tests.smoke", self.params.cluster_name, self.region + ) + .test_specific_install_command( + *get_commands_for_scripts( + ["source/idea/pipeline/scripts/chrome/install_commands.sh"] + ) + ) + .test_specific_role_policy_statement( + iam.PolicyStatement.from_json( + { + "Effect": "Allow", + "Action": [ + "ssm:SendCommand", + ], + "Resource": [ + f"arn:{self.partition}:ssm:{self.region}:*:document/*", + ], + } + ), + iam.PolicyStatement.from_json( + { + "Effect": "Allow", + "Action": [ + "ssm:SendCommand", + ], + "Resource": [ + f"arn:{self.partition}:ec2:{self.region}:{self.account}:instance/*" + ], + "Condition": { + "StringLike": { + "ssm:resourceTag/res:EnvironmentName": [ + self.params.cluster_name + ] + } + }, + } + ), + iam.PolicyStatement.from_json( + { + "Effect": "Allow", + "Action": [ + "ssm:GetCommandInvocation", + ], + "Resource": "*", + } + ), + iam.PolicyStatement.from_json( + { + "Effect": "Allow", + "Action": [ + "s3:GetObject", + ], + # TODO: Specify the bucket to which SSM writes command outputs + "Resource": "*", + } + ), + iam.PolicyStatement.from_json( + { + "Effect": "Allow", + "Action": [ + "elasticloadbalancing:DescribeLoadBalancers", + ], + "Resource": "*", + } + ), + iam.PolicyStatement.from_json( + { + "Effect": "Allow", + "Action": [ + "autoscaling:DescribeAutoScalingGroups", + ], + "Resource": "*", + } + ), + ) + .build() + ) + + return step + @staticmethod def get_steps_from_tox(tox_env: list[str]) -> list[pipelines.CodeBuildStep]: steps: list[pipelines.CodeBuildStep] = [] for _env in tox_env: _step = pipelines.CodeBuildStep( _env, - install_commands=PipelineStack.get_commands_for_scripts( + install_commands=get_commands_for_scripts( [ "source/idea/pipeline/scripts/common/install_commands.sh", "source/idea/pipeline/scripts/tox/install_commands.sh", @@ -331,7 +420,6 @@ def get_destroy_step(self) -> pipelines.CodeBuildStep: f"arn:{self.partition}:cloudformation:{self.region}:{self.account}:stack/{self.params.cluster_name}-metrics/*", f"arn:{self.partition}:cloudformation:{self.region}:{self.account}:stack/{self.params.cluster_name}-directoryservice/*", f"arn:{self.partition}:cloudformation:{self.region}:{self.account}:stack/{self.params.cluster_name}-identity-provider/*", - f"arn:{self.partition}:cloudformation:{self.region}:{self.account}:stack/{self.params.cluster_name}-analytics/*", f"arn:{self.partition}:cloudformation:{self.region}:{self.account}:stack/{self.params.cluster_name}-shared-storage/*", f"arn:{self.partition}:cloudformation:{self.region}:{self.account}:stack/{self.params.cluster_name}-cluster-manager/*", f"arn:{self.partition}:cloudformation:{self.region}:{self.account}:stack/{self.params.cluster_name}-vdc/*", @@ -347,6 +435,89 @@ def get_destroy_step(self) -> pipelines.CodeBuildStep: f"arn:{self.partition}:cloudformation:{self.region}:{self.account}:stack/Deploy-{INSTALL_STACK_NAME}/*" ], ) + codebuild_read_ssm_parameter_vpc_id_policy = iam.PolicyStatement( + effect=iam.Effect.ALLOW, + actions=[ + "ssm:GetParameter", + ], + resources=[ + f"arn:{self.partition}:ssm:{self.region}:{self.account}:parameter{self.params.vpc_id}" + ], + ) + codebuild_read_file_systems_policy = iam.PolicyStatement( + effect=iam.Effect.ALLOW, + actions=[ + "elasticfilesystem:DescribeFileSystems", + "elasticfilesystem:DescribeMountTargets", + "fsx:DescribeFileSystems", + "fsx:DescribeStorageVirtualMachines", + "fsx:DescribeVolumes", + ], + resources=["*"], + ) + codebuild_efs_delete_file_systems_policy = iam.PolicyStatement( + effect=iam.Effect.ALLOW, + actions=[ + "elasticfilesystem:DeleteMountTarget", + "elasticfilesystem:DeleteFileSystem", + ], + resources=["*"], + conditions={ + "StringEquals": { + "aws:ResourceTag/res:EnvironmentName": [self.params.cluster_name], + }, + }, + ) + codebuild_efs_filesystem_ec2_delete_eni_policy = iam.PolicyStatement( + effect=iam.Effect.ALLOW, + actions=[ + "ec2:DeleteNetworkInterface", + ], + resources=["*"], + ) + codebuild_fsx_delete_file_systems_policy = iam.PolicyStatement( + effect=iam.Effect.ALLOW, + actions=[ + "fsx:DeleteFileSystem", + ], + resources=["*"], + conditions={ + "StringEquals": { + "aws:ResourceTag/res:EnvironmentName": [self.params.cluster_name], + }, + }, + ) + codebuild_fsx_delete_svms_volumes_policy = iam.PolicyStatement( + effect=iam.Effect.ALLOW, + actions=[ + "fsx:DeleteVolume", + "fsx:DeleteStorageVirtualMachine", + "fsx:CreateBackup", + "fsx:TagResource", + ], + resources=["*"], + ) + codebuild_shared_storage_security_group_read_policy = iam.PolicyStatement( + effect=iam.Effect.ALLOW, + actions=[ + "ec2:DescribeSecurityGroups", + ], + resources=["*"], + ) + codebuild_shared_storage_security_group_delete_policy = iam.PolicyStatement( + effect=iam.Effect.ALLOW, + actions=[ + "ec2:DeleteSecurityGroup", + ], + resources=["*"], + conditions={ + "StringEquals": { + "aws:ResourceTag/Name": [ + f"{self.params.cluster_name}-shared-storage-security-group" + ], + }, + }, + ) return pipelines.CodeBuildStep( "Destroy", build_environment=codebuild.BuildEnvironment( @@ -357,9 +528,17 @@ def get_destroy_step(self) -> pipelines.CodeBuildStep: env=dict( CLUSTER_NAME=self.params.cluster_name, AWS_REGION=self.region, + BATTERIES_INCLUDED="true" if self._bi else "false", + VPC_ID=self.params.vpc_id, INSTALL_STACK_NAME=INSTALL_STACK_NAME, ), - commands=self.get_commands_for_scripts( + install_commands=get_commands_for_scripts( + [ + "source/idea/pipeline/scripts/common/install_commands.sh", + "source/idea/pipeline/scripts/destroy/install_commands.sh", + ] + ), + commands=get_commands_for_scripts( [ "source/idea/pipeline/scripts/destroy/commands.sh", ] @@ -367,6 +546,14 @@ def get_destroy_step(self) -> pipelines.CodeBuildStep: role_policy_statements=[ codebuild_cloudformation_read_policy, codebuild_cloudformation_delete_stack_policy, + codebuild_read_ssm_parameter_vpc_id_policy, + codebuild_read_file_systems_policy, + codebuild_efs_delete_file_systems_policy, + codebuild_efs_filesystem_ec2_delete_eni_policy, + codebuild_fsx_delete_file_systems_policy, + codebuild_fsx_delete_svms_volumes_policy, + codebuild_shared_storage_security_group_read_policy, + codebuild_shared_storage_security_group_delete_policy, ], timeout=Duration.hours(2), ) @@ -420,13 +607,13 @@ def get_publish_steps( PUBLISH_TEMPLATES="true" if self._publish_templates else "false", ECR_REPOSITORY_URI_PARAMETER=ecr_public_repository_uri, ), - install_commands=self.get_commands_for_scripts( + install_commands=get_commands_for_scripts( [ "source/idea/pipeline/scripts/common/install_commands.sh", "source/idea/pipeline/scripts/publish/install_commands.sh", ] ), - commands=self.get_commands_for_scripts( + commands=get_commands_for_scripts( [ "source/idea/pipeline/scripts/publish/commands.sh", ] @@ -468,9 +655,33 @@ def get_reports_partial_build_spec(self, filename: str) -> codebuild.BuildSpec: class DeployStage(Stage): - def __init__(self, scope: Construct, construct_id: str, parameters: Parameters): + def __init__( + self, + scope: Construct, + construct_id: str, + parameters: Union[RESParameters, BIParameters], + ): super().__init__(scope, construct_id) registry_name = self.node.try_get_context("registry_name") - self.install_stack = InstallStack( - self, INSTALL_STACK_NAME, parameters=parameters, registry_name=registry_name - ) + if isinstance(parameters, BIParameters): + bi_stack_template_url = self.node.try_get_context("BIStackTemplateURL") + self.batteries_included_stack = BiStack( + self, + BATTERIES_INCLUDED_STACK_NAME, + template_url=bi_stack_template_url, + parameters=parameters, + ) + self.install_stack = InstallStack( + self, + INSTALL_STACK_NAME, + parameters=parameters, + registry_name=registry_name, + ) + self.install_stack.add_dependency(target=self.batteries_included_stack) + else: + self.install_stack = InstallStack( + self, + INSTALL_STACK_NAME, + parameters=parameters, + registry_name=registry_name, + ) diff --git a/source/idea/pipeline/utils.py b/source/idea/pipeline/utils.py new file mode 100644 index 0000000..f628d78 --- /dev/null +++ b/source/idea/pipeline/utils.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pathlib + + +def get_commands_for_scripts(paths: list[str]) -> list[str]: + commands = [] + root = pathlib.Path("source").parent + scripts = root / "source/idea/pipeline/scripts" + for raw_path in paths: + path = pathlib.Path(raw_path) + if not path.exists(): + raise ValueError(f"script path doesn't exist: {path}") + if not path.is_relative_to(scripts): + raise ValueError(f"script path isn't in {scripts}: {path}") + relative = path.relative_to(root) + commands.append(f"chmod +x {relative}") + commands.append(str(relative)) + return commands diff --git a/source/tests/integration/__init__.py b/source/tests/integration/__init__.py new file mode 100644 index 0000000..6d8d18a --- /dev/null +++ b/source/tests/integration/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. diff --git a/source/tests/integration/conftest.py b/source/tests/integration/conftest.py new file mode 100644 index 0000000..031cc2b --- /dev/null +++ b/source/tests/integration/conftest.py @@ -0,0 +1,65 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import pytest + + +def pytest_addoption(parser: pytest.Parser) -> None: + parser.addoption( + "--environment-name", + action="store", + required=True, + help="Name of the RES environment", + ) + parser.addoption( + "--aws-region", + action="store", + default="us-east-1", + help="Region of the RES environment", + ) + parser.addoption( + "--custom-vdi-domain-name", action="store", help="Custom vid domain name" + ) + parser.addoption( + "--custom-web-app-domain-name", + action="store", + help="Custom web app domain name", + ) + parser.addoption( + "--api-invoker-type", + action="store", + default="http", + help="Type of the API invoker", + choices="http", + ) + parser.addoption( + "--ssm-output-bucket", + action="store", + help="S3 bucket for storing SSM command output", + ) + + +@pytest.fixture +def environment_name(request: pytest.FixtureRequest) -> str: + environment_name: str = request.config.getoption("--environment-name") + return environment_name + + +@pytest.fixture +def region(request: pytest.FixtureRequest) -> str: + region: str = request.config.getoption("--aws-region") + return region + + +@pytest.fixture +def ssm_output_bucket(request: pytest.FixtureRequest) -> str: + ssm_output_bucket: str = request.config.getoption("--ssm-output-bucket") + return ssm_output_bucket diff --git a/source/tests/integration/framework/__init__.py b/source/tests/integration/framework/__init__.py new file mode 100644 index 0000000..6d8d18a --- /dev/null +++ b/source/tests/integration/framework/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. diff --git a/source/tests/integration/framework/api_invoker/__init__.py b/source/tests/integration/framework/api_invoker/__init__.py new file mode 100644 index 0000000..6d8d18a --- /dev/null +++ b/source/tests/integration/framework/api_invoker/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. diff --git a/source/tests/integration/framework/api_invoker/api_invoker_base.py b/source/tests/integration/framework/api_invoker/api_invoker_base.py new file mode 100644 index 0000000..4f6de8c --- /dev/null +++ b/source/tests/integration/framework/api_invoker/api_invoker_base.py @@ -0,0 +1,30 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from abc import abstractmethod + +from tests.integration.framework.model.api_invocation_context import ( + ApiInvocationContext, +) + + +class ResApiInvokerBase: + """ + Interface for the API invoker. + """ + + @abstractmethod + def invoke(self, context: ApiInvocationContext) -> None: + """ + Invoke the specific service API. + :param context: API invocation context. + """ + ... diff --git a/source/tests/integration/framework/api_invoker/http_api_invoker.py b/source/tests/integration/framework/api_invoker/http_api_invoker.py new file mode 100644 index 0000000..4bf2cd2 --- /dev/null +++ b/source/tests/integration/framework/api_invoker/http_api_invoker.py @@ -0,0 +1,39 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import json + +import requests + +from tests.integration.framework.api_invoker.api_invoker_base import ResApiInvokerBase +from tests.integration.framework.model.api_invocation_context import ( + ApiInvocationContext, +) + + +class HttpApiInvoker(ResApiInvokerBase): + """ + API invoker implementation for invoking service APIs via HTTP requests + """ + + def invoke(self, context: ApiInvocationContext) -> None: + http_headers = { + "Content-Type": "application/json;charset=UTF-8", + "X_RES_TEST_USERNAME": context.auth.username, + } + response = requests.post( + f"{context.endpoint}/api/v{context.request.header.version}/{context.request.namespace}", + json=json.loads(context.request.json(exclude_none=True, by_alias=True)), + headers=http_headers, + verify=False, + ).json() + + context.response = response diff --git a/source/tests/integration/framework/client/__init__.py b/source/tests/integration/framework/client/__init__.py new file mode 100644 index 0000000..6d8d18a --- /dev/null +++ b/source/tests/integration/framework/client/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. diff --git a/source/tests/integration/framework/client/res_client.py b/source/tests/integration/framework/client/res_client.py new file mode 100644 index 0000000..9233d50 --- /dev/null +++ b/source/tests/integration/framework/client/res_client.py @@ -0,0 +1,291 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import logging +from typing import Type + +import pytest +from selenium import webdriver +from selenium.webdriver.chrome.webdriver import WebDriver + +from ideadatamodel import ( # type: ignore + CreateProjectRequest, + CreateProjectResult, + CreateSessionRequest, + CreateSessionResponse, + CreateSoftwareStackRequest, + CreateSoftwareStackResponse, + DeleteProjectRequest, + DeleteProjectResult, + DeleteSessionRequest, + DeleteSessionResponse, + DeleteSoftwareStackRequest, + DeleteSoftwareStackResponse, + GetSessionConnectionInfoRequest, + GetSessionConnectionInfoResponse, + GetSessionInfoRequest, + GetSessionInfoResponse, + GetUserRequest, + GetUserResult, + ListSessionsRequest, + ListSessionsResponse, + ListSoftwareStackRequest, + ListSoftwareStackResponse, + ModifyUserRequest, + ModifyUserResult, + SocaEnvelope, + SocaHeader, + SocaPayload, + SocaPayloadType, + UpdateSessionRequest, + UpdateSessionResponse, + VirtualDesktopSession, + VirtualDesktopSessionConnectionInfo, +) +from tests.integration.framework.api_invoker.api_invoker_base import ResApiInvokerBase +from tests.integration.framework.api_invoker.http_api_invoker import HttpApiInvoker +from tests.integration.framework.fixtures.res_environment import ResEnvironment +from tests.integration.framework.model.api_invocation_context import ( + ApiInvocationContext, +) +from tests.integration.framework.model.client_auth import ClientAuth + +logger = logging.getLogger(__name__) + + +class ResClient: + """ + RES client for invoking service APIs + """ + + def __init__( + self, + request: pytest.FixtureRequest, + res_environment: ResEnvironment, + client_auth: ClientAuth, + ): + self._client_auth = client_auth + self._api_invoker = self._get_api_invoker( + request.config.getoption("--api-invoker-type") + ) + + custom_web_app_domain_name = request.config.getoption( + "--custom-web-app-domain-name" + ) + self._endpoint = ( + f"https://{custom_web_app_domain_name}" + if custom_web_app_domain_name + else f"https://{res_environment.default_web_app_domain_name}" + ) + + def create_project(self, request: CreateProjectRequest) -> CreateProjectResult: + logger.info(f"creating project {request.project.name}...") + + return self._invoke( + "Projects.CreateProject", + "cluster-manager", + request, + CreateProjectResult, + ) + + def delete_project(self, request: DeleteProjectRequest) -> DeleteProjectResult: + project = request.project_name if request.project_name else request.project_id + logger.info(f"deleting project {project}...") + + return self._invoke( + "Projects.DeleteProject", + "cluster-manager", + request, + DeleteProjectResult, + ) + + def get_user(self, request: GetUserRequest) -> GetUserResult: + logger.info(f"getting user {request.username}...") + + return self._invoke( + "Accounts.GetUser", + "cluster-manager", + request, + GetUserResult, + ) + + def modify_user(self, request: ModifyUserRequest) -> ModifyUserResult: + logger.info(f"modifying user {request.user.username}...") + + return self._invoke( + "Accounts.ModifyUser", + "cluster-manager", + request, + ModifyUserResult, + ) + + def create_software_stack( + self, request: CreateSoftwareStackRequest + ) -> CreateSoftwareStackResponse: + logger.info(f"creating software stack {request.software_stack.name}...") + + return self._invoke( + "VirtualDesktopAdmin.CreateSoftwareStack", + "vdc", + request, + CreateSoftwareStackResponse, + ) + + def delete_software_stack( + self, request: DeleteSoftwareStackRequest + ) -> DeleteSoftwareStackResponse: + logger.info(f"deleting software stack {request.software_stack.name}...") + + return self._invoke( + "VirtualDesktopAdmin.DeleteSoftwareStack", + "vdc", + request, + DeleteSoftwareStackResponse, + ) + + def list_software_stacks( + self, request: ListSoftwareStackRequest + ) -> ListSoftwareStackResponse: + logger.info(f"listing software stacks...") + + return self._invoke( + "VirtualDesktopAdmin.ListSoftwareStacks", + "vdc", + request, + ListSoftwareStackResponse, + ) + + def create_session(self, request: CreateSessionRequest) -> CreateSessionResponse: + logger.info(f"creating session {request.session.name}...") + + return self._invoke( + "VirtualDesktop.CreateSession", + "vdc", + request, + CreateSessionResponse, + ) + + def get_session_info( + self, request: GetSessionInfoRequest + ) -> GetSessionInfoResponse: + logger.info(f"getting session info for {request.session.name}...") + + return self._invoke( + "VirtualDesktop.GetSessionInfo", + "vdc", + request, + GetSessionInfoResponse, + ) + + def get_session_connection_info( + self, request: GetSessionConnectionInfoRequest + ) -> GetSessionConnectionInfoResponse: + logger.info( + f"getting connection info for session {request.connection_info.dcv_session_id}..." + ) + + return self._invoke( + "VirtualDesktop.GetSessionConnectionInfo", + "vdc", + request, + GetSessionConnectionInfoResponse, + ) + + def list_sessions(self, request: ListSessionsRequest) -> ListSessionsResponse: + logger.info("listing sessions...") + + return self._invoke( + "VirtualDesktop.ListSessions", + "vdc", + request, + ListSessionsResponse, + ) + + def update_session(self, request: UpdateSessionRequest) -> UpdateSessionResponse: + logger.info(f"updating session {request.session.dcv_session_id}...") + + return self._invoke( + "VirtualDesktop.UpdateSession", + "vdc", + request, + UpdateSessionResponse, + ) + + def join_session(self, session: VirtualDesktopSession) -> WebDriver: + get_session_connection_info_request = GetSessionConnectionInfoRequest( + connection_info=VirtualDesktopSessionConnectionInfo( + dcv_session_id=session.dcv_session_id, + ) + ) + get_session_connection_info_response = self.get_session_connection_info( + get_session_connection_info_request + ) + connection_info = get_session_connection_info_response.connection_info + logger.info(f"joining session {connection_info.dcv_session_id}...") + + # Open the session connection URL from a Chrome browser and keep the connection active. + options = webdriver.ChromeOptions() + options.add_argument("--headless") # type: ignore + options.add_argument("--ignore-certificate-errors") # type: ignore + options.add_argument("--no-sandbox") # type: ignore + driver = webdriver.Chrome(options=options) + + connection_url = f"{connection_info.endpoint}{connection_info.web_url_path}?authToken={connection_info.access_token}#{connection_info.dcv_session_id}" + driver.get(connection_url) + + return driver + + def delete_sessions(self, request: DeleteSessionRequest) -> DeleteSessionResponse: + session_names = [session.name for session in request.sessions] + logger.info(f"deleting sessions {session_names}...") + + return self._invoke( + "VirtualDesktop.DeleteSessions", + "vdc", + request, + DeleteSessionResponse, + ) + + def _invoke( + self, + namespace: str, + component: str, + request: SocaPayload, + response_type: Type[SocaPayloadType], + ) -> SocaPayloadType: + header = SocaHeader() + header.namespace = namespace + header.version = 1 + + envelope = SocaEnvelope() + envelope.header = header + envelope.payload = request + + context = ApiInvocationContext( + endpoint=f"{self._endpoint}/{component}", + request=envelope, + auth=self._client_auth, + ) + + self._api_invoker.invoke(context) + + assert ( + context.response_is_success() + ), f'error code: {context.response.get("error_code")} message: {context.response.get("message")}' + + return context.get_response_payload_as(response_type) + + @staticmethod + def _get_api_invoker(api_invoker_type: str) -> ResApiInvokerBase: + if api_invoker_type == "http": + return HttpApiInvoker() + else: + raise Exception(f"Invalid API invoker type {api_invoker_type}") diff --git a/source/tests/integration/framework/fixtures/__init__.py b/source/tests/integration/framework/fixtures/__init__.py new file mode 100644 index 0000000..6d8d18a --- /dev/null +++ b/source/tests/integration/framework/fixtures/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. diff --git a/source/idea/idea-data-model/src/ideadatamodel/analytics/analytics_model.py b/source/tests/integration/framework/fixtures/fixture_request.py similarity index 66% rename from source/idea/idea-data-model/src/ideadatamodel/analytics/analytics_model.py rename to source/tests/integration/framework/fixtures/fixture_request.py index edc2b39..548ee52 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/analytics/analytics_model.py +++ b/source/tests/integration/framework/fixtures/fixture_request.py @@ -9,19 +9,10 @@ # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions # and limitations under the License. -__all__ = ( - 'OpenSearchQueryRequest', - 'OpenSearchQueryResult' -) +from typing import Any -from ideadatamodel import SocaPayload +from pytest import FixtureRequest as _FixtureRequest -from typing import Optional, Dict - -class OpenSearchQueryRequest(SocaPayload): - data: Optional[Dict] - - -class OpenSearchQueryResult(SocaPayload): - data: Optional[Dict] +class FixtureRequest(_FixtureRequest): + param: list[Any] diff --git a/source/tests/integration/framework/fixtures/project.py b/source/tests/integration/framework/fixtures/project.py new file mode 100644 index 0000000..19d7fcb --- /dev/null +++ b/source/tests/integration/framework/fixtures/project.py @@ -0,0 +1,47 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import pytest + +from ideadatamodel import ( # type: ignore + CreateProjectRequest, + DeleteProjectRequest, + Project, +) +from tests.integration.framework.client.res_client import ResClient +from tests.integration.framework.fixtures.fixture_request import FixtureRequest +from tests.integration.framework.fixtures.res_environment import ResEnvironment +from tests.integration.framework.model.client_auth import ClientAuth + + +@pytest.fixture +def project( + request: FixtureRequest, res_environment: ResEnvironment, admin: ClientAuth +) -> Project: + """ + Fixture for setting up/tearing down the test project + """ + project = request.param[0] + filesystem_names = request.param[1] + create_project_request = CreateProjectRequest( + project=project, filesystem_names=filesystem_names + ) + + client = ResClient(request, res_environment, admin) + project = client.create_project(create_project_request).project + + def tear_down() -> None: + delete_project_request = DeleteProjectRequest(project_name=project.name) + client.delete_project(delete_project_request) + + request.addfinalizer(tear_down) + + return project diff --git a/source/tests/integration/framework/fixtures/res_environment.py b/source/tests/integration/framework/fixtures/res_environment.py new file mode 100644 index 0000000..c0e6893 --- /dev/null +++ b/source/tests/integration/framework/fixtures/res_environment.py @@ -0,0 +1,98 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import logging +from functools import cached_property + +import boto3 +import pytest + +from tests.integration.framework.utils.ad_sync import ad_sync +from tests.integration.framework.utils.ec2_utils import ( + all_in_service_instances_from_asgs, +) +from tests.integration.framework.utils.test_mode import set_test_mode_for_all_servers + +logger = logging.getLogger(__name__) + + +class ResEnvironment: + def __init__(self, environment_name: str, region: str): + self._environment_name = environment_name + self._region = region + + @property + def region(self) -> str: + return self._region + + @property + def cluster_manager_asg(self) -> str: + return f"{self._environment_name}-cluster-manager-asg" + + @property + def vdc_controller_asg(self) -> str: + return f"{self._environment_name}-vdc-controller-asg" + + @cached_property + def default_web_app_domain_name(self) -> str: + session = boto3.session.Session(region_name=self._region) + client = session.client("elbv2") + + describe_load_balancers_response = client.describe_load_balancers( + Names=[f"{self._environment_name}-external-alb"] + ) + dns_name: str = ( + describe_load_balancers_response["LoadBalancers"][0].get("DNSName", "") + if len(describe_load_balancers_response.get("LoadBalancers", [])) == 1 + else "" + ) + + return dns_name + + +@pytest.fixture +def res_environment( + request: pytest.FixtureRequest, + environment_name: str, + region: str, +) -> ResEnvironment: + """ + Fixture for setting up the RES test environment + """ + # Initialize the Res Environment + res_environment = ResEnvironment(region=region, environment_name=environment_name) + + cluster_manager_instances = all_in_service_instances_from_asgs( + [res_environment.cluster_manager_asg], + region, + ) + assert len(cluster_manager_instances) > 0, "Cluster Manager doesn't exist" + vdc_instances = all_in_service_instances_from_asgs( + [res_environment.vdc_controller_asg], + region, + ) + assert len(vdc_instances) > 0, "Virtual Desktop Controller doesn't exist" + + logger.info("relaunch servers in test mode") + server_instances = cluster_manager_instances + vdc_instances + set_test_mode_for_all_servers(region, server_instances, True) + + # Sync users and groups from AD so that integ tests can run with these users and groups + logger.info("sync users and groups from AD") + ad_sync(region, cluster_manager_instances[0]) + + def tear_down() -> None: + logger.info("disable server test mode") + set_test_mode_for_all_servers(region, server_instances, False) + + request.addfinalizer(tear_down) + + return res_environment diff --git a/source/tests/integration/framework/fixtures/session.py b/source/tests/integration/framework/fixtures/session.py new file mode 100644 index 0000000..457aee7 --- /dev/null +++ b/source/tests/integration/framework/fixtures/session.py @@ -0,0 +1,88 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import pytest + +from ideadatamodel import ( # type: ignore + CreateSessionRequest, + DeleteProjectRequest, + DeleteSessionRequest, + UpdateSessionRequest, + VirtualDesktopSchedule, + VirtualDesktopScheduleType, + VirtualDesktopSession, + VirtualDesktopWeekSchedule, +) +from tests.integration.framework.client.res_client import ResClient +from tests.integration.framework.fixtures.fixture_request import FixtureRequest +from tests.integration.framework.fixtures.res_environment import ResEnvironment +from tests.integration.framework.model.client_auth import ClientAuth +from tests.integration.framework.utils.session_utils import ( + wait_for_deleting_session, + wait_for_launching_session, +) + + +@pytest.fixture +def session( + request: FixtureRequest, res_environment: ResEnvironment, non_admin: ClientAuth +) -> VirtualDesktopSession: + """ + Fixture for setting up/tearing down the test project + """ + session = request.param[0] + project = request.getfixturevalue(request.param[1]) + software_stack = request.getfixturevalue(request.param[2]) + + session.project = project + session.software_stack = software_stack + create_session_request = CreateSessionRequest(session=session) + + client = ResClient(request, res_environment, non_admin) + create_session_response = client.create_session(create_session_request) + + session = wait_for_launching_session(client, create_session_response.session) + + # Update the schedule to make sure that the virtual desktop session can be active every day. + session.schedule = VirtualDesktopWeekSchedule( + monday=VirtualDesktopSchedule( + schedule_type=VirtualDesktopScheduleType.NO_SCHEDULE + ), + tuesday=VirtualDesktopSchedule( + schedule_type=VirtualDesktopScheduleType.NO_SCHEDULE + ), + wednesday=VirtualDesktopSchedule( + schedule_type=VirtualDesktopScheduleType.NO_SCHEDULE + ), + thursday=VirtualDesktopSchedule( + schedule_type=VirtualDesktopScheduleType.NO_SCHEDULE + ), + friday=VirtualDesktopSchedule( + schedule_type=VirtualDesktopScheduleType.NO_SCHEDULE + ), + saturday=VirtualDesktopSchedule( + schedule_type=VirtualDesktopScheduleType.NO_SCHEDULE + ), + sunday=VirtualDesktopSchedule( + schedule_type=VirtualDesktopScheduleType.NO_SCHEDULE + ), + ) + update_session_request = UpdateSessionRequest(session=session) + client.update_session(update_session_request) + + def tear_down() -> None: + session.force = True + client.delete_sessions(DeleteSessionRequest(sessions=[session])) + wait_for_deleting_session(client, session) + + request.addfinalizer(tear_down) + + return session diff --git a/source/tests/integration/framework/fixtures/software_stack.py b/source/tests/integration/framework/fixtures/software_stack.py new file mode 100644 index 0000000..b13d875 --- /dev/null +++ b/source/tests/integration/framework/fixtures/software_stack.py @@ -0,0 +1,96 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import pytest + +from ideadatamodel import ( # type: ignore + CreateSoftwareStackRequest, + DeleteSoftwareStackRequest, + ListSoftwareStackRequest, + Project, + SocaFilter, + VirtualDesktopSoftwareStack, +) +from tests.integration.framework.client.res_client import ResClient +from tests.integration.framework.fixtures.fixture_request import FixtureRequest +from tests.integration.framework.fixtures.res_environment import ResEnvironment +from tests.integration.framework.model.client_auth import ClientAuth + + +@pytest.fixture +def software_stack( + request: FixtureRequest, + res_environment: ResEnvironment, + admin: ClientAuth, +) -> VirtualDesktopSoftwareStack: + """ + Fixture for setting up/tearing down the test software stack + """ + software_stack = request.param[0] + project = request.getfixturevalue(request.param[1]) + software_stack.projects = [project] + + client = ResClient(request, res_environment, admin) + + if not software_stack.ami_id: + base_os = software_stack.base_os + architecture = software_stack.architecture + gpu = software_stack.gpu + assert ( + base_os and architecture and gpu + ), f"Either provide an AMI ID or (base OS + architecture + GPU) of the software stack" + + list_software_stacks_request = ListSoftwareStackRequest( + filters=[ + SocaFilter( + key="base_os", + eq=base_os, + ), + SocaFilter( + key="architecture", + eq=architecture, + ), + SocaFilter( + key="gpu", + eq=gpu, + ), + ] + ) + list_software_stacks_response = client.list_software_stacks( + list_software_stacks_request + ) + existing_software_stacks = ( + list_software_stacks_response.listing + if list_software_stacks_response.listing + else [] + ) + assert ( + len(existing_software_stacks) > 0 + ), f"Failed to find existing software stacks with base OS {base_os}, architecture {architecture} and GPU {gpu}" + # If no AMI ID is provided, use the AMI ID of an existing software stack that has the same base OS, architecture and GPU. + software_stack.ami_id = existing_software_stacks[0].ami_id + + create_software_stack_request = CreateSoftwareStackRequest( + software_stack=software_stack + ) + software_stack = client.create_software_stack( + create_software_stack_request + ).software_stack + + def tear_down() -> None: + delete_software_stack_request = DeleteSoftwareStackRequest( + software_stack=software_stack + ) + client.delete_software_stack(delete_software_stack_request) + + request.addfinalizer(tear_down) + + return software_stack diff --git a/source/tests/integration/framework/fixtures/users/__init__.py b/source/tests/integration/framework/fixtures/users/__init__.py new file mode 100644 index 0000000..6d8d18a --- /dev/null +++ b/source/tests/integration/framework/fixtures/users/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. diff --git a/source/tests/integration/framework/fixtures/users/admin.py b/source/tests/integration/framework/fixtures/users/admin.py new file mode 100644 index 0000000..3a579ba --- /dev/null +++ b/source/tests/integration/framework/fixtures/users/admin.py @@ -0,0 +1,46 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import pytest + +from ideadatamodel import GetUserRequest, ModifyUserRequest # type: ignore +from tests.integration.framework.client.res_client import ResClient +from tests.integration.framework.fixtures.fixture_request import FixtureRequest +from tests.integration.framework.fixtures.res_environment import ResEnvironment +from tests.integration.framework.model.client_auth import ClientAuth + + +@pytest.fixture +def admin( + request: FixtureRequest, res_environment: ResEnvironment, admin_username: str +) -> ClientAuth: + """ + Fixture for the admin user + """ + + client = ResClient(request, res_environment, ClientAuth(username="clusteradmin")) + admin = client.get_user(GetUserRequest(username=admin_username)).user + + is_active = admin.is_active + if not is_active: + # Activate the admin user for running integ tests + admin.is_active = True + admin = client.modify_user(ModifyUserRequest(user=admin)).user + + def tear_down() -> None: + if not is_active: + # Revert the admin user activation status + admin.is_active = False + client.modify_user(ModifyUserRequest(user=admin)) + + request.addfinalizer(tear_down) + + return ClientAuth(username=admin_username) diff --git a/source/tests/integration/framework/fixtures/users/non_admin.py b/source/tests/integration/framework/fixtures/users/non_admin.py new file mode 100644 index 0000000..9526a59 --- /dev/null +++ b/source/tests/integration/framework/fixtures/users/non_admin.py @@ -0,0 +1,46 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import pytest + +from ideadatamodel import GetUserRequest, ModifyUserRequest # type: ignore +from tests.integration.framework.client.res_client import ResClient +from tests.integration.framework.fixtures.fixture_request import FixtureRequest +from tests.integration.framework.fixtures.res_environment import ResEnvironment +from tests.integration.framework.model.client_auth import ClientAuth + + +@pytest.fixture +def non_admin( + request: FixtureRequest, res_environment: ResEnvironment, non_admin_username: str +) -> ClientAuth: + """ + Fixture for the non admin user + """ + + client = ResClient(request, res_environment, ClientAuth(username="clusteradmin")) + non_admin = client.get_user(GetUserRequest(username=non_admin_username)).user + + is_active = non_admin.is_active + if not is_active: + # Activate the non admin user for running integ tests + non_admin.is_active = True + non_admin = client.modify_user(ModifyUserRequest(user=non_admin)).user + + def tear_down() -> None: + if not is_active: + # Revert the non admin user activation status + non_admin.is_active = False + client.modify_user(ModifyUserRequest(user=non_admin)) + + request.addfinalizer(tear_down) + + return ClientAuth(username=non_admin_username) diff --git a/source/tests/integration/framework/model/__init__.py b/source/tests/integration/framework/model/__init__.py new file mode 100644 index 0000000..6d8d18a --- /dev/null +++ b/source/tests/integration/framework/model/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. diff --git a/source/tests/integration/framework/model/api_invocation_context.py b/source/tests/integration/framework/model/api_invocation_context.py new file mode 100644 index 0000000..ee3dc31 --- /dev/null +++ b/source/tests/integration/framework/model/api_invocation_context.py @@ -0,0 +1,38 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from typing import Any, Dict, Type + +from pydantic import BaseModel + +from ideadatamodel import SocaEnvelope, SocaPayloadType, get_payload_as # type: ignore +from tests.integration.framework.model.client_auth import ClientAuth + + +class ApiInvocationContext(BaseModel): + """ + API invocation context that wraps endpoint, authorization, request and response. + """ + + endpoint: str + request: SocaEnvelope + auth: ClientAuth + response: Dict[str, Any] = {} + + def get_response_payload_as( + self, payload_type: Type[SocaPayloadType] + ) -> SocaPayloadType: + return get_payload_as( + payload=self.response.get("payload"), payload_type=payload_type + ) + + def response_is_success(self) -> bool: + return self.response.get("success", False) diff --git a/source/idea/idea-data-model/src/ideadatamodel/analytics/analytics_api.py b/source/tests/integration/framework/model/client_auth.py similarity index 66% rename from source/idea/idea-data-model/src/ideadatamodel/analytics/analytics_api.py rename to source/tests/integration/framework/model/client_auth.py index edc2b39..8173701 100644 --- a/source/idea/idea-data-model/src/ideadatamodel/analytics/analytics_api.py +++ b/source/tests/integration/framework/model/client_auth.py @@ -9,19 +9,12 @@ # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions # and limitations under the License. -__all__ = ( - 'OpenSearchQueryRequest', - 'OpenSearchQueryResult' -) +from pydantic import BaseModel -from ideadatamodel import SocaPayload -from typing import Optional, Dict +class ClientAuth(BaseModel): + """ + Client auth information including the username and groups. + """ - -class OpenSearchQueryRequest(SocaPayload): - data: Optional[Dict] - - -class OpenSearchQueryResult(SocaPayload): - data: Optional[Dict] + username: str diff --git a/source/tests/integration/framework/utils/__init__.py b/source/tests/integration/framework/utils/__init__.py new file mode 100644 index 0000000..6d8d18a --- /dev/null +++ b/source/tests/integration/framework/utils/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. diff --git a/source/tests/integration/framework/utils/ad_sync.py b/source/tests/integration/framework/utils/ad_sync.py new file mode 100644 index 0000000..503a495 --- /dev/null +++ b/source/tests/integration/framework/utils/ad_sync.py @@ -0,0 +1,27 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import time +from typing import Any, Dict + +from tests.integration.framework.utils.remote_command_runner import RemoteCommandRunner + + +def ad_sync( + region: str, + cluster_manager: Dict[str, Any], +) -> None: + ad_sync_commands = ["sudo /opt/idea/python/latest/bin/resctl ldap sync-from-ad"] + remote_command_runner = RemoteCommandRunner(region) + remote_command_runner.run(cluster_manager.get("InstanceId", ""), ad_sync_commands) + + # Wait for the AD sync to complete as this is an async call. + time.sleep(30) diff --git a/source/tests/integration/framework/utils/ec2_utils.py b/source/tests/integration/framework/utils/ec2_utils.py new file mode 100644 index 0000000..2cd4bd7 --- /dev/null +++ b/source/tests/integration/framework/utils/ec2_utils.py @@ -0,0 +1,35 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from typing import Any, Dict + +import boto3 + + +def all_in_service_instances_from_asgs( + auto_scaling_group_names: list[str], + region: str, +) -> list[Dict[str, Any]]: + session = boto3.session.Session(region_name=region) + auto_scaling_client = session.client("autoscaling") + response = auto_scaling_client.describe_auto_scaling_groups( + AutoScalingGroupNames=auto_scaling_group_names + ) + + instances: list[Dict[str, Any]] = [] + for group in response.get("AutoScalingGroups", []): + instances = instances + [ + instance + for instance in group.get("Instances", []) + if instance.get("LifecycleState") == "InService" + ] + + return instances diff --git a/source/tests/integration/framework/utils/remote_command_runner.py b/source/tests/integration/framework/utils/remote_command_runner.py new file mode 100644 index 0000000..89e2751 --- /dev/null +++ b/source/tests/integration/framework/utils/remote_command_runner.py @@ -0,0 +1,80 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import logging +from typing import Any, Optional + +import boto3 + +logger = logging.getLogger(__name__) + + +class RemoteCommandRunner: + def __init__(self, region: str, output_bucket: Optional[str] = None): + session = boto3.session.Session(region_name=region) + self._ssm_client = session.client("ssm") + self._s3_resource = session.resource("s3") + self._output_bucket = output_bucket + + def run(self, instance_id: str, commands: list[str]) -> Any: + logger.debug(f"sending commands {commands} to instance {instance_id}") + + if self._output_bucket: + cmd_result = self._ssm_client.send_command( + InstanceIds=[instance_id], + DocumentName="AWS-RunShellScript", + Parameters={"commands": commands}, + OutputS3BucketName=self._output_bucket, + ) + else: + cmd_result = self._ssm_client.send_command( + InstanceIds=[instance_id], + DocumentName="AWS-RunShellScript", + Parameters={"commands": commands}, + ) + + command_id = cmd_result.get("Command", {}).get("CommandId", "") + + waiter = self._ssm_client.get_waiter("command_executed") + waiter.wait( + CommandId=command_id, + InstanceId=instance_id, + ) + + get_output_result = self._ssm_client.get_command_invocation( + CommandId=command_id, + InstanceId=instance_id, + ) + + output = ( + self._read_s3_output(get_output_result.get("StandardOutputUrl")) + if self._output_bucket + else get_output_result.get("StandardOutputContent") + ) + logger.debug(f"remote commands output: {output}") + + return output + + def _read_s3_output(self, s3_url: str) -> Any: + sections = s3_url.split("/") + # Verify that the url includes the s3 bucket name and object key + assert len(sections) > 5, f"invalid remote command output URL: {s3_url}" + + bucket_name = sections[3] + ssm_output_object_key = "/".join(sections[4:]) + logger.debug( + f"remote command output bucket: {bucket_name} object key: {ssm_output_object_key}" + ) + + ssm_output_object = self._s3_resource.Object( + bucket_name, ssm_output_object_key + ).get() + return ssm_output_object.get("Body").read().decode("utf-8") diff --git a/source/tests/integration/framework/utils/session_utils.py b/source/tests/integration/framework/utils/session_utils.py new file mode 100644 index 0000000..d048dbd --- /dev/null +++ b/source/tests/integration/framework/utils/session_utils.py @@ -0,0 +1,123 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import json +import logging +import time +from typing import Any + +from ideadatamodel import ( # type: ignore + GetSessionConnectionInfoRequest, + GetSessionInfoRequest, + ListSessionsRequest, + VirtualDesktopSession, + VirtualDesktopSessionConnectionInfo, + VirtualDesktopSessionState, +) +from tests.integration.framework.client.res_client import ResClient +from tests.integration.framework.utils.remote_command_runner import RemoteCommandRunner + +logger = logging.getLogger(__name__) + +SESSION_COMPLETE_STATES = [ + VirtualDesktopSessionState.READY, + VirtualDesktopSessionState.ERROR, + VirtualDesktopSessionState.STOPPED, + VirtualDesktopSessionState.DELETED, +] +MAX_WAITING_TIME_FOR_LAUNCHING_SESSION_IN_SEC = 3600 +MAX_WAITING_TIME_FOR_DELETING_SESSION_IN_SEC = 300 +MAX_WAITING_TIME_FOR_SESSION_CONNECTION_COUNT_IN_SEC = 300 + + +def wait_for_launching_session( + client: ResClient, session: VirtualDesktopSession +) -> VirtualDesktopSession: + start_time = time.time() + while time.time() - start_time < MAX_WAITING_TIME_FOR_LAUNCHING_SESSION_IN_SEC: + get_session_info_request = GetSessionInfoRequest(session=session) + get_session_info_response = client.get_session_info(get_session_info_request) + session = get_session_info_response.session + session_state = session.state + + if session_state in SESSION_COMPLETE_STATES: + assert ( + session_state == VirtualDesktopSessionState.READY + ), f"Session {session.name} is in an unexpected state {session_state}: {session.failure_reason}" + + return session + + logger.debug(f"session state: {session_state}") + time.sleep(30) + + assert ( + False + ), f"Failed to launch session {session.name} within {MAX_WAITING_TIME_FOR_LAUNCHING_SESSION_IN_SEC} seconds" + + +def wait_for_deleting_session( + client: ResClient, session: VirtualDesktopSession +) -> None: + start_time = time.time() + while time.time() - start_time < MAX_WAITING_TIME_FOR_DELETING_SESSION_IN_SEC: + list_sessions_response = client.list_sessions(ListSessionsRequest()) + existing_sessions = list_sessions_response.listing + + if existing_sessions and any( + existing_session.dcv_session_id == session.dcv_session_id + for existing_session in existing_sessions + ): + logger.debug(f"session {session.dcv_session_id} is still available") + time.sleep(30) + else: + return + + assert ( + False + ), f"Failed to delete session {session.dcv_session_id} within {MAX_WAITING_TIME_FOR_DELETING_SESSION_IN_SEC} seconds" + + +def wait_for_session_connection_count( + region: str, session: VirtualDesktopSession, count: int +) -> None: + start_time = time.time() + while ( + time.time() - start_time < MAX_WAITING_TIME_FOR_SESSION_CONNECTION_COUNT_IN_SEC + ): + dcv_session_info = describe_dcv_session( + region, session.server.instance_id, session.dcv_session_id + ) + + num_of_connections = dcv_session_info["num-of-connections"] + if num_of_connections is not None and num_of_connections == count: + return + + logger.debug(f"num of connections: {num_of_connections}") + time.sleep(30) + + assert ( + False + ), f"Failed to reach session connection count {count} within {MAX_WAITING_TIME_FOR_SESSION_CONNECTION_COUNT_IN_SEC} seconds" + + +def describe_dcv_session( + region: str, server_instance_id: str, dcv_session_id: str +) -> Any: + remote_command_runner = RemoteCommandRunner(region) + commands = [f"sudo dcv describe-session {dcv_session_id} -j"] + + output = remote_command_runner.run(server_instance_id, commands) + try: + dcv_session_info = json.loads(output) + except json.JSONDecodeError as e: + assert False, e + + return dcv_session_info diff --git a/source/tests/integration/framework/utils/test_mode.py b/source/tests/integration/framework/utils/test_mode.py new file mode 100644 index 0000000..432b0fc --- /dev/null +++ b/source/tests/integration/framework/utils/test_mode.py @@ -0,0 +1,95 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import logging +import threading +import time +from typing import Any, Dict, Optional + +from tests.integration.framework.utils.remote_command_runner import RemoteCommandRunner + +logger = logging.getLogger(__name__) + + +class SetTestModeThread(threading.Thread): + """ + Custom Thread Class for setting the test mode on a server instance + """ + + def __init__( + self, remote_command_runner: RemoteCommandRunner, instance_id: str, enable: bool + ): + super().__init__() + + self._remote_command_runner = remote_command_runner + self._instance_id = instance_id + self._enable = enable + self._exc: Optional[BaseException] = None + + def set_test_mode(self) -> None: + set_test_mode_commands = [ + "sudo sed -i '/^RES_TEST_MODE/d' /etc/environment && " + f"echo 'RES_TEST_MODE={str(self._enable)}' | sudo tee -a /etc/environment && " + "sudo service supervisord restart && " + "sudo /opt/idea/python/3.9.16/bin/supervisorctl start all" + ] + health_check_commands = ["curl https://localhost:8443/healthcheck -k"] + + self._remote_command_runner.run(self._instance_id, set_test_mode_commands) + + start_time = time.process_time() + while time.process_time() - start_time < 30: + try: + output = self._remote_command_runner.run( + self._instance_id, health_check_commands + ) + assert output == '{"success":true}' + logger.debug( + f"server is relaunched successfully. instance id: {self._instance_id}" + ) + return + except: + logger.debug( + f"continue waiting for the server to respond. instance id: {self._instance_id}" + ) + time.sleep(1) + + assert ( + False + ), f"failed to relaunch server in 30 seconds. instance id: {self._instance_id}" + + def run(self) -> None: + try: + self.set_test_mode() + except BaseException as e: + self._exc = e + + def join(self, timeout: Optional[float] = None) -> None: + threading.Thread.join(self, timeout) + if self._exc: + raise self._exc + + +def set_test_mode_for_all_servers( + region: str, + server_instances: list[Dict[str, Any]], + enable: bool, +) -> None: + remote_command_runner = RemoteCommandRunner(region) + + threads = [ + SetTestModeThread(remote_command_runner, instance.get("InstanceId", ""), enable) + for instance in server_instances + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() diff --git a/source/tests/integration/tests/__init__.py b/source/tests/integration/tests/__init__.py new file mode 100644 index 0000000..6d8d18a --- /dev/null +++ b/source/tests/integration/tests/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. diff --git a/source/tests/integration/tests/smoke.py b/source/tests/integration/tests/smoke.py new file mode 100644 index 0000000..5abc4d5 --- /dev/null +++ b/source/tests/integration/tests/smoke.py @@ -0,0 +1,149 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import logging + +import pytest + +from ideadatamodel import ( # type: ignore + CreateSessionRequest, + DeleteSessionRequest, + Project, + SocaMemory, + SocaMemoryUnit, + UpdateSessionRequest, + VirtualDesktopArchitecture, + VirtualDesktopBaseOS, + VirtualDesktopGPU, + VirtualDesktopSchedule, + VirtualDesktopScheduleType, + VirtualDesktopServer, + VirtualDesktopSession, + VirtualDesktopSoftwareStack, + VirtualDesktopWeekSchedule, +) +from tests.integration.framework.client.res_client import ResClient +from tests.integration.framework.fixtures.fixture_request import FixtureRequest +from tests.integration.framework.fixtures.project import project +from tests.integration.framework.fixtures.res_environment import ( + ResEnvironment, + res_environment, +) +from tests.integration.framework.fixtures.session import session +from tests.integration.framework.fixtures.software_stack import software_stack +from tests.integration.framework.fixtures.users.admin import admin +from tests.integration.framework.fixtures.users.non_admin import non_admin +from tests.integration.framework.model.client_auth import ClientAuth +from tests.integration.framework.utils.session_utils import ( + wait_for_session_connection_count, +) + +logger = logging.getLogger(__name__) +MIN_STORAGE = SocaMemory(value=10, unit=SocaMemoryUnit.GB) +MIN_RAM = SocaMemory(value=4, unit=SocaMemoryUnit.GB) + + +@pytest.mark.usefixtures("res_environment") +@pytest.mark.usefixtures("region") +class TestsSmoke(object): + @pytest.mark.usefixtures("admin") + @pytest.mark.parametrize( + "admin_username", + [ + "admin1", + ], + ) + @pytest.mark.usefixtures("non_admin") + @pytest.mark.parametrize( + "non_admin_username", + [ + "user1", + ], + ) + @pytest.mark.parametrize( + "project", + [ + ( + Project( + title="res-integ-test", + name="res-integ-test", + description="RES integ test project", + enable_budgets=False, + ldap_groups=["RESAdministrators", "group_1", "group_2"], + ), + [], + ) + ], + indirect=True, + ) + @pytest.mark.parametrize( + "software_stack", + [ + ( + VirtualDesktopSoftwareStack( + name="res-integ-test-stack", + description="RES integ test software stack", + base_os=VirtualDesktopBaseOS.AMAZON_LINUX2, + architecture=VirtualDesktopArchitecture.X86_64, + min_storage=MIN_STORAGE, + min_ram=MIN_RAM, + gpu=VirtualDesktopGPU.NO_GPU, + ), + "project", + ) + ], + indirect=True, + ) + @pytest.mark.parametrize( + "session", + [ + ( + VirtualDesktopSession( + base_os=VirtualDesktopBaseOS.AMAZON_LINUX2, + name="VirtualDesktop1", + server=VirtualDesktopServer( + instance_type="t3.medium", root_volume_size=MIN_STORAGE + ), + description="VirtualDesktop1", + hibernation_enabled=False, + ), + "project", + "software_stack", + ) + ], + indirect=True, + ) + def test_end_to_end_succeed( + self, + request: FixtureRequest, + region: str, + admin: ClientAuth, + admin_username: str, + non_admin: ClientAuth, + non_admin_username: str, + res_environment: ResEnvironment, + project: Project, + software_stack: VirtualDesktopSoftwareStack, + session: VirtualDesktopSession, + ) -> None: + """ + Test the end to end workflow: + 1. Create a project, software stack and virtual desktop session. + 2. Join the virtual desktop session from a headless web browser and close it. + 3. Clean up the test project, software stack and virtual desktop session. + """ + client = ResClient(request, res_environment, non_admin) + web_driver = client.join_session(session) + wait_for_session_connection_count(region, session, 1) + + logger.info(f"leaving session {session.dcv_session_id}...") + web_driver.quit() + wait_for_session_connection_count(region, session, 0) diff --git a/source/tests/unit/idea-cluster-manager/conftest.py b/source/tests/unit/idea-cluster-manager/conftest.py index ae15a58..71ee91e 100644 --- a/source/tests/unit/idea-cluster-manager/conftest.py +++ b/source/tests/unit/idea-cluster-manager/conftest.py @@ -21,17 +21,24 @@ CognitoUserPool, CognitoUserPoolOptions, ) +from ideaclustermanager.app.auth.api_authorization_service import ( + ClusterManagerApiAuthorizationService, +) from ideaclustermanager.app.projects.projects_service import ProjectsService +from ideaclustermanager.app.shared_filesystem.shared_filesystem_service import ( + SharedFilesystemService, +) from ideaclustermanager.app.snapshots.snapshots_service import SnapshotsService from ideaclustermanager.app.tasks.task_manager import TaskManager from ideasdk.auth import TokenService, TokenServiceOptions from ideasdk.aws import AwsClientProvider, AWSUtil, EC2InstanceTypesDB from ideasdk.client.evdi_client import EvdiClient from ideasdk.context import SocaContextOptions -from ideasdk.utils import Utils +from ideasdk.utils import GroupNameHelper, Utils from ideatestutils import IdeaTestProps, MockConfig, MockInstanceTypes from ideatestutils.dynamodb.dynamodb_local import DynamoDBLocal from mock_ldap_client import MockLdapClient +from mock_vdc_client import MockVirtualDesktopControllerClient from ideadatamodel import SocaAnyPayload @@ -81,6 +88,7 @@ def mock_function(*_, **__): mock_cognito_idp = SocaAnyPayload() mock_cognito_idp.admin_create_user = mock_function + mock_cognito_idp.admin_remove_user_from_group = mock_function mock_cognito_idp.admin_delete_user = mock_function mock_cognito_idp.admin_add_user_to_group = mock_function mock_cognito_idp.admin_get_user = mock_function @@ -208,7 +216,7 @@ def create_mock_boto_session(**_): ) context.ldap_client = MockLdapClient(context=context) - user_pool = CognitoUserPool( + context.user_pool = CognitoUserPool( context=context, options=CognitoUserPoolOptions( user_pool_id=context.config().get_string( @@ -222,16 +230,6 @@ def create_mock_boto_session(**_): ), ) - context.accounts = AccountsService( - context=context, - ldap_client=context.ldap_client, - user_pool=user_pool, - evdi_client=EvdiClient(context=context), - task_manager=context.task_manager, - token_service=None, - ) - context.accounts.create_defaults() - context.token_service = TokenService( context=context, options=TokenServiceOptions( @@ -247,14 +245,34 @@ def create_mock_boto_session(**_): ), ) + context.accounts = AccountsService( + context=context, + ldap_client=context.ldap_client, + user_pool=context.user_pool, + evdi_client=EvdiClient(context=context), + task_manager=context.task_manager, + token_service=context.token_service, + ) + context.accounts.create_defaults() + + # api authorization service + context.api_authorization_service = ClusterManagerApiAuthorizationService( + accounts=context.accounts + ) + + context.vdc_client = MockVirtualDesktopControllerClient() + context.projects = ProjectsService( context=context, accounts_service=context.accounts, task_manager=context.task_manager, + vdc_client=context.vdc_client, ) context.snapshots = SnapshotsService(context=context) + context.shared_filesystem = SharedFilesystemService(context=context) + yield context print("cluster manager context clean-up ...") diff --git a/source/tests/unit/idea-cluster-manager/mock_vdc_client.py b/source/tests/unit/idea-cluster-manager/mock_vdc_client.py new file mode 100644 index 0000000..39aa768 --- /dev/null +++ b/source/tests/unit/idea-cluster-manager/mock_vdc_client.py @@ -0,0 +1,69 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +from typing import Optional + +from ideasdk.client.vdc_client import AbstractVirtualDesktopControllerClient + +from ideadatamodel import ( + VirtualDesktopPermission, + VirtualDesktopPermissionProfile, + VirtualDesktopSession, + VirtualDesktopSoftwareStack, +) + + +class MockVirtualDesktopControllerClient(AbstractVirtualDesktopControllerClient): + sessions: list[VirtualDesktopSession] = [] + software_stacks: list[VirtualDesktopSoftwareStack] = [] + base_permissions: list[VirtualDesktopPermission] = [] + + def list_sessions_by_project_id( + self, project_id: str + ) -> list[VirtualDesktopSession]: + return self.sessions + + def list_software_stacks_by_project_id( + self, project_id: str + ) -> list[VirtualDesktopSoftwareStack]: + return self.software_stacks + + def get_base_permissions(self) -> list[VirtualDesktopPermission]: + return self.base_permissions + + def create_permission_profile( + self, + profile: VirtualDesktopPermissionProfile, + ) -> VirtualDesktopPermissionProfile: + pass + + def delete_permission_profile(self, profile_id: str) -> None: + pass + + def get_permission_profile( + self, profile_id: str + ) -> VirtualDesktopPermissionProfile: + pass + + def get_software_stacks_by_name( + self, stack_name: str + ) -> list[VirtualDesktopSoftwareStack]: + return self.software_stacks + + def create_software_stack( + self, software_stack: VirtualDesktopSoftwareStack + ) -> VirtualDesktopSoftwareStack: + pass + + def delete_software_stack( + self, software_stack: VirtualDesktopSoftwareStack + ) -> None: + pass diff --git a/source/tests/unit/idea-cluster-manager/snapshot/__init__.py b/source/tests/unit/idea-cluster-manager/snapshot/__init__.py new file mode 100644 index 0000000..6d8d18a --- /dev/null +++ b/source/tests/unit/idea-cluster-manager/snapshot/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. diff --git a/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/__init__.py b/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/__init__.py new file mode 100644 index 0000000..6d8d18a --- /dev/null +++ b/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. diff --git a/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/test_filesystem_cluster_settings_table_merger.py b/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/test_filesystem_cluster_settings_table_merger.py new file mode 100644 index 0000000..1d48ec1 --- /dev/null +++ b/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/test_filesystem_cluster_settings_table_merger.py @@ -0,0 +1,620 @@ +import unittest + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from ideaclustermanager.app.snapshots.apply_snapshot_merge_table.filesystems_cluster_settings_table_merger import ( + FileSystemsClusterSettingTableMerger, +) +from ideaclustermanager.app.snapshots.apply_snapshot_merge_table.merge_table import ( + MergeTable, +) +from ideaclustermanager.app.snapshots.helpers.apply_snapshot_observability_helper import ( + ApplySnapshotObservabilityHelper, +) +from ideaclustermanager.app.snapshots.helpers.merged_record_utils import ( + MergedRecordActionType, + MergedRecordDelta, +) + +from ideadatamodel import ( + AddFileSystemToProjectRequest, + OffboardFileSystemRequest, + OnboardEFSFileSystemRequest, + errorcodes, + exceptions, +) + + +@pytest.fixture(scope="class") +def monkeypatch_for_class(request): + request.cls.monkeypatch = MonkeyPatch() + + +@pytest.fixture(scope="class") +def context_for_class(request, context): + request.cls.context = context + + +def dummy_unique_resource_id_generator(key, dedup_id): + return f"{key}_{dedup_id}" + + +DUMMY_DEDUP_ID = "dedup_id" + +DUMMY_EFS_1 = "dummy_efs_1" +DUMMY_ONTAP_1 = "dummy_ontap_1" + +DUMMY_EFS_1_DEDUP = dummy_unique_resource_id_generator(DUMMY_EFS_1, DUMMY_DEDUP_ID) +DUMMY_ONTAP_1_DEDUP = dummy_unique_resource_id_generator(DUMMY_ONTAP_1, DUMMY_DEDUP_ID) + +DUMMY_EFS_FILESYSTEM_ID = "fs-efs-1-id" +DUMMY_ONTAP_FLESYSTEM_ID = "fs-ontap-1-id" + +DUMMY_PROJECT_NAME = "dummy_project" + +dummy_snapshot_filesystem_details_in_dict = { + DUMMY_EFS_1: { + "efs": { + "cloudwatch_monitoring": "false", + "dns": "fs-efs-1-id.efs.us-east-1.amazonaws.com", + "encrypted": "true", + "file_system_id": DUMMY_EFS_FILESYSTEM_ID, + "kms_key_id": "arn:aws:kms:us-east-1:1234", + "performance_mode": "generalPurpose", + "removal_policy": "RETAIN", + "throughput_mode": "elastic", + "transition_to_ia": "AFTER_30_DAYS", + }, + "mount_dir": "/efs-1", + "mount_options": "nfs4 nfsvers=4.1,rsize=1048576,wsize=1048576,hard,timeo=600,retrans=2,noresvport 0 0", + "projects": [DUMMY_PROJECT_NAME], + "provider": "efs", + "scope": ["project"], + "title": "efs-1", + }, + DUMMY_ONTAP_1: { + "fsx_netapp_ontap": { + "file_system_id": DUMMY_ONTAP_FLESYSTEM_ID, + "removal_policy": "RETAIN", + "svm": { + "iscsi_dns": "iscsi.svm-1234.fs-1234.fsx.us-east-1.amazonaws.com", + "management_dns": "svm-1234.fs-1234.fsx.us-east-1.amazonaws.com", + "nfs_dns": "svm-1234.fs-1234.fsx.us-east-1.amazonaws.com", + "smb_dns": "null", + "svm_id": "svm-1234", + }, + "use_existing_fs": "true", + "volume": { + "cifs_share_name": "share_name", + "security_style": "null", + "volume_id": "fsvol-1234", + "volume_path": "/vol1", + }, + }, + "mount_dir": "/ontap-1", + "mount_drive": "X", + "mount_options": "nfs4 nfsvers=4.1,rsize=1048576,wsize=1048576,hard,timeo=600,retrans=2,noresvport 0 0", + "projects": [DUMMY_PROJECT_NAME], + "provider": "fsx_netapp_ontap", + "scope": ["project"], + "title": "ontap_1", + }, +} + + +@pytest.mark.usefixtures("monkeypatch_for_class") +@pytest.mark.usefixtures("context_for_class") +class FileSystemClusterSettingsTableMergerTest(unittest.TestCase): + def setUp(self) -> None: + self.monkeypatch.setattr( + FileSystemsClusterSettingTableMerger, + "extract_filesystem_details_to_dict", + lambda x, y: dummy_snapshot_filesystem_details_in_dict, + ) + self.monkeypatch.setattr( + FileSystemsClusterSettingTableMerger, + "_wait_for_onboarded_filesystem_to_sync_to_config_tree", + lambda x, y, z: None, + ) + self.monkeypatch.setattr( + FileSystemsClusterSettingTableMerger, + "get_list_of_onboarded_filesystem_ids", + lambda x, y: [], + ) + self.monkeypatch.setattr( + FileSystemsClusterSettingTableMerger, + "get_list_of_accessible_filesystem_ids", + lambda x, y: [DUMMY_EFS_FILESYSTEM_ID, DUMMY_ONTAP_FLESYSTEM_ID], + ) + self.monkeypatch.setattr( + self.context.shared_filesystem, "onboard_efs_filesystem", lambda x: None + ) + self.monkeypatch.setattr( + self.context.shared_filesystem, "onboard_ontap_filesystem", lambda x: None + ) + self.monkeypatch.setattr( + self.context.shared_filesystem, "add_filesystem_to_project", lambda x: None + ) + self.monkeypatch.setattr( + MergeTable, + "unique_resource_id_generator", + dummy_unique_resource_id_generator, + ) + self.context.projects.projects_dao.create_project( + { + "project_id": "dummy_project_id", + "created_on": 0, + "description": "dummy_project", + "enable_budgets": False, + "enabled": True, + "ldap_groups": ["test_group_1"], + "name": DUMMY_PROJECT_NAME, + "title": "dummy_project", + "updated_on": 0, + "users": [], + } + ) + + def test_filesystems_cluster_settings_table_merger_new_filesystem_succeed(self): + def dummy_get_filesystem(filesystem_name): + raise exceptions.soca_exception( + error_code=errorcodes.FILESYSTEM_NOT_FOUND, + message=f"could not find filesystem {filesystem_name}", + ) + + self.monkeypatch.setattr( + self.context.shared_filesystem, + "get_filesystem", + lambda filesystem_name: dummy_get_filesystem(filesystem_name), + ) + + merger = FileSystemsClusterSettingTableMerger() + + record_deltas, success = merger.merge( + self.context, + [], + DUMMY_DEDUP_ID, + {}, + ApplySnapshotObservabilityHelper( + self.context.logger("filesystem_cluster_settings_table_merger") + ), + ) + + assert success + assert len(record_deltas) == 2 + + merged_filesystem_names = [] + for record in record_deltas: + assert record.action_performed == MergedRecordActionType.CREATE + merged_filesystem_names.append(list(record.resolved_record.keys())[0]) + + assert DUMMY_EFS_1 in merged_filesystem_names + assert DUMMY_EFS_1_DEDUP not in merged_filesystem_names + + assert DUMMY_ONTAP_1 in merged_filesystem_names + assert DUMMY_ONTAP_1_DEDUP not in merged_filesystem_names + + def test_filesystem_cluster_settings_table_merger_non_project_scope_filesystem_skipped( + self, + ): + efs_details = { + "efs": { + "cloudwatch_monitoring": "false", + "dns": "fs-efs-1-id.efs.us-east-1.amazonaws.com", + "encrypted": "true", + "file_system_id": DUMMY_EFS_FILESYSTEM_ID, + "kms_key_id": "arn:aws:kms:us-east-1:1234", + "performance_mode": "generalPurpose", + "removal_policy": "RETAIN", + "throughput_mode": "elastic", + "transition_to_ia": "AFTER_30_DAYS", + }, + "mount_dir": "/efs-1", + "mount_options": "nfs4 nfsvers=4.1,rsize=1048576,wsize=1048576,hard,timeo=600,retrans=2,noresvport 0 0", + "projects": ["test-p-1", "efs-p", "res-integ-test-1"], + "provider": "efs", + "scope": ["project"], + "title": "efs-1", + } + dummy_non_project_scope_filesystem_details_in_dict = { + "dummy_cluster_efs": {**efs_details, "scope": ["cluster"]}, + "dummy_module_efs": {**efs_details, "scope": ["module"]}, + } + + self.monkeypatch.setattr( + FileSystemsClusterSettingTableMerger, + "extract_filesystem_details_to_dict", + lambda x, y: dummy_non_project_scope_filesystem_details_in_dict, + ) + + merger = FileSystemsClusterSettingTableMerger() + + record_deltas, success = merger.merge( + self.context, + [], + DUMMY_DEDUP_ID, + {}, + ApplySnapshotObservabilityHelper( + self.context.logger("filesystem_cluster_settings_table_merger") + ), + ) + + assert success + assert len(record_deltas) == 0 + + def test_filesystems_cluster_settings_table_merger_existing_filesystem_succeed( + self, + ): + def dummy_get_filesystem(filesystem_name): + if filesystem_name == DUMMY_EFS_1: + return {} + raise exceptions.soca_exception( + error_code=errorcodes.FILESYSTEM_NOT_FOUND, + message=f"could not find filesystem {filesystem_name}", + ) + + self.monkeypatch.setattr( + self.context.shared_filesystem, + "get_filesystem", + lambda filesystem_name: dummy_get_filesystem(filesystem_name), + ) + + merger = FileSystemsClusterSettingTableMerger() + record_deltas, success = merger.merge( + self.context, + [], + DUMMY_DEDUP_ID, + {}, + ApplySnapshotObservabilityHelper( + self.context.logger("filesystem_cluster_settings_table_merger") + ), + ) + + assert success + assert len(record_deltas) == 2 + + merged_filesystem_names = [] + for record in record_deltas: + assert record.action_performed == MergedRecordActionType.CREATE + merged_filesystem_names.append(list(record.resolved_record.keys())[0]) + + assert DUMMY_EFS_1 not in merged_filesystem_names + assert DUMMY_EFS_1_DEDUP in merged_filesystem_names + + assert DUMMY_ONTAP_1 in merged_filesystem_names + assert DUMMY_ONTAP_1_DEDUP not in merged_filesystem_names + + def test_filesystem_cluster_settings_table_merger_adds_existing_project_to_file_system( + self, + ): + def dummy_get_filesystem(filesystem_name): + return {} + + self.monkeypatch.setattr( + self.context.shared_filesystem, + "get_filesystem", + lambda filesystem_name: dummy_get_filesystem(filesystem_name), + ) + + dummy_filesystem_details_in_dict = { + "dummy_cluster_efs": { + "efs": { + "cloudwatch_monitoring": "false", + "dns": "fs-efs-1-id.efs.us-east-1.amazonaws.com", + "encrypted": "true", + "file_system_id": DUMMY_EFS_FILESYSTEM_ID, + "kms_key_id": "arn:aws:kms:us-east-1:1234", + "performance_mode": "generalPurpose", + "removal_policy": "RETAIN", + "throughput_mode": "elastic", + "transition_to_ia": "AFTER_30_DAYS", + }, + "mount_dir": "/efs-1", + "mount_options": "nfs4 nfsvers=4.1,rsize=1048576,wsize=1048576,hard,timeo=600,retrans=2,noresvport 0 0", + "projects": [DUMMY_PROJECT_NAME], + "provider": "efs", + "scope": ["project"], + "title": "efs-1", + }, + } + + add_file_system_to_project_called = False + + def dummy_add_filesystem_to_project(request: AddFileSystemToProjectRequest): + nonlocal add_file_system_to_project_called + add_file_system_to_project_called = True + + self.monkeypatch.setattr( + FileSystemsClusterSettingTableMerger, + "extract_filesystem_details_to_dict", + lambda x, y: dummy_filesystem_details_in_dict, + ) + + self.monkeypatch.setattr( + self.context.shared_filesystem, + "add_filesystem_to_project", + dummy_add_filesystem_to_project, + ) + + merger = FileSystemsClusterSettingTableMerger() + merger.merge( + self.context, + [], + DUMMY_DEDUP_ID, + {}, + ApplySnapshotObservabilityHelper( + self.context.logger("filesystem_cluster_settings_table_merger") + ), + ) + + assert add_file_system_to_project_called + + def test_filesystem_cluster_settings_table_merger_skips_adding_nonexisting_project_to_file_system( + self, + ): + def dummy_get_filesystem(filesystem_name): + return {} + + self.monkeypatch.setattr( + self.context.shared_filesystem, + "get_filesystem", + lambda filesystem_name: dummy_get_filesystem(filesystem_name), + ) + + dummy_filesystem_details_in_dict = { + "dummy_cluster_efs": { + "efs": { + "cloudwatch_monitoring": "false", + "dns": "fs-efs-1-id.efs.us-east-1.amazonaws.com", + "encrypted": "true", + "file_system_id": DUMMY_EFS_FILESYSTEM_ID, + "kms_key_id": "arn:aws:kms:us-east-1:1234", + "performance_mode": "generalPurpose", + "removal_policy": "RETAIN", + "throughput_mode": "elastic", + "transition_to_ia": "AFTER_30_DAYS", + }, + "mount_dir": "/efs-1", + "mount_options": "nfs4 nfsvers=4.1,rsize=1048576,wsize=1048576,hard,timeo=600,retrans=2,noresvport 0 0", + "projects": ["nonexisting-project"], + "provider": "efs", + "scope": ["project"], + "title": "efs-1", + }, + } + + add_file_system_to_project_called = False + + def dummy_add_filesystem_to_project(request: AddFileSystemToProjectRequest): + nonlocal add_file_system_to_project_called + add_file_system_to_project_called = True + + self.monkeypatch.setattr( + FileSystemsClusterSettingTableMerger, + "extract_filesystem_details_to_dict", + lambda x, y: dummy_filesystem_details_in_dict, + ) + + self.monkeypatch.setattr( + self.context.shared_filesystem, + "add_filesystem_to_project", + dummy_add_filesystem_to_project, + ) + + merger = FileSystemsClusterSettingTableMerger() + merger.merge( + self.context, + [], + DUMMY_DEDUP_ID, + {}, + ApplySnapshotObservabilityHelper( + self.context.logger("filesystem_cluster_settings_table_merger") + ), + ) + + assert not add_file_system_to_project_called + + def test_filesystem_cluster_settings_table_merger_adds_dedup_project_to_filesystem_if_exists( + self, + ): + self.context.projects.projects_dao.create_project( + { + "project_id": "dummy_project_id", + "created_on": 0, + "description": "dummy_project", + "enable_budgets": False, + "enabled": True, + "ldap_groups": ["test_group_1"], + "name": dummy_unique_resource_id_generator( + DUMMY_PROJECT_NAME, DUMMY_DEDUP_ID + ), + "title": "dummy_project", + "updated_on": 0, + "users": [], + } + ) + + def dummy_get_filesystem(filesystem_name): + return {} + + self.monkeypatch.setattr( + self.context.shared_filesystem, + "get_filesystem", + lambda filesystem_name: dummy_get_filesystem(filesystem_name), + ) + + dummy_filesystem_details_in_dict = { + "dummy_cluster_efs": { + "efs": { + "cloudwatch_monitoring": "false", + "dns": "fs-efs-1-id.efs.us-east-1.amazonaws.com", + "encrypted": "true", + "file_system_id": DUMMY_EFS_FILESYSTEM_ID, + "kms_key_id": "arn:aws:kms:us-east-1:1234", + "performance_mode": "generalPurpose", + "removal_policy": "RETAIN", + "throughput_mode": "elastic", + "transition_to_ia": "AFTER_30_DAYS", + }, + "mount_dir": "/efs-1", + "mount_options": "nfs4 nfsvers=4.1,rsize=1048576,wsize=1048576,hard,timeo=600,retrans=2,noresvport 0 0", + "projects": [DUMMY_PROJECT_NAME], + "provider": "efs", + "scope": ["project"], + "title": "efs-1", + }, + } + + add_file_system_to_project_called = False + + def dummy_add_filesystem_to_project(request: AddFileSystemToProjectRequest): + nonlocal add_file_system_to_project_called + add_file_system_to_project_called = True + + assert request.project_name == dummy_unique_resource_id_generator( + DUMMY_PROJECT_NAME, DUMMY_DEDUP_ID + ) + + self.monkeypatch.setattr( + FileSystemsClusterSettingTableMerger, + "extract_filesystem_details_to_dict", + lambda x, y: dummy_filesystem_details_in_dict, + ) + + self.monkeypatch.setattr( + self.context.shared_filesystem, + "add_filesystem_to_project", + dummy_add_filesystem_to_project, + ) + + merger = FileSystemsClusterSettingTableMerger() + merger.merge( + self.context, + [], + DUMMY_DEDUP_ID, + {}, + ApplySnapshotObservabilityHelper( + self.context.logger("filesystem_cluster_settings_table_merger") + ), + ) + + assert add_file_system_to_project_called + + def test_filesystem_cluter_settings_table_merger_skips_already_onboarded_filesystem( + self, + ): + def dummy_get_filesystem(filesystem_name): + return {} + + self.monkeypatch.setattr( + self.context.shared_filesystem, + "get_filesystem", + lambda filesystem_name: dummy_get_filesystem(filesystem_name), + ) + + def dummy_onboard_efs_filesystem(request: OnboardEFSFileSystemRequest): + if ( + request.filesystem_name == DUMMY_EFS_1 + or request.filesystem_name == DUMMY_EFS_1_DEDUP + ): + raise exceptions.soca_exception( + error_code=errorcodes.FILESYSTEM_ALREADY_ONBOARDED, + message="dummy_message", + ) + + self.monkeypatch.setattr( + self.context.shared_filesystem, + "onboard_efs_filesystem", + dummy_onboard_efs_filesystem, + ) + + merger = FileSystemsClusterSettingTableMerger() + record_deltas, success = merger.merge( + self.context, + [], + DUMMY_DEDUP_ID, + {}, + ApplySnapshotObservabilityHelper( + self.context.logger("filesystem_cluster_settings_table_merger") + ), + ) + + assert success + assert len(record_deltas) == 1 + + def test_filesystem_cluter_settings_table_merger_skips_not_accessible_filesystem( + self, + ): + def dummy_get_filesystem(filesystem_name): + return {} + + self.monkeypatch.setattr( + self.context.shared_filesystem, + "get_filesystem", + lambda filesystem_name: dummy_get_filesystem(filesystem_name), + ) + + def dummy_onboard_efs_filesystem(request: OnboardEFSFileSystemRequest): + if ( + request.filesystem_name == DUMMY_EFS_1 + or request.filesystem_name == DUMMY_EFS_1_DEDUP + ): + raise exceptions.soca_exception( + error_code=errorcodes.FILESYSTEM_NOT_IN_VPC, + message="dummy_message", + ) + + self.monkeypatch.setattr( + self.context.shared_filesystem, + "onboard_efs_filesystem", + dummy_onboard_efs_filesystem, + ) + + merger = FileSystemsClusterSettingTableMerger() + record_deltas, success = merger.merge( + self.context, + [], + DUMMY_DEDUP_ID, + {}, + ApplySnapshotObservabilityHelper( + self.context.logger("filesystem_cluster_settings_table_merger") + ), + ) + + assert success + assert len(record_deltas) == 1 + + def test_filesystem_cluster_settings_table_merger_rollback_succeed(self): + offboard_filesystem_called = False + + def dummy_offboard_filesystem(request: OffboardFileSystemRequest): + nonlocal offboard_filesystem_called + offboard_filesystem_called = True + + assert request.filesystem_name == DUMMY_EFS_1 + + self.monkeypatch.setattr( + self.context.shared_filesystem, + "offboard_filesystem", + lambda filesystem_name: dummy_offboard_filesystem(filesystem_name), + ) + + delta_records = [ + MergedRecordDelta( + original_record={}, + snapshot_record={}, + resolved_record={DUMMY_EFS_1: {}}, + action_performed=MergedRecordActionType.CREATE, + ), + ] + + merger = FileSystemsClusterSettingTableMerger() + merger.rollback( + self.context, + delta_records, + ApplySnapshotObservabilityHelper( + self.context.logger("filesystem_cluster_settings_table_merger") + ), + ) + + assert offboard_filesystem_called diff --git a/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/test_permission_profiles_table_merger.py b/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/test_permission_profiles_table_merger.py new file mode 100644 index 0000000..41dfcb9 --- /dev/null +++ b/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/test_permission_profiles_table_merger.py @@ -0,0 +1,263 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import unittest + +import ideaclustermanager.app.snapshots.helpers.db_utils as db_utils +import pytest +from _pytest.monkeypatch import MonkeyPatch +from ideaclustermanager.app.snapshots.apply_snapshot_merge_table.permission_profiles_table_merger import ( + PermissionProfilesTableMerger, +) +from ideaclustermanager.app.snapshots.helpers.apply_snapshot_observability_helper import ( + ApplySnapshotObservabilityHelper, +) +from ideaclustermanager.app.snapshots.helpers.merged_record_utils import ( + MergedRecordActionType, + MergedRecordDelta, +) +from ideasdk.utils.utils import Utils + +from ideadatamodel import VirtualDesktopPermissionProfile, errorcodes, exceptions +from ideadatamodel.api.api_model import ApiAuthorization, ApiAuthorizationType + + +@pytest.fixture(scope="class") +def monkeypatch_for_class(request): + request.cls.monkeypatch = MonkeyPatch() + + +@pytest.fixture(scope="class") +def context_for_class(request, context): + request.cls.context = context + + +@pytest.mark.usefixtures("monkeypatch_for_class") +@pytest.mark.usefixtures("context_for_class") +class TestPermissionProfilesTableMerger(unittest.TestCase): + def setUp(self): + self.monkeypatch.setattr( + self.context.token_service, "decode_token", lambda token: {} + ) + self.monkeypatch.setattr( + self.context.api_authorization_service, + "get_authorization", + lambda decoded_token: ApiAuthorization(type=ApiAuthorizationType.USER), + ) + + def test_permission_profiles_table_resolver_merge_new_permission_profile_succeed( + self, + ): + create_permission_profile_called = False + + def _create_permission_profile_mock(permission_profile): + nonlocal create_permission_profile_called + create_permission_profile_called = True + assert permission_profile.profile_id == "test_permission_profile" + return permission_profile + + self.monkeypatch.setattr( + self.context.vdc_client, + "create_permission_profile", + _create_permission_profile_mock, + ) + + def _get_permission_profile_mock(_profile_id): + raise exceptions.SocaException(errorcodes.INVALID_PARAMS) + + self.monkeypatch.setattr( + self.context.vdc_client, + "get_permission_profile", + _get_permission_profile_mock, + ) + + table_data_to_merge = [ + {db_utils.PERMISSION_PROFILE_DB_HASH_KEY: "test_permission_profile"}, + ] + + resolver = PermissionProfilesTableMerger() + record_deltas, success = resolver.merge( + self.context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper( + self.context.logger("permission_profiles_table_resolver") + ), + ) + + assert success + assert len(record_deltas) == 1 + assert record_deltas[0].original_record is None + assert ( + record_deltas[0].snapshot_record.get( + db_utils.PERMISSION_PROFILE_DB_HASH_KEY + ) + == f"test_permission_profile" + ) + assert ( + record_deltas[0].resolved_record.get( + db_utils.PERMISSION_PROFILE_DB_HASH_KEY + ) + == f"test_permission_profile" + ) + assert record_deltas[0].action_performed == MergedRecordActionType.CREATE + assert create_permission_profile_called + + def test_permission_profiles_table_resolver_merge_existing_permission_profile_succeed( + self, + ): + create_permission_profile_called = False + + def _create_permission_profile_mock(permission_profile): + nonlocal create_permission_profile_called + create_permission_profile_called = True + assert permission_profile.profile_id == "test_permission_profile_dedup_id" + return permission_profile + + self.monkeypatch.setattr( + self.context.vdc_client, + "create_permission_profile", + _create_permission_profile_mock, + ) + self.monkeypatch.setattr( + self.context.vdc_client, + "get_permission_profile", + lambda _profile_id: VirtualDesktopPermissionProfile( + profile_id="test_permission_profile" + ), + ) + + table_data_to_merge = [ + {db_utils.PERMISSION_PROFILE_DB_HASH_KEY: "test_permission_profile"}, + ] + + resolver = PermissionProfilesTableMerger() + record_deltas, success = resolver.merge( + self.context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper( + self.context.logger("permission_profiles_table_resolver") + ), + ) + + assert success + assert len(record_deltas) == 1 + assert record_deltas[0].original_record is None + assert ( + record_deltas[0].snapshot_record.get( + db_utils.PERMISSION_PROFILE_DB_HASH_KEY + ) + == f"test_permission_profile" + ) + assert ( + record_deltas[0].resolved_record.get( + db_utils.PERMISSION_PROFILE_DB_HASH_KEY + ) + == f"test_permission_profile_dedup_id" + ) + assert record_deltas[0].action_performed == MergedRecordActionType.CREATE + assert create_permission_profile_called + + def test_permission_profiles_table_resolver_rollback_original_data_succeed(self): + test_profile_id = f"test_permission_profile_dedup_id" + delete_permission_profile_called = False + + def _delete_permission_profile_mock(profile_id): + nonlocal delete_permission_profile_called + delete_permission_profile_called = True + assert profile_id == test_profile_id + + self.monkeypatch.setattr( + self.context.vdc_client, + "delete_permission_profile", + _delete_permission_profile_mock, + ) + + delta_records = [ + MergedRecordDelta( + snapshot_record={ + db_utils.PERMISSION_PROFILE_DB_HASH_KEY: test_profile_id + }, + resolved_record={ + db_utils.PERMISSION_PROFILE_DB_HASH_KEY: test_profile_id + }, + action_performed=MergedRecordActionType.CREATE, + ), + ] + + resolver = PermissionProfilesTableMerger() + resolver.rollback( + self.context, + delta_records, + ApplySnapshotObservabilityHelper( + self.context.logger("users_table_resolver") + ), + ) + + assert delete_permission_profile_called + + def test_permission_profiles_table_resolver_ignore_unchanged_permission_profile_succeed( + self, + ): + create_permission_profile_called = False + + def _create_permission_profile_mock(permission_profile): + nonlocal create_permission_profile_called + create_permission_profile_called = True + assert permission_profile == VirtualDesktopPermissionProfile( + profile_id="test_permission_profile", + title="", + description="", + permissions=[], + created_on=Utils.to_datetime(0), + updated_on=Utils.to_datetime(0), + ) + return permission_profile + + self.monkeypatch.setattr( + self.context.vdc_client, + "create_permission_profile", + _create_permission_profile_mock, + ) + self.monkeypatch.setattr( + self.context.vdc_client, + "get_permission_profile", + lambda _profile_id: VirtualDesktopPermissionProfile( + profile_id="test_permission_profile", + title="", + description="", + permissions=[], + created_on=Utils.to_datetime(0), + updated_on=Utils.to_datetime(0), + ), + ) + + table_data_to_merge = [ + {db_utils.PERMISSION_PROFILE_DB_HASH_KEY: "test_permission_profile"}, + ] + + resolver = PermissionProfilesTableMerger() + record_deltas, success = resolver.merge( + self.context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper( + self.context.logger("permission_profiles_table_resolver") + ), + ) + + assert success + assert not create_permission_profile_called + assert len(record_deltas) == 0 diff --git a/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/test_projects_table_merger.py b/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/test_projects_table_merger.py new file mode 100644 index 0000000..d7d72a9 --- /dev/null +++ b/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/test_projects_table_merger.py @@ -0,0 +1,379 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import pytest +from ideaclustermanager import AppContext +from ideaclustermanager.app.snapshots.apply_snapshot_merge_table.projects_table_merger import ( + ProjectsTableMerger, +) +from ideaclustermanager.app.snapshots.helpers.apply_snapshot_observability_helper import ( + ApplySnapshotObservabilityHelper, +) +from ideaclustermanager.app.snapshots.helpers.merged_record_utils import ( + MergedRecordActionType, + MergedRecordDelta, +) +from ideasdk.utils.utils import Utils + +from ideadatamodel import ( + GetProjectRequest, + UpdateProjectRequest, + UpdateProjectResult, + errorcodes, + exceptions, +) + + +def test_projects_table_merger_merge_new_project_succeed( + context: AppContext, monkeypatch +): + table_data_to_merge = [ + { + "project_id": "test_project_id", + "created_on": 0, + "description": "test_project", + "enable_budgets": False, + "enabled": True, + "ldap_groups": [], + "name": "test_project", + "title": "test_project", + "updated_on": 0, + "users": [], + } + ] + + resolver = ProjectsTableMerger() + record_deltas, success = resolver.merge( + context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper(context.logger("projects_table_resolver")), + ) + assert success + assert len(record_deltas) == 1 + assert record_deltas[0].original_record is None + assert record_deltas[0].snapshot_record.get("name") == "test_project" + assert record_deltas[0].resolved_record.get("name") == "test_project" + assert record_deltas[0].action_performed == MergedRecordActionType.CREATE + + imported_project = context.projects.get_project( + GetProjectRequest(project_name="test_project") + ).project + assert imported_project is not None + + +def test_users_table_merger_merge_existing_project_succeed( + context: AppContext, monkeypatch +): + context.projects.projects_dao.create_project( + { + "project_id": "test_project_id_1", + "created_on": 0, + "description": "", + "enable_budgets": False, + "enabled": True, + "ldap_groups": [], + "name": "test_project_1", + "title": "test_project_1", + "updated_on": 0, + "users": [], + } + ) + table_data_to_merge = [ + { + "project_id": "test_project_id_1", + "created_on": 0, + "description": "test_project_1", + "enable_budgets": False, + "enabled": True, + "ldap_groups": [], + "name": "test_project_1", + "title": "test_project_1", + "updated_on": 0, + "users": [], + }, + ] + + resolver = ProjectsTableMerger() + record_deltas, success = resolver.merge( + context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper(context.logger("projects_table_resolver")), + ) + assert success + assert len(record_deltas) == 1 + assert record_deltas[0].original_record is None + assert record_deltas[0].snapshot_record.get("name") == "test_project_1" + assert record_deltas[0].resolved_record.get("name") == "test_project_1_dedup_id" + assert record_deltas[0].action_performed == MergedRecordActionType.CREATE + + project = context.projects.get_project( + GetProjectRequest(project_name=f"test_project_1_dedup_id") + ).project + assert project is not None + + +def test_users_table_resolver_rollback_original_data_succeed( + context: AppContext, monkeypatch +): + context.projects.projects_dao.create_project( + { + "project_id": "test_project_id_2", + "created_on": 0, + "description": "test_project_2", + "enable_budgets": False, + "enabled": True, + "ldap_groups": [], + "name": "test_project_2", + "title": "test_project_2", + "updated_on": 0, + "users": [], + } + ) + + resolver = ProjectsTableMerger() + delta_records = [ + MergedRecordDelta( + snapshot_record={ + "project_id": "test_project_id_2", + "created_on": 0, + "description": "test_project_2", + "enable_budgets": False, + "enabled": True, + "ldap_groups": [], + "name": "test_project_2", + "title": "test_project_2", + "updated_on": 0, + "users": [], + }, + resolved_record={ + "project_id": "test_project_id_2", + "created_on": 0, + "description": "test_project_2", + "enable_budgets": False, + "enabled": True, + "ldap_groups": [], + "name": "test_project_2", + "title": "test_project_2", + "updated_on": 0, + "users": [], + }, + action_performed=MergedRecordActionType.CREATE, + ), + ] + resolver.rollback( + context, + delta_records, + ApplySnapshotObservabilityHelper(context.logger("projects_table_resolver")), + ) + + with pytest.raises(exceptions.SocaException) as exc_info: + _project = context.projects.get_project( + GetProjectRequest(project_name="test_project_2") + ).project + assert exc_info.value.error_code == errorcodes.PROJECT_NOT_FOUND + + +def test_users_table_merger_ignore_existing_project_succeed( + context: AppContext, monkeypatch +): + context.projects.projects_dao.create_project( + { + "project_id": "test_project_id_3", + "created_on": 0, + "description": "test_project_3", + "enable_budgets": False, + "enabled": False, + "ldap_groups": [], + "name": "test_project_3", + "title": "test_project_3", + "updated_on": 0, + "users": [], + } + ) + project = context.projects.get_project( + GetProjectRequest(project_name=f"test_project_3") + ).project + created_on = Utils.to_milliseconds(project.created_on) + updated_on = Utils.to_milliseconds(project.updated_on) + project = context.projects.projects_dao.convert_to_db(project) + project["created_on"] = created_on + project["updated_on"] = updated_on + table_data_to_merge = [project] + + resolver = ProjectsTableMerger() + record_deltas, success = resolver.merge( + context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper(context.logger("projects_table_resolver")), + ) + assert success + assert len(record_deltas) == 0 + + +def test_projects_table_merger_ignore_nonexistent_group_succeed( + context: AppContext, monkeypatch +): + table_data_to_merge = [ + { + "project_id": "test_project_id_4", + "created_on": 0, + "description": "test_project_4", + "enable_budgets": False, + "enabled": True, + "ldap_groups": ["group_1"], + "name": "test_project_4", + "title": "test_project_4", + "updated_on": 0, + "users": [], + } + ] + + resolver = ProjectsTableMerger() + record_deltas, success = resolver.merge( + context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper(context.logger("projects_table_resolver")), + ) + assert success + assert len(record_deltas) == 1 + assert not record_deltas[0].resolved_record.get("ldap_groups") + + +def test_projects_table_merger_ignore_nonexistent_user_succeed( + context: AppContext, monkeypatch +): + table_data_to_merge = [ + { + "project_id": "test_project_id_5", + "created_on": 0, + "description": "test_project_5", + "enable_budgets": False, + "enabled": True, + "ldap_groups": [], + "name": "test_project_5", + "title": "test_project_5", + "updated_on": 0, + "users": ["user1"], + } + ] + + resolver = ProjectsTableMerger() + record_deltas, success = resolver.merge( + context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper(context.logger("projects_table_resolver")), + ) + assert success + assert len(record_deltas) == 1 + assert not record_deltas[0].resolved_record.get("users") + + +def test_projects_table_merger_add_groups_and_users_to_project_succeed( + context: AppContext, monkeypatch +): + monkeypatch.setattr(context.accounts, "get_group", lambda group_name: None) + + project_groups_updated_called = False + + def _task_manager_send_mock( + task_name, payload, message_group_id=None, message_dedupe_id=None + ): + if task_name == "projects.project-groups-updated": + nonlocal project_groups_updated_called + project_groups_updated_called = True + + assert payload["groups_added"] == ["group_1"] + assert payload["users_added"] == ["user1"] + + monkeypatch.setattr( + context.projects.task_manager, + "send", + _task_manager_send_mock, + ) + table_data_to_merge = [ + { + "project_id": "test_project_id_6", + "created_on": 0, + "description": "test_project_6", + "enable_budgets": False, + "enabled": True, + "ldap_groups": ["group_1"], + "name": "test_project_6", + "title": "test_project_6", + "updated_on": 0, + "users": ["user1"], + } + ] + + resolver = ProjectsTableMerger() + record_deltas, success = resolver.merge( + context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper(context.logger("projects_table_resolver")), + ) + assert success + assert project_groups_updated_called + + +def test_projects_table_merger_ignore_nonexistent_budget_succeed( + context: AppContext, monkeypatch +): + def _budgets_get_budget(_budget_name): + raise Exception() + + monkeypatch.setattr( + context.aws_util(), + "budgets_get_budget", + _budgets_get_budget, + ) + table_data_to_merge = [ + { + "project_id": "test_project_id_7", + "created_on": 0, + "description": "test_project_7", + "enable_budgets": True, + "budget_name": "budget_name", + "enabled": True, + "ldap_groups": [], + "name": "test_project_7", + "title": "test_project_7", + "updated_on": 0, + "users": [], + } + ] + + resolver = ProjectsTableMerger() + record_deltas, success = resolver.merge( + context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper(context.logger("projects_table_resolver")), + ) + assert success + + project = context.projects.get_project( + GetProjectRequest(project_name=f"test_project_7") + ).project + + assert not project.enable_budgets diff --git a/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/test_software_stacks_table_merger.py b/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/test_software_stacks_table_merger.py new file mode 100644 index 0000000..5639998 --- /dev/null +++ b/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/test_software_stacks_table_merger.py @@ -0,0 +1,395 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import unittest + +import ideaclustermanager.app.snapshots.helpers.db_utils as db_utils +import pytest +from _pytest.monkeypatch import MonkeyPatch +from ideaclustermanager.app.snapshots.apply_snapshot_merge_table.software_stacks_table_merger import ( + SoftwareStacksTableMerger, +) +from ideaclustermanager.app.snapshots.helpers.apply_snapshot_observability_helper import ( + ApplySnapshotObservabilityHelper, +) +from ideaclustermanager.app.snapshots.helpers.merged_record_utils import ( + MergedRecordActionType, + MergedRecordDelta, +) + +from ideadatamodel import ( + Project, + SocaMemory, + SocaMemoryUnit, + VirtualDesktopArchitecture, + VirtualDesktopBaseOS, + VirtualDesktopGPU, + VirtualDesktopSoftwareStack, + errorcodes, + exceptions, +) +from ideadatamodel.api.api_model import ApiAuthorization, ApiAuthorizationType +from ideadatamodel.snapshots.snapshot_model import TableName + + +@pytest.fixture(scope="class") +def monkeypatch_for_class(request): + request.cls.monkeypatch = MonkeyPatch() + + +@pytest.fixture(scope="class") +def context_for_class(request, context): + request.cls.context = context + + +@pytest.mark.usefixtures("monkeypatch_for_class") +@pytest.mark.usefixtures("context_for_class") +class TestPermissionProfilesTableMerger(unittest.TestCase): + def setUp(self): + self.monkeypatch.setattr( + self.context.token_service, "decode_token", lambda token: {} + ) + self.monkeypatch.setattr( + self.context.api_authorization_service, + "get_authorization", + lambda decoded_token: ApiAuthorization(type=ApiAuthorizationType.USER), + ) + + def test_software_stacks_table_resolver_merge_new_software_stack_succeed( + self, + ): + create_software_stack_called = False + + def _create_software_stack_mock(software_stack): + nonlocal create_software_stack_called + create_software_stack_called = True + assert software_stack.name == "test_software_stack" + return software_stack + + self.monkeypatch.setattr( + self.context.vdc_client, + "create_software_stack", + _create_software_stack_mock, + ) + self.monkeypatch.setattr( + self.context.vdc_client, + "get_software_stacks_by_name", + lambda _stack_name: [], + ) + + table_data_to_merge = [ + { + db_utils.SOFTWARE_STACK_DB_NAME_KEY: "test_software_stack", + db_utils.SOFTWARE_STACK_DB_BASE_OS_KEY: "amazonlinux2", + db_utils.SOFTWARE_STACK_DB_MIN_RAM_VALUE_KEY: 10.0, + db_utils.SOFTWARE_STACK_DB_MIN_STORAGE_VALUE_KEY: 10.0, + db_utils.SOFTWARE_STACK_DB_MIN_RAM_UNIT_KEY: "gb", + db_utils.SOFTWARE_STACK_DB_MIN_STORAGE_UNIT_KEY: "gb", + db_utils.SOFTWARE_STACK_DB_ARCHITECTURE_KEY: "x86_64", + db_utils.SOFTWARE_STACK_DB_GPU_KEY: "NO_GPU", + }, + ] + + resolver = SoftwareStacksTableMerger() + record_deltas, success = resolver.merge( + self.context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper( + self.context.logger("software_stacks_table_resolver") + ), + ) + + assert success + assert len(record_deltas) == 1 + assert record_deltas[0].original_record is None + assert ( + record_deltas[0].snapshot_record.get(db_utils.SOFTWARE_STACK_DB_NAME_KEY) + == "test_software_stack" + ) + assert ( + record_deltas[0].resolved_record.get(db_utils.SOFTWARE_STACK_DB_NAME_KEY) + == "test_software_stack" + ) + assert record_deltas[0].action_performed == MergedRecordActionType.CREATE + assert create_software_stack_called + + def test_software_stacks_table_resolver_merge_existing_software_stack_succeed( + self, + ): + create_software_stack_called = False + + def _create_software_stack_mock(software_stack): + nonlocal create_software_stack_called + create_software_stack_called = True + assert software_stack.name == "test_software_stack_dedup_id" + return software_stack + + self.monkeypatch.setattr( + self.context.vdc_client, + "create_software_stack", + _create_software_stack_mock, + ) + self.monkeypatch.setattr( + self.context.vdc_client, + "get_software_stacks_by_name", + lambda _stack_name: [ + VirtualDesktopSoftwareStack(name="test_software_stack"), + ], + ) + + table_data_to_merge = [ + { + db_utils.SOFTWARE_STACK_DB_NAME_KEY: "test_software_stack", + db_utils.SOFTWARE_STACK_DB_BASE_OS_KEY: "amazonlinux2", + db_utils.SOFTWARE_STACK_DB_MIN_RAM_VALUE_KEY: 10.0, + db_utils.SOFTWARE_STACK_DB_MIN_STORAGE_VALUE_KEY: 10.0, + db_utils.SOFTWARE_STACK_DB_MIN_RAM_UNIT_KEY: "gb", + db_utils.SOFTWARE_STACK_DB_MIN_STORAGE_UNIT_KEY: "gb", + db_utils.SOFTWARE_STACK_DB_ARCHITECTURE_KEY: "x86_64", + db_utils.SOFTWARE_STACK_DB_GPU_KEY: "NO_GPU", + }, + ] + + resolver = SoftwareStacksTableMerger() + record_deltas, success = resolver.merge( + self.context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper( + self.context.logger("software_stacks_table_resolver") + ), + ) + + assert success + assert len(record_deltas) == 1 + assert record_deltas[0].original_record is None + assert ( + record_deltas[0].snapshot_record.get(db_utils.SOFTWARE_STACK_DB_NAME_KEY) + == "test_software_stack" + ) + assert ( + record_deltas[0].resolved_record.get(db_utils.SOFTWARE_STACK_DB_NAME_KEY) + == "test_software_stack_dedup_id" + ) + assert record_deltas[0].action_performed == MergedRecordActionType.CREATE + assert create_software_stack_called + + def test_software_stacks_table_resolver_rollback_original_data_succeed(self): + test_stack_id = "test_stack_id" + test_base_os = "amazonlinux2" + delete_software_stack_called = False + + def _delete_software_stack_mock(software_stack: VirtualDesktopSoftwareStack): + nonlocal delete_software_stack_called + delete_software_stack_called = True + assert software_stack.stack_id == test_stack_id + assert software_stack.base_os == test_base_os + + self.monkeypatch.setattr( + self.context.vdc_client, + "delete_software_stack", + _delete_software_stack_mock, + ) + + delta_records = [ + MergedRecordDelta( + snapshot_record={ + db_utils.SOFTWARE_STACK_DB_NAME_KEY: "test_software_stack", + db_utils.SOFTWARE_STACK_DB_STACK_ID_KEY: test_stack_id, + db_utils.SOFTWARE_STACK_DB_BASE_OS_KEY: test_base_os, + db_utils.SOFTWARE_STACK_DB_MIN_RAM_VALUE_KEY: 10.0, + db_utils.SOFTWARE_STACK_DB_MIN_STORAGE_VALUE_KEY: 10.0, + db_utils.SOFTWARE_STACK_DB_MIN_RAM_UNIT_KEY: "gb", + db_utils.SOFTWARE_STACK_DB_MIN_STORAGE_UNIT_KEY: "gb", + db_utils.SOFTWARE_STACK_DB_ARCHITECTURE_KEY: "x86_64", + db_utils.SOFTWARE_STACK_DB_GPU_KEY: "NO_GPU", + }, + resolved_record={ + db_utils.SOFTWARE_STACK_DB_NAME_KEY: "test_software_stack", + db_utils.SOFTWARE_STACK_DB_STACK_ID_KEY: test_stack_id, + db_utils.SOFTWARE_STACK_DB_BASE_OS_KEY: test_base_os, + db_utils.SOFTWARE_STACK_DB_MIN_RAM_VALUE_KEY: 10.0, + db_utils.SOFTWARE_STACK_DB_MIN_STORAGE_VALUE_KEY: 10.0, + db_utils.SOFTWARE_STACK_DB_MIN_RAM_UNIT_KEY: "gb", + db_utils.SOFTWARE_STACK_DB_MIN_STORAGE_UNIT_KEY: "gb", + db_utils.SOFTWARE_STACK_DB_ARCHITECTURE_KEY: "x86_64", + db_utils.SOFTWARE_STACK_DB_GPU_KEY: "NO_GPU", + }, + action_performed=MergedRecordActionType.CREATE, + ), + ] + + resolver = SoftwareStacksTableMerger() + resolver.rollback( + self.context, + delta_records, + ApplySnapshotObservabilityHelper( + self.context.logger("users_table_resolver") + ), + ) + + assert delete_software_stack_called + + def test_software_stacks_table_resolver_resolve_project_id_succeed(self): + def _create_software_stack_mock(software_stack): + return software_stack + + self.monkeypatch.setattr( + self.context.vdc_client, + "create_software_stack", + _create_software_stack_mock, + ) + table_data_to_merge = [ + { + db_utils.SOFTWARE_STACK_DB_NAME_KEY: "test_software_stack", + db_utils.SOFTWARE_STACK_DB_BASE_OS_KEY: "amazonlinux2", + db_utils.SOFTWARE_STACK_DB_MIN_RAM_VALUE_KEY: 10.0, + db_utils.SOFTWARE_STACK_DB_MIN_STORAGE_VALUE_KEY: 10.0, + db_utils.SOFTWARE_STACK_DB_MIN_RAM_UNIT_KEY: "gb", + db_utils.SOFTWARE_STACK_DB_MIN_STORAGE_UNIT_KEY: "gb", + db_utils.SOFTWARE_STACK_DB_ARCHITECTURE_KEY: "x86_64", + db_utils.SOFTWARE_STACK_DB_GPU_KEY: "NO_GPU", + db_utils.SOFTWARE_STACK_DB_PROJECTS_KEY: ["snapshot_project_id"], + }, + ] + + resolver = SoftwareStacksTableMerger() + record_deltas, success = resolver.merge( + self.context, + table_data_to_merge, + "dedup_id", + { + TableName.PROJECTS_TABLE_NAME: [ + MergedRecordDelta( + snapshot_record={"project_id": "snapshot_project_id"}, + resolved_record={"project_id": "resolved_project_id"}, + ) + ] + }, + ApplySnapshotObservabilityHelper( + self.context.logger("software_stacks_table_resolver") + ), + ) + + assert success + assert len(record_deltas) == 1 + assert record_deltas[0].original_record is None + assert len(record_deltas[0].resolved_record.get("projects")) == 1 + assert record_deltas[0].resolved_record["projects"][0] == "resolved_project_id" + + def test_software_stacks_table_resolver_ignore_unchanged_software_stack_succeed( + self, + ): + create_software_stack_called = False + + def _create_software_stack_mock(software_stack): + nonlocal create_software_stack_called + create_software_stack_called = True + return software_stack + + self.monkeypatch.setattr( + self.context.vdc_client, + "create_software_stack", + _create_software_stack_mock, + ) + self.monkeypatch.setattr( + self.context.vdc_client, + "get_software_stacks_by_name", + lambda _stack_name: [ + VirtualDesktopSoftwareStack( + name="test_software_stack", + base_os=VirtualDesktopBaseOS.AMAZON_LINUX2, + min_storage=SocaMemory(value=10, unit=SocaMemoryUnit.GB), + min_ram=SocaMemory(value=10, unit=SocaMemoryUnit.GB), + architecture=VirtualDesktopArchitecture.X86_64, + gpu=VirtualDesktopGPU.NO_GPU, + projects=[Project(project_id="project_id")], + ), + ], + ) + + table_data_to_merge = [ + { + db_utils.SOFTWARE_STACK_DB_NAME_KEY: "test_software_stack", + db_utils.SOFTWARE_STACK_DB_BASE_OS_KEY: "amazonlinux2", + db_utils.SOFTWARE_STACK_DB_MIN_RAM_VALUE_KEY: 10.0, + db_utils.SOFTWARE_STACK_DB_MIN_STORAGE_VALUE_KEY: 10.0, + db_utils.SOFTWARE_STACK_DB_MIN_RAM_UNIT_KEY: "gb", + db_utils.SOFTWARE_STACK_DB_MIN_STORAGE_UNIT_KEY: "gb", + db_utils.SOFTWARE_STACK_DB_ARCHITECTURE_KEY: "x86_64", + db_utils.SOFTWARE_STACK_DB_GPU_KEY: "NO_GPU", + db_utils.SOFTWARE_STACK_DB_PROJECTS_KEY: ["project_id"], + }, + ] + + resolver = SoftwareStacksTableMerger() + record_deltas, success = resolver.merge( + self.context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper( + self.context.logger("software_stacks_table_resolver") + ), + ) + + assert success + assert len(record_deltas) == 0 + assert not create_software_stack_called + + def test_software_stacks_table_resolver_ignore_software_stack_with_invalid_ami_id_succeed( + self, + ): + create_software_stack_called = False + + def _create_software_stack_mock(_software_stack): + raise exceptions.SocaException( + error_code=errorcodes.INVALID_PARAMS, + message="Invalid software_stack.ami_id", + ) + + self.monkeypatch.setattr( + self.context.vdc_client, + "create_software_stack", + _create_software_stack_mock, + ) + + table_data_to_merge = [ + { + db_utils.SOFTWARE_STACK_DB_NAME_KEY: "test_software_stack", + db_utils.SOFTWARE_STACK_DB_BASE_OS_KEY: "amazonlinux2", + db_utils.SOFTWARE_STACK_DB_MIN_RAM_VALUE_KEY: 10.0, + db_utils.SOFTWARE_STACK_DB_MIN_STORAGE_VALUE_KEY: 10.0, + db_utils.SOFTWARE_STACK_DB_MIN_RAM_UNIT_KEY: "gb", + db_utils.SOFTWARE_STACK_DB_MIN_STORAGE_UNIT_KEY: "gb", + db_utils.SOFTWARE_STACK_DB_ARCHITECTURE_KEY: "x86_64", + db_utils.SOFTWARE_STACK_DB_GPU_KEY: "NO_GPU", + db_utils.SOFTWARE_STACK_DB_PROJECTS_KEY: ["project_id"], + }, + ] + + resolver = SoftwareStacksTableMerger() + record_deltas, success = resolver.merge( + self.context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper( + self.context.logger("software_stacks_table_resolver") + ), + ) + + assert success + assert len(record_deltas) == 0 + assert not create_software_stack_called diff --git a/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/test_users_table_merger.py b/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/test_users_table_merger.py new file mode 100644 index 0000000..074faf1 --- /dev/null +++ b/source/tests/unit/idea-cluster-manager/snapshot/apply_snapshot_merge_table/test_users_table_merger.py @@ -0,0 +1,227 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import pytest +from ideaclustermanager import AppContext +from ideaclustermanager.app.snapshots.apply_snapshot_merge_table.users_table_merger import ( + UsersTableMerger, +) +from ideaclustermanager.app.snapshots.helpers.apply_snapshot_observability_helper import ( + ApplySnapshotObservabilityHelper, +) +from ideaclustermanager.app.snapshots.helpers.merged_record_utils import ( + MergedRecordActionType, + MergedRecordDelta, +) + +from ideadatamodel import constants, errorcodes, exceptions + + +def test_users_table_merger_merge_valid_data_succeed(context: AppContext): + table_data_to_merge = [ + {"username": "test_user_1", "role": constants.ADMIN_ROLE, "sudo": True}, + {"username": "test_user_2", "role": constants.USER_ROLE, "sudo": False}, + ] + context.accounts.user_dao.create_user( + { + "username": "test_user_1", + "email": "test_user_1@res.test", + "uid": 0, + "gid": 0, + "additional_groups": [], + "login_shell": "", + "home_dir": "", + "sudo": False, + "enabled": True, + "role": constants.USER_ROLE, + "is_active": True, + } + ) + context.accounts.user_dao.create_user( + { + "username": "test_user_2", + "email": "test_user_2@res.test", + "uid": 1, + "gid": 1, + "additional_groups": [], + "login_shell": "", + "home_dir": "", + "sudo": True, + "enabled": True, + "role": constants.ADMIN_ROLE, + "is_active": True, + } + ) + + merger = UsersTableMerger() + record_deltas, success = merger.merge( + context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper(context.logger("users_table_merger")), + ) + assert success + assert len(record_deltas) == 2 + assert record_deltas[0].original_record.get("role") == constants.USER_ROLE + assert record_deltas[0].snapshot_record.get("role") == constants.ADMIN_ROLE + assert record_deltas[0].resolved_record is None + assert record_deltas[0].action_performed == MergedRecordActionType.UPDATE + assert record_deltas[1].original_record.get("role") == constants.ADMIN_ROLE + assert record_deltas[1].snapshot_record.get("role") == constants.USER_ROLE + assert record_deltas[1].resolved_record is None + assert record_deltas[1].action_performed == MergedRecordActionType.UPDATE + + imported_test_user_1 = context.accounts.get_user("test_user_1") + assert imported_test_user_1.role == constants.ADMIN_ROLE + assert imported_test_user_1.sudo + + imported_test_user_2 = context.accounts.get_user("test_user_2") + assert imported_test_user_2.role == constants.USER_ROLE + assert not imported_test_user_2.sudo + + +def test_users_table_merger_ignore_nonexistent_user_succeed(context: AppContext): + table_data_to_merge = [ + {"username": "test_user_3", "role": constants.ADMIN_ROLE, "sudo": True}, + ] + + merger = UsersTableMerger() + record_deltas, success = merger.merge( + context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper(context.logger("users_table_merger")), + ) + assert success + assert len(record_deltas) == 0 + + with pytest.raises(exceptions.SocaException) as exc_info: + context.accounts.get_user("test_user_3") + assert exc_info.value.error_code == errorcodes.AUTH_USER_NOT_FOUND + + +def test_users_table_merger_ignore_user_without_permission_change_succeed( + context: AppContext, +): + table_data_to_merge = [ + {"username": "test_user_3", "role": constants.ADMIN_ROLE, "sudo": True}, + ] + + context.accounts.user_dao.create_user( + { + "username": "test_user_3", + "email": "test_user_1@res.test", + "uid": 2, + "gid": 2, + "additional_groups": [], + "login_shell": "", + "home_dir": "", + "sudo": True, + "enabled": True, + "role": constants.ADMIN_ROLE, + "is_active": True, + } + ) + + merger = UsersTableMerger() + record_deltas, success = merger.merge( + context, + table_data_to_merge, + "dedup_id", + {}, + ApplySnapshotObservabilityHelper(context.logger("users_table_merger")), + ) + assert success + assert len(record_deltas) == 0 + + imported_test_user_3 = context.accounts.get_user("test_user_3") + assert imported_test_user_3.role == constants.ADMIN_ROLE + assert imported_test_user_3.sudo + + +def test_users_table_merger_roll_back_original_data_succeed( + context: AppContext, monkeypatch +): + context.accounts.user_dao.create_user( + { + "username": "test_user_4", + "email": "test_user_1@res.test", + "uid": 3, + "gid": 3, + "additional_groups": [], + "login_shell": "", + "home_dir": "", + "sudo": True, + "enabled": True, + "role": constants.ADMIN_ROLE, + "is_active": True, + } + ) + context.accounts.user_dao.create_user( + { + "username": "test_user_5", + "email": "test_user_2@res.test", + "uid": 4, + "gid": 4, + "additional_groups": [], + "login_shell": "", + "home_dir": "", + "sudo": False, + "enabled": True, + "role": constants.USER_ROLE, + "is_active": True, + } + ) + + merger = UsersTableMerger() + delta_records = [ + MergedRecordDelta( + original_record={ + "username": "test_user_4", + "sudo": False, + "role": constants.USER_ROLE, + }, + snapshot_record={ + "username": "test_user_4", + "sudo": False, + "role": constants.ADMIN_ROLE, + }, + action_performed=MergedRecordActionType.UPDATE, + ), + MergedRecordDelta( + original_record={ + "username": "test_user_5", + "sudo": False, + "role": constants.ADMIN_ROLE, + }, + snapshot_record={ + "username": "test_user_5", + "sudo": False, + "role": constants.USER_ROLE, + }, + action_performed=MergedRecordActionType.UPDATE, + ), + ] + merger.rollback( + context, + delta_records, + ApplySnapshotObservabilityHelper(context.logger("users_table_merger")), + ) + + imported_test_user_4 = context.accounts.get_user("test_user_4") + assert imported_test_user_4.role == constants.USER_ROLE + assert not imported_test_user_4.sudo + + imported_test_user_5 = context.accounts.get_user("test_user_5") + assert imported_test_user_5.role == constants.ADMIN_ROLE + assert imported_test_user_5.sudo diff --git a/source/tests/unit/idea-cluster-manager/snapshot/test_apply_snapshot.py b/source/tests/unit/idea-cluster-manager/snapshot/test_apply_snapshot.py new file mode 100644 index 0000000..7d1f587 --- /dev/null +++ b/source/tests/unit/idea-cluster-manager/snapshot/test_apply_snapshot.py @@ -0,0 +1,538 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import unittest +from enum import Enum +from logging import Logger +from typing import Dict, List, Tuple +from unittest.mock import MagicMock + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from ideaclustermanager.app.snapshots.apply_snapshot import ApplySnapshot +from ideaclustermanager.app.snapshots.apply_snapshot_data_transformation_from_version.abstract_transformation_from_res_version import ( + TransformationFromRESVersion, +) +from ideaclustermanager.app.snapshots.apply_snapshot_merge_table.merge_table import ( + MergeTable, +) +from ideaclustermanager.app.snapshots.helpers.apply_snapshot_version_control_helper import ( + get_table_keys_by_res_version, +) +from ideaclustermanager.app.snapshots.helpers.apply_snapshots_config import ( + RES_VERSION_TO_DATA_TRANSFORMATION_CLASS, + TABLE_TO_TABLE_KEYS_BY_VERSION, +) +from ideaclustermanager.app.snapshots.helpers.merged_record_utils import ( + MergedRecordDelta, +) +from ideaclustermanager.app.snapshots.snapshot_constants import ( + TABLE_EXPORT_DESCRIPTION_KEY, + VERSION_KEY, +) +from ideasdk.context import SocaContext + +from ideadatamodel import errorcodes, exceptions +from ideadatamodel.snapshots import RESVersion, Snapshot, TableKeys, TableName + + +@pytest.fixture(scope="class") +def monkeypatch_for_class(request): + request.cls.monkeypatch = MonkeyPatch() + + +@pytest.fixture(scope="class") +def context_for_class(request, context): + request.cls.context = context + + +class DummyTableName(str, Enum): + TABLE1 = "table1" + TABLE2 = "table2" + TABLE3 = "table3" + + +class DummyResVersion(str, Enum): + VERSION1 = "version1" + VERSION2 = "version2" + VERSION3 = "version3" + + +DUMMY_TABLE_TO_TABLE_KEYS_BY_VERSION = { + DummyTableName.TABLE1: { + DummyResVersion.VERSION1: TableKeys(partition_key="table1-version1") + }, + DummyTableName.TABLE2: { + DummyResVersion.VERSION1: TableKeys(partition_key="table2-version1"), + DummyResVersion.VERSION2: TableKeys(partition_key="table2-version2"), + }, + DummyTableName.TABLE3: { + DummyResVersion.VERSION1: TableKeys(partition_key="table3-version1"), + DummyResVersion.VERSION2: TableKeys(partition_key="table3-version2"), + DummyResVersion.VERSION3: TableKeys(partition_key="table3-version3"), + }, +} + +DUMMY_RES_VERSION_IN_TOPOLOGICAL_ORDER = [ + DummyResVersion.VERSION1, + DummyResVersion.VERSION2, + DummyResVersion.VERSION3, +] + +dummy_table_export_descriptions = { + "table1": "table1", + "table2": "table2", + "table4": "table4", +} + + +class DummyTransformationFromVersion1(TransformationFromRESVersion): + def transform_data( + self, env_data_by_table: Dict[TableName, List], logger=Logger + ) -> Dict[TableName, List]: + env_data_by_table[DummyTableName.TABLE1].append( + "Transformation from V1 applied" + ) + return env_data_by_table + + +class DummyTransformationFromVersion3(TransformationFromRESVersion): + def transform_data( + self, env_data_by_table: Dict[TableName, List], logger=Logger + ) -> Dict[TableName, List]: + env_data_by_table[DummyTableName.TABLE1].append( + "Transformation from V3 applied" + ) + return env_data_by_table + + +DUMMY_RES_VERSION_TO_DATA_TRANSFORMATION_CLASS = { + DummyResVersion.VERSION1: DummyTransformationFromVersion1, + DummyResVersion.VERSION3: DummyTransformationFromVersion3, +} + + +ROLLEDBACK = "rolledback" +MERGED = "merged" + +MERGE_TABLE_RESULT_TRACKER = { + DummyTableName.TABLE1: "", + DummyTableName.TABLE2: "", + DummyTableName.TABLE3: "", +} + +DUMMY_TABLES_IN_MERGE_DEPENDENCY_ORDER = [ + DummyTableName.TABLE1, + DummyTableName.TABLE2, + DummyTableName.TABLE3, +] + + +class MergeTableTable1(MergeTable): + def merge( + self, + context: SocaContext, + table_data_to_merge: List, + dedup_id: str, + merged_record_deltas: Dict[TableName, List[MergedRecordDelta]], + logger=Logger, + ) -> Tuple[Dict, bool]: + global MERGE_TABLE_RESULT_TRACKER + + MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE1] = MERGED + return {}, True + + def rollback(self, context: SocaContext, merge_delta: Dict, logger: Logger): + global MERGE_TABLE_RESULT_TRACKER + + MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE1] = ROLLEDBACK + + +class MergeTableTable2(MergeTable): + def merge( + self, + context: SocaContext, + table_data_to_merge: List, + dedup_id: str, + merged_record_deltas: Dict[TableName, List[MergedRecordDelta]], + logger=Logger, + ) -> Tuple[Dict, bool]: + global MERGE_TABLE_RESULT_TRACKER + + MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE2] = MERGED + return {}, True + + def rollback(self, context: SocaContext, merge_delta: Dict, logger: Logger): + global MERGE_TABLE_RESULT_TRACKER + + MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE2] = ROLLEDBACK + + +class MergeTableTable2MergeFail(MergeTable): + def merge( + self, + context: SocaContext, + table_data_to_merge: List, + dedup_id: str, + merged_record_deltas: Dict[TableName, List[MergedRecordDelta]], + logger=Logger, + ) -> Tuple[Dict, bool]: + global MERGE_TABLE_RESULT_TRACKER + + MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE2] = MERGED + return {}, False + + def rollback(self, context: SocaContext, merge_delta: Dict, logger: Logger): + global MERGE_TABLE_RESULT_TRACKER + + MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE2] = ROLLEDBACK + + +class MergeTableTable2RollbackFail(MergeTable): + def merge( + self, + context: SocaContext, + table_data_to_merge: List, + dedup_id: str, + merged_record_deltas: Dict[TableName, List[MergedRecordDelta]], + logger=Logger, + ) -> Tuple[Dict, bool]: + global MERGE_TABLE_RESULT_TRACKER + + MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE2] = MERGED + return {}, True + + def rollback(self, context: SocaContext, merge_delta: Dict, logger: Logger): + raise exceptions.SocaException( + error_code=errorcodes.GENERAL_ERROR, message="Fake error" + ) + + +class MergeTableTable3(MergeTable): + def merge( + self, + context: SocaContext, + table_data_to_merge: List, + dedup_id: str, + merged_record_deltas: Dict[TableName, List[MergedRecordDelta]], + logger=Logger, + ) -> Tuple[Dict, bool]: + global MERGE_TABLE_RESULT_TRACKER + + MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE3] = MERGED + return {}, True + + def rollback(self, context: SocaContext, merge_delta: Dict, logger: Logger): + global MERGE_TABLE_RESULT_TRACKER + MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE3] = ROLLEDBACK + + +class MergeTableTable3MergeFail(MergeTable): + def merge( + self, + context: SocaContext, + table_data_to_merge: List, + dedup_id: str, + merged_record_deltas: Dict[TableName, List[MergedRecordDelta]], + logger=Logger, + ) -> Tuple[Dict, bool]: + return {}, False + + def rollback(self, context: SocaContext, merge_delta: Dict, logger: Logger): + global MERGE_TABLE_RESULT_TRACKER + + MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE3] = ROLLEDBACK + + +@pytest.mark.usefixtures("monkeypatch_for_class") +@pytest.mark.usefixtures("context_for_class") +class ApplySnapshotsTest(unittest.TestCase): + def setUp(self) -> None: + self.monkeypatch.setattr( + "ideaclustermanager.app.snapshots.apply_snapshot.TableName", DummyTableName + ) + self.monkeypatch.setattr( + "ideaclustermanager.app.snapshots.helpers.apply_snapshot_version_control_helper.TableName", + DummyTableName, + ) + self.monkeypatch.setattr( + "ideaclustermanager.app.snapshots.helpers.apply_snapshot_version_control_helper.RESVersion", + DummyResVersion, + ) + self.monkeypatch.setattr( + "ideaclustermanager.app.snapshots.apply_snapshot.RES_VERSION_IN_TOPOLOGICAL_ORDER", + DUMMY_RES_VERSION_IN_TOPOLOGICAL_ORDER, + ) + self.monkeypatch.setattr( + "ideaclustermanager.app.snapshots.helpers.apply_snapshot_version_control_helper.RES_VERSION_IN_TOPOLOGICAL_ORDER", + DUMMY_RES_VERSION_IN_TOPOLOGICAL_ORDER, + ) + self.monkeypatch.setattr( + "ideaclustermanager.app.snapshots.helpers.apply_snapshot_version_control_helper.TABLE_TO_TABLE_KEYS_BY_VERSION", + DUMMY_TABLE_TO_TABLE_KEYS_BY_VERSION, + ) + self.monkeypatch.setattr( + "ideaclustermanager.app.snapshots.apply_snapshot.RES_VERSION_TO_DATA_TRANSFORMATION_CLASS", + DUMMY_RES_VERSION_TO_DATA_TRANSFORMATION_CLASS, + ) + self.monkeypatch.setattr( + "ideaclustermanager.app.snapshots.apply_snapshot.TABLES_IN_MERGE_DEPENDENCY_ORDER", + DUMMY_TABLES_IN_MERGE_DEPENDENCY_ORDER, + ) + self.monkeypatch.setattr( + ApplySnapshot, + "fetch_snapshot_metadata", + lambda x: { + VERSION_KEY: DummyResVersion.VERSION1, + TABLE_EXPORT_DESCRIPTION_KEY: dummy_table_export_descriptions, + }, + ) + + self.monkeypatch.setattr(ApplySnapshot, "apply_snapshot_main", lambda x: None) + + self.monkeypatch.setattr( + self.context.snapshots.apply_snapshot_dao, + "create", + lambda x, created_on: x, + ) + self.monkeypatch.setattr( + self.context.snapshots.apply_snapshot_dao, + "update_status", + MagicMock(), + ) + + global MERGE_TABLE_RESULT_TRACKER + MERGE_TABLE_RESULT_TRACKER = { + DummyTableName.TABLE1: "", + DummyTableName.TABLE2: "", + DummyTableName.TABLE3: "", + } + + def test_get_table_keys_by_res_version_function(self): + response = get_table_keys_by_res_version( + [DummyTableName.TABLE1, DummyTableName.TABLE2, DummyTableName.TABLE3], + DummyResVersion.VERSION3, + ) + + assert ( + response[DummyTableName.TABLE1] + == DUMMY_TABLE_TO_TABLE_KEYS_BY_VERSION[DummyTableName.TABLE1][ + DummyResVersion.VERSION1 + ] + ) + assert ( + response[DummyTableName.TABLE2] + == DUMMY_TABLE_TO_TABLE_KEYS_BY_VERSION[DummyTableName.TABLE2][ + DummyResVersion.VERSION2 + ] + ) + assert ( + response[DummyTableName.TABLE3] + == DUMMY_TABLE_TO_TABLE_KEYS_BY_VERSION[DummyTableName.TABLE3][ + DummyResVersion.VERSION3 + ] + ) + + def test_get_list_of_tables_to_be_imported_function(self): + obj = ApplySnapshot( + snapshot=Snapshot( + s3_bucket_name="some_bucket_name", snapshot_path="valid/path" + ), + context=self.context, + apply_snapshot_dao=self.context.snapshots.apply_snapshot_dao, + ) + obj.initialize() + + response = obj.get_list_of_tables_to_be_imported() + + assert len(response) == 2 + + assert DummyTableName.TABLE1 in response + + assert DummyTableName.TABLE2 in response + + assert DummyTableName.TABLE3 not in response + + # table4 is not identified by the ApplySnapshot process thus is not a key in DummyTableName + assert "table4" not in response + + def test_key_error_exception_handling_apply_snapshot(self): + self.monkeypatch.setattr( + ApplySnapshot, + "fetch_snapshot_metadata", + lambda x: { + "wrong_version_key": DummyResVersion.VERSION2, + "table_export_description": {}, + }, + ) + + with pytest.raises(KeyError) as exc_info: + obj = ApplySnapshot( + snapshot=Snapshot( + s3_bucket_name="some_bucket_name", snapshot_path="some_path" + ), + apply_snapshot_dao=self.context.snapshots.apply_snapshot_dao, + context=self.context, + ) + obj.initialize() + assert "version" in exc_info.value.args[0] + + def test_get_table_keys_by_res_version_exception_handling_apply_snapshot(self): + dummy_table_to_table_keys_by_version = { + DummyTableName.TABLE1: {}, + DummyTableName.TABLE2: {}, + DummyTableName.TABLE3: {}, + } + + self.monkeypatch.setattr( + "ideaclustermanager.app.snapshots.helpers.apply_snapshot_version_control_helper.TABLE_TO_TABLE_KEYS_BY_VERSION", + dummy_table_to_table_keys_by_version, + ) + + with pytest.raises(Exception) as exc_info: + obj = ApplySnapshot( + snapshot=Snapshot( + s3_bucket_name="some_bucket_name", snapshot_path="some_path" + ), + apply_snapshot_dao=self.context.snapshots.apply_snapshot_dao, + context=self.context, + ) + obj.initialize() + + assert ( + "Could not fetch partition_key and sort_key for" in exc_info.value.args[0] + ) + + def test_apply_data_transformation_function(self): + obj = ApplySnapshot( + snapshot=Snapshot( + s3_bucket_name="some_bucket_name", snapshot_path="valid/path" + ), + context=self.context, + apply_snapshot_dao=self.context.snapshots.apply_snapshot_dao, + ) + obj.initialize() + + data = {DummyTableName.TABLE1: [], DummyTableName.TABLE2: []} + + response = obj.apply_data_transformations(data=data) + + table_1_response = response["table1"] + + assert len(table_1_response) == 2 + + assert "Transformation from V1 applied" in table_1_response + + assert "Transformation from V3 applied" in table_1_response + + assert "Transformation from V2 applied" not in table_1_response + + def test_apply_snapshot_merge_function_succeeds(self): + global MERGE_TABLE_RESULT_TRACKER + + assert MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE1] == "" + assert MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE2] == "" + assert MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE3] == "" + + DUMMY_TABLE_TO_MERGE_LOGIC_CLASS_MERGE_FAIL_TEST = { + DummyTableName.TABLE1: MergeTableTable1, + DummyTableName.TABLE2: MergeTableTable2, + DummyTableName.TABLE3: MergeTableTable3, + } + + self.monkeypatch.setattr( + "ideaclustermanager.app.snapshots.apply_snapshot.TABLE_TO_MERGE_LOGIC_CLASS", + DUMMY_TABLE_TO_MERGE_LOGIC_CLASS_MERGE_FAIL_TEST, + ) + + obj = ApplySnapshot( + snapshot=Snapshot( + s3_bucket_name="some_bucket_name", snapshot_path="valid/path" + ), + context=self.context, + apply_snapshot_dao=self.context.snapshots.apply_snapshot_dao, + ) + obj.initialize() + + obj.merge_transformed_data_for_all_tables({}) + + assert MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE1] == MERGED + assert MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE2] == MERGED + assert MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE3] == MERGED + + def test_apply_snapshot_merge_function_handles_merge_failure(self): + DUMMY_TABLE_TO_MERGE_LOGIC_CLASS_MERGE_FAIL_TEST = return_value = { + DummyTableName.TABLE1: MergeTableTable1, + DummyTableName.TABLE2: MergeTableTable2MergeFail, + DummyTableName.TABLE3: MergeTableTable3, + } + + self.monkeypatch.setattr( + "ideaclustermanager.app.snapshots.apply_snapshot.TABLE_TO_MERGE_LOGIC_CLASS", + DUMMY_TABLE_TO_MERGE_LOGIC_CLASS_MERGE_FAIL_TEST, + ) + + obj = ApplySnapshot( + snapshot=Snapshot( + s3_bucket_name="some_bucket_name", snapshot_path="valid/path" + ), + context=self.context, + apply_snapshot_dao=self.context.snapshots.apply_snapshot_dao, + ) + obj.initialize() + + obj.merge_transformed_data_for_all_tables({}) + + assert MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE1] == ROLLEDBACK + assert MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE2] == ROLLEDBACK + assert MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE3] == "" + + def test_apply_snapshot_merge_function_handles_rollback_failure(self): + DUMMY_TABLE_TO_MERGE_LOGIC_CLASS_MERGE_FAIL_TEST = { + DummyTableName.TABLE1: MergeTableTable1, + DummyTableName.TABLE2: MergeTableTable2RollbackFail, + DummyTableName.TABLE3: MergeTableTable3MergeFail, + } + + self.monkeypatch.setattr( + "ideaclustermanager.app.snapshots.apply_snapshot.TABLE_TO_MERGE_LOGIC_CLASS", + DUMMY_TABLE_TO_MERGE_LOGIC_CLASS_MERGE_FAIL_TEST, + ) + + obj = ApplySnapshot( + snapshot=Snapshot( + s3_bucket_name="some_bucket_name", snapshot_path="valid/path" + ), + context=self.context, + apply_snapshot_dao=self.context.snapshots.apply_snapshot_dao, + ) + obj.initialize() + + obj.merge_transformed_data_for_all_tables({}) + + assert MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE1] == MERGED + assert MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE2] == MERGED + assert MERGE_TABLE_RESULT_TRACKER[DummyTableName.TABLE3] == ROLLEDBACK + + +def test_every_table_has_a_corresponding_entry_in_table_to_table_keys_by_version(): + assert len(TableName) == len(TABLE_TO_TABLE_KEYS_BY_VERSION) + + +def test_every_data_transformation_class_inherits_the_abstarct_class(): + for key in RES_VERSION_TO_DATA_TRANSFORMATION_CLASS: + assert key in RESVersion + + value = RES_VERSION_TO_DATA_TRANSFORMATION_CLASS[key] + + assert value is not None + + assert TransformationFromRESVersion in value.__bases__ diff --git a/source/tests/unit/idea-cluster-manager/test_snapshots.py b/source/tests/unit/idea-cluster-manager/snapshot/test_snapshots.py similarity index 84% rename from source/tests/unit/idea-cluster-manager/test_snapshots.py rename to source/tests/unit/idea-cluster-manager/snapshot/test_snapshots.py index 14dc717..dbafe74 100644 --- a/source/tests/unit/idea-cluster-manager/test_snapshots.py +++ b/source/tests/unit/idea-cluster-manager/snapshot/test_snapshots.py @@ -298,3 +298,51 @@ def test_list_snapshots_updates_status_to_completed_when_pending_export_complete assert result.listing is not None assert len(result.listing) == 1 assert result.listing[0].status == SnapshotStatus.COMPLETED + + +def test_apply_snapshot_with_missing_bucket_name_should_fail(context: AppContext): + """ + create snapshot with missing bucket name should fail + """ + with pytest.raises(exceptions.SocaException) as exc_info: + context.snapshots.apply_snapshot(snapshot=Snapshot(s3_bucket_name="")) + assert exc_info.value.error_code == errorcodes.INVALID_PARAMS + assert "s3_bucket_name is required" in exc_info.value.message + + +def test_apply_snapshot_with_invalid_bucket_name_should_fail(context: AppContext): + """ + create snapshot with invalid bucket name that doesn't match the required regex pattern. + """ + with pytest.raises(exceptions.SocaException) as exc_info: + context.snapshots.apply_snapshot( + snapshot=Snapshot(s3_bucket_name="invalid@bucket_name") + ) + assert exc_info.value.error_code == errorcodes.INVALID_PARAMS + assert "s3_bucket_name must match regex" in exc_info.value.message + + +def test_apply_snapshot_with_missing_snapshot_path_should_fail(context: AppContext): + """ + create snapshot with missing snapshot path should fail + """ + with pytest.raises(exceptions.SocaException) as exc_info: + context.snapshots.apply_snapshot( + snapshot=Snapshot(s3_bucket_name="some_bucket_name", snapshot_path="") + ) + assert exc_info.value.error_code == errorcodes.INVALID_PARAMS + assert "snapshot_path is required" in exc_info.value.message + + +def test_apply_snapshot_with_invalid_snapshot_path_should_fail(context: AppContext): + """ + create snapshot with invalid snapshot path that doesn't match the required regex pattern. + """ + with pytest.raises(exceptions.SocaException) as exc_info: + context.snapshots.apply_snapshot( + snapshot=Snapshot( + s3_bucket_name="some_bucket_name", snapshot_path="invalid\\path" + ) + ) + assert exc_info.value.error_code == errorcodes.INVALID_PARAMS + assert "snapshot_path must match regex" in exc_info.value.message diff --git a/source/tests/unit/idea-cluster-manager/test_accounts.py b/source/tests/unit/idea-cluster-manager/test_accounts.py index f2a146c..6143067 100644 --- a/source/tests/unit/idea-cluster-manager/test_accounts.py +++ b/source/tests/unit/idea-cluster-manager/test_accounts.py @@ -36,7 +36,9 @@ exceptions, ) from ideadatamodel.auth import ( + AuthResult, InitiateAuthRequest, + InitiateAuthResult, RespondToAuthChallengeRequest, RespondToAuthChallengeResult, ) @@ -117,56 +119,49 @@ def test_accounts_create_user_invalid_email_should_fail(context: AppContext): assert "invalid email:" in exc_info.value.message -def test_accounts_create_user_with_verified_email_missing_password_should_fail( +def test_accounts_create_user_with_verified_email_missing_uid_should_fail( context: AppContext, ): """ - create valid account with email verified and no password + create valid account with email verified and no uid """ with pytest.raises(exceptions.SocaException) as exc_info: context.accounts.create_user( user=User(username="mockuser1", email="mockuser1@example.com"), email_verified=True, ) - assert exc_info.value.error_code == errorcodes.INVALID_PARAMS - assert "Password is required" in exc_info.value.message + assert exc_info.value.error_code == errorcodes.UID_AND_GID_NOT_FOUND + assert ( + "Unable to retrieve UID and GID for User: mockuser1" in exc_info.value.message + ) -def test_accounts_create_user_with_verified_email_invalid_password_should_fail( +def test_accounts_create_user_with_verified_email_missing_gid_should_fail( context: AppContext, ): """ - Create a valid account with email verified and invalid password - """ - - # Cover all the ways a password will be considered invalid. - too_long_password = Utils.generate_password(257, 2, 2, 2, 2) - invalid_passwords = [ - "2short", - too_long_password, - "No_Numbers", - "no_upper_case_0", - "NO_LOWER_CASE_0", - "NoSpecialCharacters0", - ] - - for invalid_password in invalid_passwords: - with pytest.raises(exceptions.SocaException) as exc_info: - context.accounts.create_user( - user=User( - username="mockuser1", - email="mockuser1@example.com", - password=invalid_password, - ), - email_verified=True, - ) - assert exc_info.value.error_code == errorcodes.INVALID_PARAMS + Create valid account with email verified and no gid + """ + + with pytest.raises(exceptions.SocaException) as exc_info: + context.accounts.create_user( + user=User(username="mockuser1", email="mockuser1@example.com", uid=1234), + email_verified=True, + ) + assert exc_info.value.error_code == errorcodes.UID_AND_GID_NOT_FOUND + assert ( + "Unable to retrieve UID and GID for User: mockuser1" in exc_info.value.message + ) -def test_accounts_crud_create_user(context: AppContext): +def test_accounts_crud_create_user(context: AppContext, monkeypatch): """ create user """ + monkeypatch.setattr( + context.accounts.sssd, "get_uid_and_gid_for_user", lambda x: (1000, 1000) + ) + created_user = context.accounts.create_user( user=User( username="accounts_user1", @@ -204,6 +199,37 @@ def test_accounts_crud_create_user(context: AppContext): AccountsTestContext.crud_user = user +def test_accounts_crud_create_user_ad_uid_and_gid(context: AppContext, monkeypatch): + """ + create user with AD provided uid and gid + """ + monkeypatch.setattr(context.accounts.sssd, "ldap_id_mapping", lambda x: "False") + monkeypatch.setattr( + context.accounts.sssd, "get_uid_and_gid_for_user", lambda x: (100, 100) + ) + created_user = context.accounts.create_user( + user=User( + username="accounts_user2", + email="accounts_user2@example.com", + password="MockPassword_123!%", + uid=100, + gid=100, + login_shell="/bin/bash", + home_dir="home/account_user2", + additional_groups=[], + ), + email_verified=True, + ) + + assert created_user.username is not None + assert created_user.email is not None + + user = context.accounts.get_user(username=created_user.username) + + assert user.uid == 100 + assert user.gid == 100 + + def test_accounts_crud_get_user(context: AppContext): """ get user @@ -219,6 +245,34 @@ def test_accounts_crud_get_user(context: AppContext): assert user.gid == crud_user.gid +def test_accounts_crud_get_user_by_email(context: AppContext): + """ + get user by email + """ + assert AccountsTestContext.crud_user is not None + crud_user = AccountsTestContext.crud_user + + user = context.accounts.get_user_by_email(email=crud_user.email) + + assert user is not None + assert user.username == crud_user.username + assert user.email == crud_user.email + + +def test_accounts_crud_get_user_by_email(context: AppContext): + """ + get user by email + """ + assert AccountsTestContext.crud_user is not None + crud_user = AccountsTestContext.crud_user + + user = context.accounts.get_user_by_email(email=crud_user.email) + + assert user is not None + assert user.username == crud_user.username + assert user.email == crud_user.email + + def test_accounts_crud_modify_user(context: AppContext): """ modify user @@ -363,10 +417,35 @@ def test_accounts_create_internal_group(context: AppContext): ) assert group is not None + assert group.gid is None assert group.name == "dummy-internal-group" assert group.type == constants.GROUP_TYPE_INTERNAL +def test_accounts_create_internal_group_ad_gid(context: AppContext, monkeypatch): + """ + internal group with AD provided gid + """ + monkeypatch.setattr(context.accounts.sssd, "ldap_id_mapping", lambda: "False") + group = context.accounts.create_group( + Group( + title="Dummy Internal Group2", + name="dummy-internal-group-2", + ds_name="dummy-internal-group-2", + gid=100, + group_type=constants.GROUP_TYPE_PROJECT, + type=constants.GROUP_TYPE_INTERNAL, + ) + ) + returned_group = context.accounts.get_group(group.name) + + assert returned_group is not None + + assert returned_group.gid == 100 + assert returned_group.name == "dummy-internal-group-2" + assert returned_group.type == constants.GROUP_TYPE_INTERNAL + + def test_accounts_create_group_with_no_type_param_must_attemp_an_external_group_creation( context: AppContext, ): @@ -521,15 +600,18 @@ def test_accounts_create_group_in_ad_none_should_fail(context: AppContext, monke ) -def test_accounts_create_group_in_ad_gid_none_should_fail( +def test_accounts_create_group_in_ad_gid_none_should_fail_ad_gid( context: AppContext, monkeypatch ): """ group with gid none """ monkeypatch.setattr(context.ldap_client, "is_readonly", lambda: True) - response = {"gid": None} + monkeypatch.setattr(context.accounts.sssd, "get_gid_for_group", lambda x: None) + + response = {"gid": None, "name": "pqr"} monkeypatch.setattr(context.ldap_client, "get_group", lambda x: response) + monkeypatch.setattr(context.accounts.sssd, "ldap_id_mapping", lambda: "False") with pytest.raises(exceptions.SocaException) as exc_info: group = context.accounts.create_group( Group( @@ -540,11 +622,8 @@ def test_accounts_create_group_in_ad_gid_none_should_fail( type="external", ) ) - assert exc_info.value.error_code == errorcodes.INVALID_PARAMS - assert ( - f"Group id is not found in Directory Service: activedirectory" - in exc_info.value.message - ) + assert exc_info.value.error_code == errorcodes.GID_NOT_FOUND + assert f"Unable to retrieve GID for Group: pqr" in exc_info.value.message def test_accounts_create_group_in_ad_with_valid_but_name_none_should_fail( @@ -1529,38 +1608,54 @@ def test_accounts_initiate_auth_user_password_auth_flow( """ initiate auth normal workflow """ - try: - context.accounts.initiate_auth( - request=InitiateAuthRequest( - auth_flow="USER_PASSWORD_AUTH", - username="xyz", - password="abc123", - ), - ) - except Exception as e: - print("failed to user_password_auth in initiate_auth {e}") - monkeypatch.setattr(Utils, "is_empty", lambda x: bool(x == "xyz")) + username = "clusteradmin" + decoded_token = { + "username": username, + } + mock_auth_result = InitiateAuthResult( + auth=AuthResult(access_token=""), + ) + + monkeypatch.setattr( + context.user_pool, "initiate_username_password_auth", lambda a: mock_auth_result + ) + monkeypatch.setattr( + context.token_service, "decode_token", lambda token: decoded_token + ) + result = context.accounts.initiate_auth( + request=InitiateAuthRequest( + auth_flow="USER_PASSWORD_AUTH", + cognito_username=username, + password="abc123", + ), + ) + assert result.role == "admin" + assert result.db_username == "clusteradmin" + + +def test_accounts_initiate_auth_user_password_auth_flow_fail(context: AppContext): + username = "clusteradmin" with pytest.raises(exceptions.SocaException) as exc_info: context.accounts.initiate_auth( request=InitiateAuthRequest( auth_flow="USER_PASSWORD_AUTH", - username="xyz", + cognito_username="random", password="abc123", ), ) - assert exc_info.value.error_code == errorcodes.INVALID_PARAMS - assert "username is required" in exc_info.value.message - monkeypatch.setattr(Utils, "is_empty", lambda x: bool(x == "abc123")) + assert exc_info.value.error_code == errorcodes.UNAUTHORIZED_ACCESS + assert "Unauthorized Access" in exc_info.value.message + with pytest.raises(exceptions.SocaException) as exc_info: context.accounts.initiate_auth( request=InitiateAuthRequest( auth_flow="USER_PASSWORD_AUTH", - username="xyz", - password="abc123", + cognito_username=username, + password="", ), ) assert exc_info.value.error_code == errorcodes.INVALID_PARAMS - assert "password is required" in exc_info.value.message + assert "Invalid params: password is required" in exc_info.value.message def test_accounts_initiate_auth_refresh_token_auth_flow( @@ -1573,18 +1668,18 @@ def test_accounts_initiate_auth_refresh_token_auth_flow( context.accounts.initiate_auth( request=InitiateAuthRequest( auth_flow="REFRESH_TOKEN_AUTH", - username="xyz", + cognito_username="clusteradmin", refresh_token="abc123", ), ) except Exception as e: print("failed to refresh_token_auth in initiate_auth {e}") - monkeypatch.setattr(Utils, "is_empty", lambda x: bool(x == "xyz")) + monkeypatch.setattr(Utils, "is_empty", lambda x: bool(x == "clusteradmin")) with pytest.raises(exceptions.SocaException) as exc_info: context.accounts.initiate_auth( request=InitiateAuthRequest( auth_flow="REFRESH_TOKEN_AUTH", - username="xyz", + cognito_username="clusteradmin", refresh_token="abc123", ), ) @@ -1595,7 +1690,7 @@ def test_accounts_initiate_auth_refresh_token_auth_flow( context.accounts.initiate_auth( request=InitiateAuthRequest( auth_flow="REFRESH_TOKEN_AUTH", - username="xyz", + cognito_username="clusteradmin", refresh_token="abc123", ), ) @@ -1611,7 +1706,7 @@ def test_accounts_initiate_sso_auth_flow(context: AppContext, monkeypatch): context.accounts.initiate_auth( request=InitiateAuthRequest( auth_flow="SSO_AUTH", - username="xyz", + cognito_username="xyz", authorization_code="abc123", ), ) @@ -1622,7 +1717,7 @@ def test_accounts_initiate_sso_auth_flow(context: AppContext, monkeypatch): context.accounts.initiate_auth( request=InitiateAuthRequest( auth_flow="SSO_AUTH", - username="xyz", + cognito_username="xyz", authorization_code="def123", ), ) @@ -1634,7 +1729,7 @@ def test_accounts_initiate_sso_auth_flow(context: AppContext, monkeypatch): context.accounts.initiate_auth( request=InitiateAuthRequest( auth_flow="SSO_AUTH", - username="xyz", + cognito_username="xyz", authorization_code="def123", ), ) @@ -1646,7 +1741,7 @@ def test_accounts_initiate_sso_auth_flow(context: AppContext, monkeypatch): context.accounts.initiate_auth( request=InitiateAuthRequest( auth_flow="SSO_AUTH", - username="xyz", + cognito_username="xyz", authorization_code="def123", ), ) @@ -1656,7 +1751,7 @@ def test_accounts_initiate_sso_auth_flow(context: AppContext, monkeypatch): context.accounts.initiate_auth( request=InitiateAuthRequest( auth_flow="SSO_AUTH", - username="xyz", + cognito_username="xyz", authorization_code="def123", ), ) @@ -1673,7 +1768,7 @@ def test_accounts_initiate_sso_refresh_auth_flow(context: AppContext, monkeypatc context.accounts.initiate_auth( request=InitiateAuthRequest( auth_flow="SSO_REFRESH_TOKEN_AUTH", - username="xyz1", + cognito_username="xyz1", authorization_code="abc456", ), ) @@ -1684,7 +1779,7 @@ def test_accounts_initiate_sso_refresh_auth_flow(context: AppContext, monkeypatc context.accounts.initiate_auth( request=InitiateAuthRequest( auth_flow="SSO_REFRESH_TOKEN_AUTH", - username="xyz1", + cognito_username="xyz1", refresh_token="abc456", ), ) @@ -1695,7 +1790,7 @@ def test_accounts_initiate_sso_refresh_auth_flow(context: AppContext, monkeypatc context.accounts.initiate_auth( request=InitiateAuthRequest( auth_flow="SSO_REFRESH_TOKEN_AUTH", - username="xyz1", + cognito_username="xyz1", refresh_token="abc456", ), ) @@ -1706,7 +1801,7 @@ def test_accounts_initiate_sso_refresh_auth_flow(context: AppContext, monkeypatc context.accounts.initiate_auth( request=InitiateAuthRequest( auth_flow="SSO_REFRESH_TOKEN_AUTH", - username="xyz1", + cognito_username="xyz1", refresh_token="abc456", ), ) diff --git a/source/tests/unit/idea-cluster-manager/test_api_authorization_service.py b/source/tests/unit/idea-cluster-manager/test_api_authorization_service.py new file mode 100644 index 0000000..68aa307 --- /dev/null +++ b/source/tests/unit/idea-cluster-manager/test_api_authorization_service.py @@ -0,0 +1,100 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +""" +Test Cases for ApiAuthorizationService +""" + +import unittest +from typing import Optional + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from ideaclustermanager import AppContext +from ideaclustermanager.app.auth.api_authorization_service import ( + ClusterManagerApiAuthorizationService, +) + +from ideadatamodel import User, constants +from ideadatamodel.api.api_model import ApiAuthorizationType + + +@pytest.fixture(scope="class") +def monkeypatch_for_class(request): + request.cls.monkeypatch = MonkeyPatch() + + +@pytest.fixture(scope="class") +def context_for_class(request, context): + request.cls.context = context + + +@pytest.mark.usefixtures("monkeypatch_for_class") +@pytest.mark.usefixtures("context_for_class") +class ApiAuthorizationTests(unittest.TestCase): + def setUp(self): + self.admin_username = "dummy_admin" + self.admin_email = self.admin_username + "@email.com" + self.monkeypatch.setattr( + self.context.accounts.sssd, "ldap_id_mapping", lambda x: "False" + ) + self.monkeypatch.setattr( + self.context.accounts.sssd, + "get_uid_and_gid_for_user", + lambda x: (1000, 1000), + ) + self.context.accounts.create_user( + user=User( + username=self.admin_username, + email=self.admin_email, + role=constants.ADMIN_ROLE, + ), + ) + + self.user_username = "dummy_user" + self.user_email = self.user_username + "@email.com" + self.context.accounts.create_user( + user=User( + username=self.user_username, + email=self.user_email, + role=constants.USER_ROLE, + ), + ) + + def test_api_auth_service_get_authorization_app_passes(self): + token = { + "username": "", + } + api_authorization = self.context.api_authorization_service.get_authorization( + token + ) + assert api_authorization.type == ApiAuthorizationType.APP + assert not api_authorization.username + + def test_api_auth_service_get_authorization_admin_passes(self): + token = {"username": self.admin_username} + api_authorization = self.context.api_authorization_service.get_authorization( + token + ) + assert api_authorization.type == ApiAuthorizationType.ADMINISTRATOR + assert api_authorization.username == self.admin_username + + def test_api_auth_service_get_authorization_user_passes(self): + token = {"username": self.user_username} + api_authorization = self.context.api_authorization_service.get_authorization( + token + ) + assert api_authorization.type == ApiAuthorizationType.USER + assert api_authorization.username == self.user_username + + def tearDown(self): + self.context.accounts.delete_user(self.admin_username) + self.context.accounts.delete_user(self.user_username) diff --git a/source/tests/unit/idea-cluster-manager/test_projects.py b/source/tests/unit/idea-cluster-manager/test_projects.py index c3b6b65..e1cbb3d 100644 --- a/source/tests/unit/idea-cluster-manager/test_projects.py +++ b/source/tests/unit/idea-cluster-manager/test_projects.py @@ -16,6 +16,7 @@ from typing import List, Optional import pytest +from _pytest.monkeypatch import MonkeyPatch from ideaclustermanager import AppContext from ideaclustermanager.app.accounts.account_tasks import GroupMembershipUpdatedTask from ideaclustermanager.app.projects.project_tasks import ( @@ -25,6 +26,7 @@ from ideadatamodel import ( CreateProjectRequest, + DeleteProjectRequest, DisableProjectRequest, EnableProjectRequest, GetProjectRequest, @@ -35,6 +37,8 @@ SocaAnyPayload, SocaKeyValue, User, + VirtualDesktopSession, + VirtualDesktopSoftwareStack, constants, errorcodes, exceptions, @@ -99,7 +103,19 @@ def remove_member(context: AppContext, user: User, project: Project): @pytest.fixture(scope="module") -def membership(context: AppContext): +def monkey_session(request): + mp = MonkeyPatch() + yield mp + mp.undo() + + +@pytest.fixture(scope="module") +def membership(context: AppContext, monkey_session): + monkey_session.setattr( + context.accounts.sssd, "get_uid_and_gid_for_user", lambda x: (1000, 1000) + ) + monkey_session.setattr(context.accounts.sssd, "get_gid_for_group", lambda x: 1000) + def create_group(group_name: str) -> Group: group = context.accounts.create_group( Group( @@ -112,13 +128,16 @@ def create_group(group_name: str) -> Group: return group def create_project( - project_name: str, project_title: str, group_names: List[str] + project_name: str, project_title: str, group_names: List[str], users=[] ) -> Project: # create project result = context.projects.create_project( CreateProjectRequest( project=Project( - name=project_name, title=project_title, ldap_groups=group_names + name=project_name, + title=project_title, + ldap_groups=group_names, + users=users, ) ) ) @@ -135,6 +154,7 @@ def create_project( assert result.project is not None assert result.project.enabled is True + ProjectsTestContext.project = result.project return result.project @@ -356,6 +376,53 @@ def test_projects_crud_list_projects(context): assert found is not None +def test_projects_crud_delete_project(context, membership): + assert ProjectsTestContext.crud_project is not None + add_member(context, membership.user_1, ProjectsTestContext.crud_project) + enable_project(context, ProjectsTestContext.crud_project) + assert is_memberof(context, membership.user_1, ProjectsTestContext.crud_project) + + # First delete a project which are still used by sessions or software stacks + context.projects.vdc_client.sessions = [VirtualDesktopSession()] + context.projects.vdc_client.software_stacks = [VirtualDesktopSoftwareStack()] + with pytest.raises(exceptions.SocaException) as excinfo: + context.projects.delete_project( + DeleteProjectRequest(project_id=ProjectsTestContext.crud_project.project_id) + ) + assert excinfo.value.error_code == "GENERAL_ERROR" + + result = context.projects.get_project( + GetProjectRequest(project_id=ProjectsTestContext.crud_project.project_id) + ) + assert result.project is not None + for group in ProjectsTestContext.crud_project.ldap_groups: + projects = context.projects.user_projects_dao.get_projects_by_group_name(group) + assert len(projects) > 0 + + # Next delete a project which is not in use by any session or software stack + context.projects.vdc_client.sessions = [] + context.projects.vdc_client.software_stacks = [] + context.projects.delete_project( + DeleteProjectRequest(project_id=ProjectsTestContext.crud_project.project_id) + ) + with pytest.raises(exceptions.SocaException) as excinfo: + context.projects.get_project( + GetProjectRequest(project_id=ProjectsTestContext.crud_project.project_id) + ) + assert excinfo.value.error_code == "PROJECT_NOT_FOUND" + + # Project should have been removed from the project-groups table + for group in ProjectsTestContext.crud_project.ldap_groups: + projects = context.projects.user_projects_dao.get_projects_by_group_name(group) + assert len(projects) == 0 + + # Project should have been removed from the user-project table + projects = context.projects.user_projects_dao.get_projects_by_username( + membership.user_1.username + ) + assert len(projects) == 0 + + def test_projects_membership_setup(context, membership): """ check if membership setup data is valid and tests are starting with a clean slate. @@ -394,6 +461,7 @@ def test_projects_membership_member_removed(context, membership, monkeypatch): """ monkeypatch.setattr(context.ldap_client, "is_readonly", lambda: True) + remove_member(context, membership.user_1, membership.project_a) assert is_memberof(context, membership.user_1, membership.project_a) is False @@ -444,6 +512,7 @@ def test_projects_membership_multiple_projects(context, membership, monkeypatch) """ monkeypatch.setattr(context.ldap_client, "is_readonly", lambda: True) + # pre-requisites clear_memberships(context, membership, membership.user_1) clear_memberships(context, membership, membership.user_2) diff --git a/source/tests/unit/idea-cluster-manager/test_single_sign_on_helper.py b/source/tests/unit/idea-cluster-manager/test_single_sign_on_helper.py index b88faba..5a562f2 100644 --- a/source/tests/unit/idea-cluster-manager/test_single_sign_on_helper.py +++ b/source/tests/unit/idea-cluster-manager/test_single_sign_on_helper.py @@ -87,6 +87,22 @@ def test_create_or_update_identity_provider_missing_provider_name(context: AppCo assert "provider_name is required" in exc_info.value.message +@patch.object(saml_payload, "provider_name", "Cognito") +def test_create_or_update_identity_provider_invalid_provider_name(context: AppContext): + """ + create_or_update_user_pool_client with missing provider name + """ + with pytest.raises(exceptions.SocaException) as exc_info: + context.accounts.single_sign_on_helper.create_or_update_identity_provider( + saml_payload + ) + + assert exc_info.value.error_code == errorcodes.INVALID_PARAMS + assert ( + constants.SSO_SOURCE_PROVIDER_NAME_ERROR_MESSAGE == exc_info.value.message + ) + + @patch.object(saml_payload, "provider_type", "") def test_create_or_update_identity_provider_missing_provider_type(context: AppContext): """ diff --git a/source/tests/unit/idea-sdk/test_api_invocation_context.py b/source/tests/unit/idea-sdk/test_api_invocation_context.py index c447425..5afed95 100644 --- a/source/tests/unit/idea-sdk/test_api_invocation_context.py +++ b/source/tests/unit/idea-sdk/test_api_invocation_context.py @@ -13,14 +13,20 @@ Test Cases for ApiInvocationContext """ -from typing import Dict +import os +from typing import Dict, Mapping, Optional import pytest from ideasdk.api import ApiInvocationContext -from ideasdk.auth import TokenService +from ideasdk.auth.api_authorization_service_base import ApiAuthorizationServiceBase from ideasdk.context import SocaContext, SocaContextOptions +from ideasdk.protocols import TokenServiceProtocol from ideasdk.utils import GroupNameHelper, Utils +from ideatestutils.api_authorization_service.mock_api_authorization_service import ( + MockApiAuthorizationService, +) from ideatestutils.config.mock_config import MockConfig +from ideatestutils.token_service.mock_token_service import MockTokenService from ideadatamodel import constants, errorcodes, exceptions @@ -34,9 +40,11 @@ def context(): def build_invocation_context( context: SocaContext, payload: Dict, + http_headers: Optional[Mapping] = None, invocation_source: str = None, token: Dict = None, - token_service: TokenService = None, + token_service: TokenServiceProtocol = None, + api_authorization_service: ApiAuthorizationServiceBase = None, ) -> ApiInvocationContext: if Utils.is_empty(invocation_source): invocation_source = constants.API_INVOCATION_SOURCE_HTTP @@ -44,11 +52,13 @@ def build_invocation_context( return ApiInvocationContext( context=context, request=payload, + http_headers=http_headers, invocation_source=invocation_source, group_name_helper=GroupNameHelper(context=context), logger=context.logger(), token=token, token_service=token_service, + api_authorization_service=api_authorization_service, ) @@ -83,3 +93,38 @@ def test_api_invocation_context_validate_request_invalid(context): with pytest.raises(exceptions.SocaException) as exc_info: api_context.validate_request() assert exc_info.value.error_code == errorcodes.INVALID_PARAMS + + +def test_api_invocation_context_get_authorization_in_test_mode_valid(context): + """ + get API authorization in test mode - valid API authorization + """ + test_username = "username" + mock_token_service = MockTokenService() + mock_api_authorization_service = MockApiAuthorizationService() + api_context = build_invocation_context( + context=context, + payload={ + "header": {"namespace": "Hello.World", "request_id": Utils.uuid()}, + "payload": {}, + }, + http_headers={ + "X_RES_TEST_USERNAME": test_username, + }, + token_service=mock_token_service, + api_authorization_service=mock_api_authorization_service, + ) + os.environ["RES_TEST_MODE"] = "True" + + authorization = api_context.get_authorization() + + os.environ.pop("RES_TEST_MODE") + decoded_token = { + "username": test_username, + } + assert authorization.username == test_username + assert ( + authorization.type + == mock_api_authorization_service.get_authorization(decoded_token).type + ) + assert authorization.invocation_source == constants.API_INVOCATION_SOURCE_HTTP diff --git a/source/tests/unit/idea-sdk/test_server.py b/source/tests/unit/idea-sdk/test_server.py index 7b54de2..d19eb3f 100644 --- a/source/tests/unit/idea-sdk/test_server.py +++ b/source/tests/unit/idea-sdk/test_server.py @@ -15,7 +15,11 @@ from ideasdk.api import ApiInvocationContext, BaseAPI from ideasdk.client import SocaClient, SocaClientOptions from ideasdk.context import SocaContext, SocaContextOptions -from ideasdk.protocols import ApiInvokerProtocol, TokenServiceProtocol +from ideasdk.protocols import ( + ApiAuthorizationServiceProtocol, + ApiInvokerProtocol, + TokenServiceProtocol, +) from ideasdk.server import SocaServer, SocaServerOptions from ideasdk.utils import Utils from ideatestutils import MockConfig @@ -134,6 +138,11 @@ def __init__(self): def get_token_service(self) -> Optional[TokenServiceProtocol]: return None + def get_api_authorization_service( + self, + ) -> Optional[ApiAuthorizationServiceProtocol]: + return None + def invoke(self, context: ApiInvocationContext): namespace = context.namespace if namespace.startswith("Calculator."): diff --git a/source/tests/unit/infrastructure/install/conftest.py b/source/tests/unit/infrastructure/install/conftest.py index 4ebed04..04b21b2 100644 --- a/source/tests/unit/infrastructure/install/conftest.py +++ b/source/tests/unit/infrastructure/install/conftest.py @@ -8,7 +8,7 @@ from aws_cdk.assertions import Template from idea.infrastructure.install.installer import Installer -from idea.infrastructure.install.parameters.parameters import Parameters +from idea.infrastructure.install.parameters.parameters import RESParameters from idea.infrastructure.install.stack import InstallStack @@ -52,7 +52,7 @@ def stack( return InstallStack( app, "IDEAInstallStack", - parameters=Parameters(cluster_name=cluster_name), + parameters=RESParameters(cluster_name=cluster_name), registry_name=registry_name, dynamodb_kms_key_alias=dynamodb_kms_key_alias, env=env, diff --git a/source/tests/unit/infrastructure/install/parameters/test_parameters.py b/source/tests/unit/infrastructure/install/parameters/test_parameters.py index 245ce60..6eb2086 100644 --- a/source/tests/unit/infrastructure/install/parameters/test_parameters.py +++ b/source/tests/unit/infrastructure/install/parameters/test_parameters.py @@ -8,12 +8,12 @@ from idea.infrastructure.install.parameters.base import Attributes, Base from idea.infrastructure.install.parameters.common import CommonKey from idea.infrastructure.install.parameters.directoryservice import DirectoryServiceKey -from idea.infrastructure.install.parameters.parameters import Parameters +from idea.infrastructure.install.parameters.parameters import RESParameters from idea.infrastructure.install.stack import InstallStack def test_parameters_are_generated(cluster_name: str) -> None: - parameters = Parameters(cluster_name=cluster_name) + parameters = RESParameters(cluster_name=cluster_name) env = aws_cdk.Environment(account="111111111111", region="us-east-1") app = aws_cdk.App() InstallStack( @@ -26,26 +26,31 @@ def test_parameters_are_generated(cluster_name: str) -> None: assert parameters._generated assert parameters.cluster_name == cluster_name assert parameters.get(CommonKey.CLUSTER_NAME).default == cluster_name - assert parameters.get(CommonKey.PRIVATE_SUBNETS).default is None + assert parameters.get(CommonKey.INFRASTRUCTURE_HOST_SUBNETS).default is None def test_parameters_can_be_passed_via_context() -> None: - parameters = Parameters(cluster_name="foo", private_subnets=["a", "b"]) + parameters = RESParameters( + cluster_name="foo", infrastructure_host_subnets=["a", "b"] + ) stack = aws_cdk.Stack() for key, value in parameters.to_context().items(): stack.node.set_context(key, value) - context_params = Parameters.from_context(stack) + context_params = RESParameters.from_context(stack) assert context_params.cluster_name == parameters.cluster_name - assert context_params.private_subnets == parameters.private_subnets + assert ( + context_params.infrastructure_host_subnets + == parameters.infrastructure_host_subnets + ) def test_parameter_list_default_set_correctly() -> None: - parameters = Parameters(private_subnets=["a", "b"]) + parameters = RESParameters(infrastructure_host_subnets=["a", "b"]) parameters.generate(aws_cdk.Stack()) - assert parameters.get(CommonKey.PRIVATE_SUBNETS).default == "a,b" + assert parameters.get(CommonKey.INFRASTRUCTURE_HOST_SUBNETS).default == "a,b" def test_fields_only_includes_base_parameters() -> None: @@ -68,7 +73,7 @@ def test_parameters_only_generates_cfn_parameters_for_base_parameter_attributes( parameters = template.find_parameters(logical_id="*") cfn_keys = set(parameters.keys()) - defined_keys = set(attributes.id.value for _, attributes in Parameters._fields()) + defined_keys = set(attributes.id.value for _, attributes in RESParameters._fields()) assert cfn_keys == defined_keys diff --git a/source/tests/unit/infrastructure/install/test_permissions.py b/source/tests/unit/infrastructure/install/test_permissions.py index e3f7fb6..03efd5a 100644 --- a/source/tests/unit/infrastructure/install/test_permissions.py +++ b/source/tests/unit/infrastructure/install/test_permissions.py @@ -42,17 +42,15 @@ def assume_role_policy_document() -> dict[str, Any]: } -@pytest.mark.parametrize("role", ("InstallRole", "UpdateRole", "DeleteRole")) def test_role_creation( stack: InstallStack, template: Template, assume_role_policy_document: dict[str, Any], - role: str, ) -> None: util.assert_resource_name_has_correct_type_and_props( stack, template, - resources=["Installer", "Tasks", "Permissions", role], + resources=["Installer", "Tasks", "Permissions", "PipelineRole"], cfn_type="AWS::IAM::Role", props={ "Properties": { @@ -60,7 +58,7 @@ def test_role_creation( "RoleName": { "Fn::Join": [ "", - ["Admin-", {"Ref": CommonKey.CLUSTER_NAME}, f"-{role}"], + ["Admin-", {"Ref": CommonKey.CLUSTER_NAME}, f"-PipelineRole"], ] }, }, @@ -69,14 +67,20 @@ def test_role_creation( util.assert_resource_name_has_correct_type_and_props( stack, template, - resources=["Installer", "Tasks", "Permissions", role, "DefaultPolicy"], + resources=[ + "Installer", + "Tasks", + "Permissions", + "PipelineRole", + "DefaultPolicy", + ], cfn_type="AWS::IAM::Policy", props={ "Properties": { "Roles": [ { "Ref": util.get_logical_id( - stack, ["Installer", "Tasks", "Permissions", role] + stack, ["Installer", "Tasks", "Permissions", "PipelineRole"] ) } ] diff --git a/source/tests/unit/infrastructure/install/test_tasks.py b/source/tests/unit/infrastructure/install/test_tasks.py index eaeb2d3..db67e1c 100644 --- a/source/tests/unit/infrastructure/install/test_tasks.py +++ b/source/tests/unit/infrastructure/install/test_tasks.py @@ -68,7 +68,7 @@ def test_create_task_creation( "TaskRoleArn": { "Fn::GetAtt": [ util.get_logical_id( - stack, ["Installer", "Tasks", "Permissions", "InstallRole"] + stack, ["Installer", "Tasks", "Permissions", "PipelineRole"] ), "Arn", ] @@ -104,7 +104,7 @@ def test_update_task_creation( "TaskRoleArn": { "Fn::GetAtt": [ util.get_logical_id( - stack, ["Installer", "Tasks", "Permissions", "UpdateRole"] + stack, ["Installer", "Tasks", "Permissions", "PipelineRole"] ), "Arn", ] @@ -140,7 +140,7 @@ def test_delete_task_creation( "TaskRoleArn": { "Fn::GetAtt": [ util.get_logical_id( - stack, ["Installer", "Tasks", "Permissions", "DeleteRole"] + stack, ["Installer", "Tasks", "Permissions", "PipelineRole"] ), "Arn", ] diff --git a/source/tests/unit/pipeline/test_pipeline_stack.py b/source/tests/unit/pipeline/test_pipeline_stack.py index cb26b0a..d8adf0b 100644 --- a/source/tests/unit/pipeline/test_pipeline_stack.py +++ b/source/tests/unit/pipeline/test_pipeline_stack.py @@ -5,7 +5,7 @@ import pytest from aws_cdk import assertions -from idea.infrastructure.install.parameters.parameters import Parameters +from idea.infrastructure.install.parameters.parameters import RESParameters from idea.infrastructure.install.stack import PUBLIC_REGISTRY_NAME from idea.pipeline.stack import DeployStage, PipelineStack @@ -26,13 +26,13 @@ def test_pipeline_created(template: assertions.Template) -> None: def test_registry_name_set_correctly_from_context() -> None: # No context should be public registry name - stage = DeployStage(aws_cdk.App(), "Stage", Parameters()) + stage = DeployStage(aws_cdk.App(), "Stage", RESParameters()) assert stage.install_stack.registry_name == PUBLIC_REGISTRY_NAME # context should override app = aws_cdk.App() app.node.set_context("registry_name", "foo") - stage = DeployStage(app, "DeployStage", Parameters()) + stage = DeployStage(app, "DeployStage", RESParameters()) assert stage.install_stack.registry_name == "foo" diff --git a/tasks/admin.py b/tasks/admin.py index 80352e8..b7e647f 100644 --- a/tasks/admin.py +++ b/tasks/admin.py @@ -54,12 +54,6 @@ def test_iam_policies(c, cluster_name, aws_region, aws_profile=None): 'custom-resource-get-user-pool-client-secret.yml' ] }, - { - 'module_name': constants.MODULE_ANALYTICS, - 'templates': [ - 'analytics-stream-processing-lambda.yml' - ] - }, { 'module_name': constants.MODULE_DIRECTORYSERVICE, 'templates': [ diff --git a/tasks/idea.py b/tasks/idea.py index 5baf0bd..f9800f9 100644 --- a/tasks/idea.py +++ b/tasks/idea.py @@ -136,11 +136,15 @@ def requirements_dir(self) -> str: @property def administrator_project_dir(self) -> str: return os.path.join(self.project_source_dir, 'idea-administrator') - + @property def administrator_integ_tests_dir(self) -> str: return os.path.join(self.administrator_project_dir, 'src', 'ideaadministrator', 'integration_tests') + @property + def end_to_end_integ_tests_dir(self) -> str: + return os.path.join(self.project_root_dir, 'source', 'tests', 'integration', 'tests') + @property def deployment_ecr_dir(self) -> str: return os.path.join(self.project_deployment_dir, 'ecr') diff --git a/tasks/integ_tests.py b/tasks/integ_tests.py index 39a5352..30c7c92 100644 --- a/tasks/integ_tests.py +++ b/tasks/integ_tests.py @@ -13,10 +13,10 @@ import tasks.idea as idea from invoke import task, Context import os -from typing import List +from typing import List, Optional -def _run_integ_tests( +def _run_component_integ_tests( c: Context, component_name: str, component_src: str, @@ -31,7 +31,44 @@ def _run_integ_tests( """ Currently requires ~/.aws/credentials file to be setup in order to run due to boto being unable to use ~/.aws/config. """ - test_params = ["--module", component_name] + params.append(f"module={component_name}") + if cov_report is not None and cov_report in ( + "term", + "term-missing", + "annotate", + "html", + "xml", + "lcov", + ): + params.append(f"cov={package_name}") + params.append(f"cov-report={cov_report}") + + return _run_integ_tests( + c, + component_name, + component_tests_src, + params, + [component_src], + capture_output, + keywords, + test_file + ) + + +def _run_integ_tests( + c: Context, + test_id: str, + tests_src: str, + params: List[str], + additional_python_path: Optional[List[str]] = None, + capture_output: bool = False, + keywords=None, + test_file=None, +) -> int: + """ + Currently requires ~/.aws/credentials file to be setup in order to run due to boto being unable to use ~/.aws/config. + """ + test_params = [] if params is not None: for param in params: kv = param.split("=") @@ -43,33 +80,24 @@ def _run_integ_tests( if value is not None: test_params += [value] - idea.console.print_header_block(f"executing integ tests for: {component_name}") + idea.console.print_header_block(f"executing integ tests for: {test_id}") python_path = [ idea.props.project_root_dir, idea.props.data_model_src, idea.props.sdk_src, idea.props.test_utils_src, ] - if component_src not in python_path: - python_path.append(component_src) - if component_tests_src not in python_path: - python_path.append(component_tests_src) + if tests_src not in python_path: + python_path.append(tests_src) + if additional_python_path: + python_path = list(set(python_path + additional_python_path)) - with c.cd(component_tests_src): + with c.cd(tests_src): cmd = f'pytest -v --disable-warnings {test_file} {" ".join(test_params)}' if capture_output: cmd = f"{cmd} --capture=tee-sys" if keywords is not None: cmd = f'{cmd} -k "{keywords}"' - if cov_report is not None and cov_report in ( - "term", - "term-missing", - "annotate", - "html", - "xml", - "lcov", - ): - cmd = f"{cmd} --cov {package_name} --cov-report {cov_report}" idea.console.info(f"> {cmd}") try: @@ -92,7 +120,7 @@ def cluster_manager( """ run cluster-manager integ tests """ - exit_code = _run_integ_tests( + exit_code = _run_component_integ_tests( c=c, component_name="cluster-manager", component_src=idea.props.cluster_manager_src, @@ -105,3 +133,22 @@ def cluster_manager( test_file="run_integ_tests.py", ) raise SystemExit(exit_code) + +@task(iterable=["params"]) +def smoke( + c, keywords=None, params=None, capture_output=False, cov_report=None +): + # type: (Context, str, List[str], bool, str) -> None + """ + run smoke tests + """ + exit_code = _run_integ_tests( + c=c, + test_id="smoke", + tests_src=idea.props.end_to_end_integ_tests_dir, + params=params, + capture_output=capture_output, + keywords=keywords, + test_file="smoke.py", + ) + raise SystemExit(exit_code) diff --git a/tasks/web_portal.py b/tasks/web_portal.py index 7f6854d..4e82990 100644 --- a/tasks/web_portal.py +++ b/tasks/web_portal.py @@ -58,7 +58,6 @@ def add_model(module: str): add_model('virtual_desktop') add_model('cluster_settings') add_model('projects') - add_model('analytics') add_model('app') add_model('email_templates') add_model('notifications') diff --git a/tox.ini b/tox.ini index b6cb97f..a98ac96 100644 --- a/tox.ini +++ b/tox.ini @@ -14,7 +14,7 @@ env_list = description = run lint checks skip_install = true deps = - black + black~=24.1.0 pytest isort commands = @@ -81,7 +81,7 @@ commands = [testenv:coverage] description = combine and generate code coverage report -allowlist_externals = +allowlist_externals = coverage cat rm @@ -101,7 +101,7 @@ commands = coverage run --parallel-mode -m pytest -v source/tests/unit/pipeline {posargs} coverage run --parallel-mode -m pytest -v source/tests/unit/infrastructure {posargs} coverage combine - coverage html -i + coverage html -i python source/idea/pipeline/scripts/helpers/generate_cov_report.py cat summary_report.txt @@ -139,3 +139,13 @@ deps = -rrequirements/dev.txt commands = integ-tests.cluster-manager: invoke integ-tests.cluster-manager {posargs} + +[testenv:integ-tests.smoke] +description = run smoke tests +skip_install = true +set_env = + LC_CTYPE=en_US.UTF-8 +deps = + -rrequirements/dev.txt +commands = + integ-tests.smoke: invoke integ-tests.smoke {posargs}