diff --git a/src/api-service/__app__/onefuzzlib/orm.py b/src/api-service/__app__/onefuzzlib/orm.py index 6e4257e24f..0f30da4224 100644 --- a/src/api-service/__app__/onefuzzlib/orm.py +++ b/src/api-service/__app__/onefuzzlib/orm.py @@ -40,7 +40,7 @@ from typing_extensions import Protocol from .azure.table import get_client -from .secrets import save_to_keyvault +from .secrets import delete_remote_secret_data, save_to_keyvault from .telemetry import track_event_filtered from .updates import queue_update @@ -249,6 +249,28 @@ def hide_secrets(data: B, hider: Callable[[SecretData], SecretData]) -> B: return data +# NOTE: the actual deletion must come from the `deleter` callback function +def delete_secrets(data: B, deleter: Callable[[SecretData], None]) -> None: + for field in data.__fields__: + field_data = getattr(data, field) + if isinstance(field_data, SecretData): + deleter(field_data) + elif isinstance(field_data, BaseModel): + delete_secrets(field_data, deleter) + elif isinstance(field_data, list): + for entry in field_data: + if isinstance(entry, BaseModel): + delete_secrets(entry, deleter) + elif isinstance(entry, SecretData): + deleter(entry) + elif isinstance(field_data, dict): + for value in field_data.values(): + if isinstance(value, BaseModel): + delete_secrets(value, deleter) + elif isinstance(value, SecretData): + deleter(value) + + # NOTE: if you want to include Timestamp in a model that uses ORMMixin, # it must be maintained as part of the model. class ORMMixin(ModelMixin): @@ -363,6 +385,8 @@ def save(self, new: bool = False, require_etag: bool = False) -> Optional[Error] def delete(self) -> None: partition_key, row_key = self.get_keys() + delete_secrets(self, delete_remote_secret_data) + client = get_client() try: client.delete_entity( diff --git a/src/api-service/__app__/onefuzzlib/secrets.py b/src/api-service/__app__/onefuzzlib/secrets.py index eaba556e53..f7904b0706 100644 --- a/src/api-service/__app__/onefuzzlib/secrets.py +++ b/src/api-service/__app__/onefuzzlib/secrets.py @@ -80,3 +80,8 @@ def delete_secret(secret_url: str) -> None: (vault_url, secret_name) = parse_secret_url(secret_url) keyvault_client = get_keyvault_client(vault_url) keyvault_client.begin_delete_secret(secret_name) + + +def delete_remote_secret_data(data: SecretData) -> None: + if isinstance(data.secret, SecretAddress): + delete_secret(data.secret.url) diff --git a/src/api-service/tests/test_secrets.py b/src/api-service/tests/test_secrets.py index e484cf9ee1..9792a3c72d 100644 --- a/src/api-service/tests/test_secrets.py +++ b/src/api-service/tests/test_secrets.py @@ -6,6 +6,7 @@ import json import pathlib import unittest +from typing import Dict, List from onefuzztypes.enums import OS, ContainerType from onefuzztypes.job_templates import ( @@ -26,8 +27,9 @@ ) from onefuzztypes.primitives import Container from onefuzztypes.requests import NotificationCreate +from pydantic import BaseModel -from __app__.onefuzzlib.orm import hide_secrets +from __app__.onefuzzlib.orm import delete_secrets, hide_secrets def hider(secret_data: SecretData) -> SecretData: @@ -36,7 +38,72 @@ def hider(secret_data: SecretData) -> SecretData: return secret_data +class WithSecret(BaseModel): + a: SecretData[str] + + +class WithList(BaseModel): + a: List[WithSecret] + + +class WithDict(BaseModel): + a: Dict[str, SecretData[str]] + b: Dict[str, WithSecret] + + +class Nested(BaseModel): + a: WithSecret + b: WithDict + c: WithList + + class TestSecret(unittest.TestCase): + def test_delete(self) -> None: + self.count = 0 + + def deleter(secret_data: SecretData) -> None: + self.count += 1 + + delete_secrets(WithSecret(a=SecretData("a")), deleter) + self.assertEqual(self.count, 1) + + delete_secrets( + WithList( + a=[ + WithSecret(a=SecretData("a")), + WithSecret(a=SecretData("b")), + ] + ), + deleter, + ) + self.assertEqual(self.count, 3) + + delete_secrets( + WithDict( + a={"a": SecretData("a"), "b": SecretData("b")}, + b={ + "a": WithSecret(a=SecretData("a")), + "b": WithSecret(a=SecretData("a")), + }, + ), + deleter, + ) + self.assertEqual(self.count, 7) + + delete_secrets( + Nested( + a=WithSecret(a=SecretData("a")), + b=WithDict( + a={"a": SecretData("a")}, b={"a": WithSecret(a=SecretData("a"))} + ), + c=WithList( + a=[WithSecret(a=SecretData("a")), WithSecret(a=SecretData("b"))] + ), + ), + deleter, + ) + self.assertEqual(self.count, 12) + def test_hide(self) -> None: notification = Notification( container=Container("data"),