-
Notifications
You must be signed in to change notification settings - Fork 0
/
aws_utils.py
125 lines (100 loc) · 4.78 KB
/
aws_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""AWS Utils and helper."""
import asyncio
import os
import time
import boto3
from boto3.dynamodb.types import TypeDeserializer
from botocore.exceptions import ClientError
from core import SingletonMeta
from enums import Env
class AWSUtils(metaclass=SingletonMeta):
"""AWSUtils class"""
def __init__(self, config=None):
self.config = config
@staticmethod
def get_secret(secret_name: str, region_name: str = "eu-central-1") -> str:
"""Get a secret value from AWS Secrets Manager."""
# secret_name = "prod/telegram-gas-tracker/AWS_SECRET_ACCESS_KEY"
# region_name = "eu-central-1"
# Create a Secrets Manager client
session = boto3.session.Session()
client = session.client(service_name="secretsmanager", region_name=region_name)
try:
get_secret_value_response = client.get_secret_value(SecretId=secret_name)
except ClientError as e:
# For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
raise e
return get_secret_value_response["SecretString"]
@staticmethod
def get_secret_value(secret_string: str) -> str:
"""Get a secret value from AWS Secrets Manager given the payload's SecretString"""
if not secret_string:
raise ValueError("Empty AWS Secret.")
try:
return secret_string.split('"')[3]
except IndexError:
# Attempt to return the secret string as is
# it might be a plain text secret, or from env vars
return secret_string
async def ensure_credentials_file(self, home_dir: str = "/root"):
"""Ensure the AWS credentials are loaded."""
lock = asyncio.Lock()
await lock.acquire()
try:
# wait that the aws credentials are loaded
for _ in range(self.config.aws_credentials_timeout):
if os.path.isfile(f"{home_dir}/.aws/credentials"):
return
time.sleep(5)
raise TimeoutError("Timeout waiting for the AWS credentials")
finally:
lock.release()
@staticmethod
def load_credentials():
"""Load AWS credentials from environment variables."""
aws_access_key_id = os.getenv("AWS_ACCESS_KEY_ID")
aws_secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY")
if not aws_access_key_id or not aws_secret_access_key:
raise ValueError("AWS credentials not found in environment variables.")
if os.getenv("CLOUD_PROVIDER") == Env.AWS.value:
aws_access_key_id = AWSUtils.get_secret_value(aws_access_key_id)
aws_secret_access_key = AWSUtils.get_secret_value(aws_secret_access_key)
return aws_access_key_id, aws_secret_access_key
async def generate_credentials_file(self, home_dir: str = "/root", attempt=0):
"""Generate the AWS credentials file."""
aws_access_key_id, aws_secret_access_key = self.load_credentials()
if not aws_access_key_id or not aws_secret_access_key:
# wait that the aws credentials are loaded
while attempt <= self.config.aws_credentials_timeout:
print(
f"Waiting to load the AWS credentials ({attempt}/{self.config.aws_credentials_timeout})"
)
attempt = attempt + 1
await asyncio.sleep(1)
await self.generate_credentials_file(home_dir, attempt)
raise ValueError("AWS credentials not loaded from environment variables")
aws_config_path = f"{home_dir}/.aws"
os.makedirs(aws_config_path, exist_ok=True)
# create the credentials file
with open(f"{aws_config_path}/credentials", "w", encoding="utf-8") as f:
f.write(
f"[default]\naws_access_key_id = {aws_access_key_id}\naws_secret_access_key = {aws_secret_access_key}\n" # pylint: disable=line-too-long
)
f.close()
@staticmethod
def is_aws_environment() -> bool:
"""Check if the environment is AWS."""
return os.getenv("CLOUD_PROVIDER") == Env.AWS.value
def deserialize_dynamodb_json(self, node):
"""Helper function to convert AWS DynamoDB items to JSON serializable format"""
deserializer = TypeDeserializer()
if isinstance(node, list):
return [self.deserialize_dynamodb_json(n) for n in node]
if isinstance(node, dict):
if ":n" in node: # It's a number; DynamoDB uses this notation
return float(
deserializer.deserialize(node)
) # Convert Decimal to float for JSON serialization
return {k: self.deserialize_dynamodb_json(v) for k, v in node.items()}
return node