Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

Commit

Permalink
Storing secrets in azure keyvault (#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
chkeita authored Jan 25, 2021
1 parent dc31ffc commit 3f2883d
Show file tree
Hide file tree
Showing 12 changed files with 359 additions and 29 deletions.
4 changes: 0 additions & 4 deletions src/api-service/__app__/notifications/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@

def get(req: func.HttpRequest) -> func.HttpResponse:
entries = Notification.search()
for entry in entries:
entry.config.redact()

return ok(entries)


Expand All @@ -46,7 +43,6 @@ def delete(req: func.HttpRequest) -> func.HttpResponse:
return not_ok(entry, context="notification delete")

entry.delete()
entry.config.redact()
return ok(entry)


Expand Down
7 changes: 7 additions & 0 deletions src/api-service/__app__/onefuzzlib/azure/creds.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from azure.common.client_factory import get_client_from_cli_profile
from azure.graphrbac import GraphRbacManagementClient
from azure.graphrbac.models import CheckGroupMembershipParameters
from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.subscription import SubscriptionClient
from memoization import cached
Expand Down Expand Up @@ -134,3 +136,8 @@ def get_scaleset_principal_id() -> UUID:
client = mgmt_client_factory(ResourceManagementClient)
uid = client.resources.get_by_id(get_scaleset_identity_resource_path(), api_version)
return UUID(uid.properties["principalId"])


@cached
def get_keyvault_client(vault_url: str) -> SecretClient:
return SecretClient(vault_url=vault_url, credential=DefaultAzureCredential())
4 changes: 3 additions & 1 deletion src/api-service/__app__/onefuzzlib/notifications/ado.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from onefuzztypes.models import ADOTemplate, Report
from onefuzztypes.primitives import Container

from ..secrets import get_secret_string_value
from .common import Render, fail_task


Expand Down Expand Up @@ -54,7 +55,8 @@ def __init__(
):
self.config = config
self.renderer = Render(container, filename, report)
self.client = get_ado_client(self.config.base_url, self.config.auth_token)
auth_token = get_secret_string_value(self.config.auth_token)
self.client = get_ado_client(self.config.base_url, auth_token)
self.project = self.render(self.config.project)

def render(self, template: str) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from github3.exceptions import GitHubException
from github3.issues import Issue
from onefuzztypes.enums import GithubIssueSearchMatch
from onefuzztypes.models import GithubIssueTemplate, Report
from onefuzztypes.models import GithubAuth, GithubIssueTemplate, Report
from onefuzztypes.primitives import Container

from ..secrets import get_secret_obj
from .common import Render, fail_task


Expand All @@ -26,9 +27,12 @@ def __init__(
):
self.config = config
self.report = report
self.gh = login(
username=config.auth.user, password=config.auth.personal_access_token
)
if isinstance(config.auth.secret, GithubAuth):
auth = config.auth.secret
else:
auth = get_secret_obj(config.auth.secret.url, GithubAuth)

self.gh = login(username=auth.user, password=auth.personal_access_token)
self.renderer = Render(container, filename, report)

def render(self, field: str) -> str:
Expand Down
4 changes: 3 additions & 1 deletion src/api-service/__app__/onefuzzlib/notifications/teams.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from onefuzztypes.primitives import Container

from ..azure.containers import auth_download_url
from ..secrets import get_secret_string_value
from ..tasks.config import get_setup_container
from ..tasks.main import Task

Expand Down Expand Up @@ -46,7 +47,8 @@ def send_teams_webhook(
if text:
message["sections"].append({"text": text})

response = requests.post(config.url, json=message)
config_url = get_secret_string_value(config.url)
response = requests.post(config_url, json=message)
if not response.ok:
logging.error("webhook failed %s %s", response.status_code, response.content)

Expand Down
39 changes: 36 additions & 3 deletions src/api-service/__app__/onefuzzlib/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
List,
Mapping,
Optional,
Set,
Tuple,
Type,
TypeVar,
Expand All @@ -33,11 +34,12 @@
UpdateType,
VmState,
)
from onefuzztypes.models import Error
from onefuzztypes.models import Error, SecretData
from onefuzztypes.primitives import Container, PoolName, Region
from pydantic import BaseModel, Field
from typing_extensions import Protocol

from ..onefuzzlib.secrets import save_to_keyvault
from .azure.table import get_client
from .telemetry import track_event_filtered
from .updates import queue_update
Expand Down Expand Up @@ -268,18 +270,49 @@ def get_keys(self) -> Tuple[KEY, KEY]:

return (partition_key, row_key)

@classmethod
def hide_secrets(
cls,
model: BaseModel,
hider: Callable[["SecretData"], None],
visited: Set[int] = set(),
) -> None:
if id(model) in visited:
return

visited.add(id(model))
for field in model.__fields__:
field_data = getattr(model, field)
if isinstance(field_data, SecretData):
hider(field_data)
elif isinstance(field_data, List):
if len(field_data) > 0:
if not isinstance(field_data[0], BaseModel):
continue
for data in field_data:
cls.hide_secrets(data, hider, visited)
elif isinstance(field_data, dict):
for key in field_data:
if not isinstance(field_data[key], BaseModel):
continue
cls.hide_secrets(field_data[key], hider, visited)
else:
if isinstance(field_data, BaseModel):
cls.hide_secrets(field_data, hider, visited)

def save(self, new: bool = False, require_etag: bool = False) -> Optional[Error]:
self.__class__.hide_secrets(self, save_to_keyvault)
# TODO: migrate to an inspect.signature() model
raw = self.raw(by_alias=True, exclude_none=True, exclude=self.save_exclude())
for key in raw:
if not isinstance(raw[key], (str, int)):
raw[key] = json.dumps(raw[key])

# for datetime fields that passed through filtering, use the real value,
# rather than a serialized form
for field in self.__fields__:
if field not in raw:
continue
# for datetime fields that passed through filtering, use the real value,
# rather than a serialized form
if self.__fields__[field].type_ == datetime:
raw[field] = getattr(self, field)

Expand Down
79 changes: 79 additions & 0 deletions src/api-service/__app__/onefuzzlib/secrets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.


from typing import Tuple, Type, TypeVar, cast
from urllib.parse import urlparse
from uuid import uuid4

from azure.keyvault.secrets import KeyVaultSecret
from onefuzztypes.models import SecretAddress, SecretData
from pydantic import BaseModel

from .azure.creds import get_instance_name, get_keyvault_client

A = TypeVar("A", bound=BaseModel)


def save_to_keyvault(secret_data: SecretData) -> None:
if isinstance(secret_data.secret, SecretAddress):
return

secret_name = str(uuid4())
if isinstance(secret_data.secret, str):
secret_value = secret_data.secret
elif isinstance(secret_data.secret, BaseModel):
secret_value = secret_data.secret.json()
else:
raise Exception("invalid secret data")

kv = store_in_keyvault(get_keyvault_address(), secret_name, secret_value)
secret_data.secret = SecretAddress(url=kv.id)


def get_secret_string_value(self: SecretData[str]) -> str:
if isinstance(self.secret, SecretAddress):
secret = get_secret(self.secret.url)
return cast(str, secret.value)
else:
return self.secret


def get_keyvault_address() -> str:
# https://docs.microsoft.com/en-us/azure/key-vault/general/about-keys-secrets-certificates#vault-name-and-object-name
return f"https://{get_instance_name()}-vault.vault.azure.net"


def store_in_keyvault(
keyvault_url: str, secret_name: str, secret_value: str
) -> KeyVaultSecret:
keyvault_client = get_keyvault_client(keyvault_url)
kvs: KeyVaultSecret = keyvault_client.set_secret(secret_name, secret_value)
return kvs


def parse_secret_url(secret_url: str) -> Tuple[str, str]:
# format: https://{vault-name}.vault.azure.net/secrets/{secret-name}/{version}
u = urlparse(secret_url)
vault_url = f"{u.scheme}://{u.netloc}"
secret_name = u.path.split("/")[2]
return (vault_url, secret_name)


def get_secret(secret_url: str) -> KeyVaultSecret:
(vault_url, secret_name) = parse_secret_url(secret_url)
keyvault_client = get_keyvault_client(vault_url)
return keyvault_client.get_secret(secret_name)


def get_secret_obj(secret_url: str, model: Type[A]) -> A:
secret = get_secret(secret_url)
return model.parse_raw(secret.value)


def delete_secret(secret_url: str) -> None:
(vault_url, secret_name) = parse_secret_url(secret_url)
keyvault_client = get_keyvault_client(vault_url)
keyvault_client.begin_delete_secret(secret_name).wait()
91 changes: 91 additions & 0 deletions src/api-service/tests/test_secrets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import json
import unittest

from onefuzztypes.enums import OS, ContainerType
from onefuzztypes.job_templates import (
JobTemplate,
JobTemplateIndex,
JobTemplateNotification,
)
from onefuzztypes.models import (
JobConfig,
Notification,
NotificationConfig,
SecretAddress,
SecretData,
TeamsTemplate,
)
from onefuzztypes.primitives import Container

from __app__.onefuzzlib.orm import ORMMixin


class TestSecret(unittest.TestCase):
def test_hide(self) -> None:
def hider(secret_data: SecretData) -> None:
if not isinstance(secret_data.secret, SecretAddress):
secret_data.secret = SecretAddress(url="blah blah")

notification = Notification(
container=Container("data"),
config=TeamsTemplate(url=SecretData(secret="http://test")),
)
ORMMixin.hide_secrets(notification, hider)

if isinstance(notification.config, TeamsTemplate):
self.assertIsInstance(notification.config.url, SecretData)
self.assertIsInstance(notification.config.url.secret, SecretAddress)
else:
self.fail(f"Invalid config type {type(notification.config)}")

def test_hide_nested_list(self) -> None:
def hider(secret_data: SecretData) -> None:
if not isinstance(secret_data.secret, SecretAddress):
secret_data.secret = SecretAddress(url="blah blah")

job_template_index = JobTemplateIndex(
name="test",
template=JobTemplate(
os=OS.linux,
job=JobConfig(name="test", build="test", project="test", duration=1),
tasks=[],
notifications=[
JobTemplateNotification(
container_type=ContainerType.unique_inputs,
notification=NotificationConfig(
config=TeamsTemplate(url=SecretData(secret="http://test"))
),
)
],
user_fields=[],
),
)
ORMMixin.hide_secrets(job_template_index, hider)
notification = job_template_index.template.notifications[0].notification
if isinstance(notification.config, TeamsTemplate):
self.assertIsInstance(notification.config.url, SecretData)
self.assertIsInstance(notification.config.url.secret, SecretAddress)
else:
self.fail(f"Invalid config type {type(notification.config)}")

def test_read_secret(self) -> None:
json_data = """
{
"notification_id": "b52b24d1-eec6-46c9-b06a-818a997da43c",
"container": "data",
"config" : {"url": {"secret": {"url": "http://test"}}}
}
"""
data = json.loads(json_data)
notification = Notification.parse_obj(data)
self.assertIsInstance(notification.config, TeamsTemplate)
if isinstance(notification.config, TeamsTemplate):
self.assertIsInstance(notification.config.url, SecretData)
self.assertIsInstance(notification.config.url.secret, SecretAddress)
else:
self.fail(f"Invalid config type {type(notification.config)}")
3 changes: 3 additions & 0 deletions src/cli/onefuzz/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import sys
import time
from dataclasses import asdict, is_dataclass
from enum import Enum
from typing import (
Any,
Expand Down Expand Up @@ -381,6 +382,8 @@ def serialize(data: Any) -> Any:
return str(data)
if isinstance(data, (int, str)):
return data
if is_dataclass(data):
return {serialize(a): serialize(b) for (a, b) in asdict(data).items()}

raise Exception("unknown type %s" % type(data))

Expand Down
Loading

0 comments on commit 3f2883d

Please sign in to comment.