From be8ac6b3182bd46a24fd3dfca553a48f168a2de1 Mon Sep 17 00:00:00 2001 From: Jialiang Xu <48697394+liamjxu@users.noreply.github.com> Date: Fri, 20 Dec 2024 09:56:56 -0800 Subject: [PATCH] Add Encryption for GPQA (#3216) --- helm-frontend/src/components/Instances.tsx | 16 +- helm-frontend/src/routes/Run.tsx | 49 +++++ .../services/getDisplayPredictionsByName.ts | 74 ++++++- .../src/services/getDisplayRequestsByName.ts | 75 ++++++- helm-frontend/src/services/getInstances.ts | 80 +++++++- helm-frontend/src/types/EncryptionDataMap.ts | 8 + scripts/decrypt_scenario_states.py | 191 ++++++++++++++++++ scripts/encrypt_scenario_states.py | 187 +++++++++++++++++ 8 files changed, 669 insertions(+), 11 deletions(-) create mode 100644 helm-frontend/src/types/EncryptionDataMap.ts create mode 100644 scripts/decrypt_scenario_states.py create mode 100644 scripts/encrypt_scenario_states.py diff --git a/helm-frontend/src/components/Instances.tsx b/helm-frontend/src/components/Instances.tsx index 03c5135750..c85f2e929a 100644 --- a/helm-frontend/src/components/Instances.tsx +++ b/helm-frontend/src/components/Instances.tsx @@ -23,9 +23,15 @@ interface Props { runName: string; suite: string; metricFieldMap: MetricFieldMap; + userAgreed: boolean; } -export default function Instances({ runName, suite, metricFieldMap }: Props) { +export default function Instances({ + runName, + suite, + metricFieldMap, + userAgreed, +}: Props) { const [searchParams, setSearchParams] = useSearchParams(); const [instances, setInstances] = useState([]); const [displayPredictionsMap, setDisplayPredictionsMap] = useState< @@ -43,9 +49,9 @@ export default function Instances({ runName, suite, metricFieldMap }: Props) { const [instancesResp, displayPredictions, displayRequests] = await Promise.all([ - getInstances(runName, signal, suite), - getDisplayPredictionsByName(runName, signal, suite), - getDisplayRequestsByName(runName, signal, suite), + getInstances(runName, signal, suite, userAgreed), + getDisplayPredictionsByName(runName, signal, suite, userAgreed), + getDisplayRequestsByName(runName, signal, suite, userAgreed), ]); setInstances(instancesResp); @@ -93,7 +99,7 @@ export default function Instances({ runName, suite, metricFieldMap }: Props) { void fetchData(); return () => controller.abort(); - }, [runName, suite]); + }, [runName, suite, userAgreed]); const pagedInstances = instances.slice( (currentInstancesPage - 1) * INSTANCES_PAGE_SIZE, diff --git a/helm-frontend/src/routes/Run.tsx b/helm-frontend/src/routes/Run.tsx index 17fdf799da..49a4eba224 100644 --- a/helm-frontend/src/routes/Run.tsx +++ b/helm-frontend/src/routes/Run.tsx @@ -37,6 +37,9 @@ export default function Run() { MetricFieldMap | undefined >({}); + const [agreeInput, setAgreeInput] = useState(""); + const [userAgreed, setUserAgreed] = useState(false); + useEffect(() => { const controller = new AbortController(); async function fetchData() { @@ -93,6 +96,16 @@ export default function Run() { return ; } + // Handler for agreement + const handleAgreement = () => { + if (agreeInput.trim() === "Yes, I agree") { + setUserAgreed(true); + } else { + setUserAgreed(false); + alert("Please type 'Yes, I agree' exactly."); + } + }; + return ( <>
@@ -178,11 +191,47 @@ export default function Run() {
+ + {activeTab === 0 && runName.includes("gpqa") && !userAgreed && ( +
+
+

+ The GPQA dataset instances are encrypted by default to comply with + the following request: +

+
+ “We ask that you do not reveal examples from this dataset in plain + text or images online, to minimize the risk of these instances being + included in foundation model training corpora.” +
+

+ If you agree to this condition, please type{" "} + "Yes, I agree" in the box below and then click{" "} + Decrypt. +

+
+ setAgreeInput(e.target.value)} + className="input input-bordered" + placeholder='Type "Yes, I agree"' + /> + +
+
+
+ )} + {activeTab === 0 ? ( ) : ( { + const decodeBase64 = (str: string) => + Uint8Array.from(atob(str), (c) => c.charCodeAt(0)); + + const cryptoKey = await window.crypto.subtle.importKey( + "raw", + decodeBase64(key), + "AES-GCM", + true, + ["decrypt"], + ); + + const combinedCiphertext = new Uint8Array([ + ...decodeBase64(ciphertext), + ...decodeBase64(tag), + ]); + + const ivArray = decodeBase64(iv); + + const decrypted = await window.crypto.subtle.decrypt( + { name: "AES-GCM", iv: ivArray }, + cryptoKey, + combinedCiphertext, + ); + + return new TextDecoder().decode(decrypted); +} + export default async function getDisplayPredictionsByName( runName: string, signal: AbortSignal, suite?: string, + userAgreed?: boolean, ): Promise { try { - const displayPrediction = await fetch( + const response = await fetch( getBenchmarkEndpoint( `/runs/${ suite || getBenchmarkSuite() @@ -16,8 +51,43 @@ export default async function getDisplayPredictionsByName( ), { signal }, ); + const displayPredictions = (await response.json()) as DisplayPrediction[]; + + if (runName.includes("gpqa") && userAgreed) { + const encryptionResponse = await fetch( + getBenchmarkEndpoint( + `/runs/${ + suite || getBenchmarkSuite() + }/${runName}/encryption_data.json`, + ), + { signal }, + ); + const encryptionData = + (await encryptionResponse.json()) as EncryptionDataMap; + + for (const prediction of displayPredictions) { + const encryptedText = prediction.predicted_text; + const encryptionDetails = encryptionData[encryptedText]; + + if (encryptionDetails) { + try { + prediction.predicted_text = await decryptField( + encryptionDetails.ciphertext, + encryptionDetails.key, + encryptionDetails.iv, + encryptionDetails.tag, + ); + } catch (error) { + console.error( + `Failed to decrypt predicted_text for instance_id: ${prediction.instance_id}`, + error, + ); + } + } + } + } - return (await displayPrediction.json()) as DisplayPrediction[]; + return displayPredictions; } catch (error) { if (error instanceof Error && error.name === "AbortError") { console.log(error); diff --git a/helm-frontend/src/services/getDisplayRequestsByName.ts b/helm-frontend/src/services/getDisplayRequestsByName.ts index a996fba47a..a5d3794ad2 100644 --- a/helm-frontend/src/services/getDisplayRequestsByName.ts +++ b/helm-frontend/src/services/getDisplayRequestsByName.ts @@ -1,14 +1,50 @@ import type DisplayRequest from "@/types/DisplayRequest"; +import { EncryptionDataMap } from "@/types/EncryptionDataMap"; import getBenchmarkEndpoint from "@/utils/getBenchmarkEndpoint"; import getBenchmarkSuite from "@/utils/getBenchmarkSuite"; +// Helper function for decryption +async function decryptField( + ciphertext: string, + key: string, + iv: string, + tag: string, +): Promise { + const decodeBase64 = (str: string) => + Uint8Array.from(atob(str), (c) => c.charCodeAt(0)); + + const cryptoKey = await window.crypto.subtle.importKey( + "raw", + decodeBase64(key), + "AES-GCM", + true, + ["decrypt"], + ); + + const combinedCiphertext = new Uint8Array([ + ...decodeBase64(ciphertext), + ...decodeBase64(tag), + ]); + + const ivArray = decodeBase64(iv); + + const decrypted = await window.crypto.subtle.decrypt( + { name: "AES-GCM", iv: ivArray }, + cryptoKey, + combinedCiphertext, + ); + + return new TextDecoder().decode(decrypted); +} + export default async function getDisplayRequestsByName( runName: string, signal: AbortSignal, suite?: string, + userAgreed?: boolean, ): Promise { try { - const displayRequest = await fetch( + const response = await fetch( getBenchmarkEndpoint( `/runs/${ suite || getBenchmarkSuite() @@ -16,8 +52,43 @@ export default async function getDisplayRequestsByName( ), { signal }, ); + const displayRequests = (await response.json()) as DisplayRequest[]; + + if (runName.startsWith("gpqa") && userAgreed) { + const encryptionResponse = await fetch( + getBenchmarkEndpoint( + `/runs/${ + suite || getBenchmarkSuite() + }/${runName}/encryption_data.json`, + ), + { signal }, + ); + const encryptionData = + (await encryptionResponse.json()) as EncryptionDataMap; + + for (const request of displayRequests) { + const encryptedPrompt = request.request.prompt; + const encryptionDetails = encryptionData[encryptedPrompt]; + + if (encryptionDetails) { + try { + request.request.prompt = await decryptField( + encryptionDetails.ciphertext, + encryptionDetails.key, + encryptionDetails.iv, + encryptionDetails.tag, + ); + } catch (error) { + console.error( + `Failed to decrypt prompt for instance_id: ${request.instance_id}`, + error, + ); + } + } + } + } - return (await displayRequest.json()) as DisplayRequest[]; + return displayRequests; } catch (error) { if (error instanceof Error && error.name !== "AbortError") { console.log(error); diff --git a/helm-frontend/src/services/getInstances.ts b/helm-frontend/src/services/getInstances.ts index 77ca42a055..45c370bd39 100644 --- a/helm-frontend/src/services/getInstances.ts +++ b/helm-frontend/src/services/getInstances.ts @@ -1,21 +1,97 @@ import Instance from "@/types/Instance"; +import { EncryptionDataMap } from "@/types/EncryptionDataMap"; import getBenchmarkEndpoint from "@/utils/getBenchmarkEndpoint"; import getBenchmarkSuite from "@/utils/getBenchmarkSuite"; +// Helper function for decryption +async function decryptField( + ciphertext: string, + key: string, + iv: string, + tag: string, +): Promise { + // Convert Base64 strings to Uint8Array + const decodeBase64 = (str: string) => + Uint8Array.from(atob(str), (c) => c.charCodeAt(0)); + + const cryptoKey = await window.crypto.subtle.importKey( + "raw", + decodeBase64(key), + "AES-GCM", + true, + ["decrypt"], + ); + + const combinedCiphertext = new Uint8Array([ + ...decodeBase64(ciphertext), + ...decodeBase64(tag), + ]); + + const ivArray = decodeBase64(iv); + + const decrypted = await window.crypto.subtle.decrypt( + { name: "AES-GCM", iv: ivArray }, + cryptoKey, + combinedCiphertext, + ); + + return new TextDecoder().decode(decrypted); +} + export default async function getInstancesByRunName( runName: string, signal: AbortSignal, suite?: string, + userAgreed?: boolean, ): Promise { try { - const instances = await fetch( + const response = await fetch( getBenchmarkEndpoint( `/runs/${suite || getBenchmarkSuite()}/${runName}/instances.json`, ), { signal }, ); + const instances = (await response.json()) as Instance[]; + + if (runName.includes("gpqa") && userAgreed) { + const encryptionResponse = await fetch( + getBenchmarkEndpoint( + `/runs/${ + suite || getBenchmarkSuite() + }/${runName}/encryption_data.json`, + ), + { signal }, + ); + const encryptionData = + (await encryptionResponse.json()) as EncryptionDataMap; + + for (const instance of instances) { + const inputEncryption = encryptionData[instance.input.text]; + if (inputEncryption) { + instance.input.text = "encrypted"; + instance.input.text = await decryptField( + inputEncryption.ciphertext, + inputEncryption.key, + inputEncryption.iv, + inputEncryption.tag, + ); + } + + for (const reference of instance.references) { + const referenceEncryption = encryptionData[reference.output.text]; + if (referenceEncryption) { + reference.output.text = await decryptField( + referenceEncryption.ciphertext, + referenceEncryption.key, + referenceEncryption.iv, + referenceEncryption.tag, + ); + } + } + } + } - return (await instances.json()) as Instance[]; + return instances; } catch (error) { if (error instanceof Error && error.name !== "AbortError") { console.log(error); diff --git a/helm-frontend/src/types/EncryptionDataMap.ts b/helm-frontend/src/types/EncryptionDataMap.ts new file mode 100644 index 0000000000..61f8b65d70 --- /dev/null +++ b/helm-frontend/src/types/EncryptionDataMap.ts @@ -0,0 +1,8 @@ +export default interface EncryptionDetails { + ciphertext: string; + key: string; + iv: string; + tag: string; +} + +export type EncryptionDataMap = Record; diff --git a/scripts/decrypt_scenario_states.py b/scripts/decrypt_scenario_states.py new file mode 100644 index 0000000000..0e4276c0df --- /dev/null +++ b/scripts/decrypt_scenario_states.py @@ -0,0 +1,191 @@ +import argparse +import dataclasses +import json +import os +import base64 +from typing import Dict, Optional +from tqdm import tqdm +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.backends import default_backend + +from helm.benchmark.adaptation.request_state import RequestState +from helm.benchmark.adaptation.scenario_state import ScenarioState +from helm.benchmark.scenarios.scenario import Instance, Reference +from helm.common.codec import from_json, to_json +from helm.common.hierarchical_logger import hlog +from helm.common.request import Request, RequestResult + +_SCENARIO_STATE_FILE_NAME = "scenario_state.json" +_DECRYPTED_SCENARIO_STATE_FILE_NAME = "decrypted_scenario_state.json" +_DISPLAY_ENCRYPTION_DATA_JSON_FILE_NAME = "encryption_data.json" + + +class HELMDecryptor: + def __init__(self, encryption_data_mapping: Dict[str, Dict[str, str]]): + """ + encryption_data_mapping is a dict like: + { + "[encrypted_text_0]": { + "ciphertext": "...", + "key": "...", + "iv": "...", + "tag": "..." + }, + ... + } + """ + self.encryption_data_mapping = encryption_data_mapping + + def decrypt_text(self, text: str) -> str: + if text.startswith("[encrypted_text_") and text.endswith("]"): + data = self.encryption_data_mapping.get(text) + if data is None: + # If not found in encryption data, return as-is or raise error + raise ValueError(f"No decryption data found for {text}") + + ciphertext = base64.b64decode(data["ciphertext"]) + key = base64.b64decode(data["key"]) + iv = base64.b64decode(data["iv"]) + tag = base64.b64decode(data["tag"]) + + cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=default_backend()) + decryptor = cipher.decryptor() + plaintext = decryptor.update(ciphertext) + decryptor.finalize() + return plaintext.decode("utf-8") + else: + # Not an encrypted placeholder, return as is. + return text + + +def read_scenario_state(scenario_state_path: str) -> ScenarioState: + if not os.path.exists(scenario_state_path): + raise ValueError(f"Could not load ScenarioState from {scenario_state_path}") + with open(scenario_state_path) as f: + return from_json(f.read(), ScenarioState) + + +def write_scenario_state(scenario_state_path: str, scenario_state: ScenarioState) -> None: + with open(scenario_state_path, "w") as f: + f.write(to_json(scenario_state)) + + +def read_encryption_data(encryption_data_path: str) -> Dict[str, Dict[str, str]]: + if not os.path.exists(encryption_data_path): + raise ValueError(f"Could not load encryption data from {encryption_data_path}") + with open(encryption_data_path) as f: + return json.load(f) + + +def decrypt_reference(reference: Reference, decryptor: HELMDecryptor) -> Reference: + decrypted_output = dataclasses.replace(reference.output, text=decryptor.decrypt_text(reference.output.text)) + return dataclasses.replace(reference, output=decrypted_output) + + +def decrypt_instance(instance: Instance, decryptor: HELMDecryptor) -> Instance: + decrypted_input = dataclasses.replace(instance.input, text=decryptor.decrypt_text(instance.input.text)) + decrypted_references = [decrypt_reference(reference, decryptor) for reference in instance.references] + return dataclasses.replace(instance, input=decrypted_input, references=decrypted_references) + + +def decrypt_request(request: Request, decryptor: HELMDecryptor) -> Request: + # The encryption script sets request.messages and multimodal_prompt to None, so we don't need to decrypt them + return dataclasses.replace(request, prompt=decryptor.decrypt_text(request.prompt)) + + +def decrypt_output_mapping( + output_mapping: Optional[Dict[str, str]], decryptor: HELMDecryptor +) -> Optional[Dict[str, str]]: + if output_mapping is None: + return None + return {key: decryptor.decrypt_text(val) for key, val in output_mapping.items()} + + +def decrypt_result(result: Optional[RequestResult], decryptor: HELMDecryptor) -> Optional[RequestResult]: + if result is None: + return None + + decrypted_completions = [ + dataclasses.replace(completion, text=decryptor.decrypt_text(completion.text)) + for completion in result.completions + ] + return dataclasses.replace(result, completions=decrypted_completions) + + +def decrypt_request_state(request_state: RequestState, decryptor: HELMDecryptor) -> RequestState: + return dataclasses.replace( + request_state, + instance=decrypt_instance(request_state.instance, decryptor), + request=decrypt_request(request_state.request, decryptor), + output_mapping=decrypt_output_mapping(request_state.output_mapping, decryptor), + result=decrypt_result(request_state.result, decryptor), + ) + + +def decrypt_scenario_state(scenario_state: ScenarioState, decryptor: HELMDecryptor) -> ScenarioState: + decrypted_request_states = [decrypt_request_state(rs, decryptor) for rs in scenario_state.request_states] + return dataclasses.replace(scenario_state, request_states=decrypted_request_states) + + +def modify_scenario_state_for_run(run_path: str) -> None: + scenario_state_path = os.path.join(run_path, _SCENARIO_STATE_FILE_NAME) + encryption_data_path = os.path.join(run_path, _DISPLAY_ENCRYPTION_DATA_JSON_FILE_NAME) + + scenario_state = read_scenario_state(scenario_state_path) + encryption_data_mapping = read_encryption_data(encryption_data_path) + decryptor = HELMDecryptor(encryption_data_mapping) + + decrypted_scenario_state = decrypt_scenario_state(scenario_state, decryptor) + decrypted_scenario_state_path = os.path.join(run_path, _DECRYPTED_SCENARIO_STATE_FILE_NAME) + write_scenario_state(decrypted_scenario_state_path, decrypted_scenario_state) + + +def modify_scenario_states_for_suite(run_suite_path: str, scenario: str) -> None: + scenario_prefix = scenario if scenario != "all" else "" + run_dir_names = sorted( + [ + p + for p in os.listdir(run_suite_path) + if p != "eval_cache" + and p != "groups" + and os.path.isdir(os.path.join(run_suite_path, p)) + and p.startswith(scenario_prefix) + ] + ) + for run_dir_name in tqdm(run_dir_names, disable=None): + scenario_state_path: str = os.path.join(run_suite_path, run_dir_name, _SCENARIO_STATE_FILE_NAME) + encryption_data_path = os.path.join(run_suite_path, run_dir_name, _DISPLAY_ENCRYPTION_DATA_JSON_FILE_NAME) + if not os.path.exists(scenario_state_path): + hlog(f"WARNING: {run_dir_name} doesn't have {_SCENARIO_STATE_FILE_NAME}, skipping") + continue + if not os.path.exists(encryption_data_path): + hlog(f"WARNING: {run_dir_name} doesn't have {_DISPLAY_ENCRYPTION_DATA_JSON_FILE_NAME}, skipping") + continue + run_path: str = os.path.join(run_suite_path, run_dir_name) + modify_scenario_state_for_run(run_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", "--output-path", type=str, help="Where the benchmarking output lives", default="benchmark_output" + ) + parser.add_argument( + "--suite", + type=str, + help="Name of the suite this decryption should go under.", + ) + parser.add_argument( + "--scenario", + type=str, + default="all", + help="Name of the scenario this decryption should go under. Default is all.", + ) + args = parser.parse_args() + output_path = args.output_path + suite = args.suite + run_suite_path = os.path.join(output_path, "runs", suite) + modify_scenario_states_for_suite(run_suite_path, scenario=args.scenario) + + +if __name__ == "__main__": + main() diff --git a/scripts/encrypt_scenario_states.py b/scripts/encrypt_scenario_states.py new file mode 100644 index 0000000000..5db14b5419 --- /dev/null +++ b/scripts/encrypt_scenario_states.py @@ -0,0 +1,187 @@ +"""Encrypts prompts from scenario state. + +This script modifies all scenario_state.json files in place within a suite to +encrypting all prompts, instance input text, and instance reference output text +from the `ScenarioState`s. + +This is used when the scenario contains prompts that should not be displayed, +in order to reduce the chance of data leakage or to comply with data privacy +requirements. + +After running this, you must re-run helm-summarize on the suite in order to +update other JSON files used by the web frontend.""" + +import argparse +import dataclasses +import os +import base64 +from typing import Dict, Optional +from tqdm import tqdm +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.backends import default_backend + +from helm.benchmark.adaptation.request_state import RequestState +from helm.benchmark.adaptation.scenario_state import ScenarioState +from helm.benchmark.scenarios.scenario import Instance, Reference +from helm.common.codec import from_json, to_json +from helm.common.hierarchical_logger import hlog +from helm.common.request import Request, RequestResult +from helm.common.general import write + + +_SCENARIO_STATE_FILE_NAME = "scenario_state.json" +_DISPLAY_ENCRYPTION_DATA_JSON_FILE_NAME = "encryption_data.json" + + +class HELMEncryptor: + def __init__(self, key, iv): + self.key = key + self.iv = iv + self.encryption_data_mapping = {} + self.idx = 0 + + def encrypt_text(self, text: str) -> str: + cipher = Cipher(algorithms.AES(self.key), modes.GCM(self.iv), backend=default_backend()) + encryptor = cipher.encryptor() + ciphertext = encryptor.update(text.encode()) + encryptor.finalize() + ret_text = f"[encrypted_text_{self.idx}]" + + res = { + "ciphertext": base64.b64encode(ciphertext).decode(), + "key": base64.b64encode(self.key).decode(), + "iv": base64.b64encode(self.iv).decode(), + "tag": base64.b64encode(encryptor.tag).decode(), + } + assert ret_text not in self.encryption_data_mapping + self.encryption_data_mapping[ret_text] = res + self.idx += 1 + return ret_text + + +def read_scenario_state(scenario_state_path: str) -> ScenarioState: + if not os.path.exists(scenario_state_path): + raise ValueError(f"Could not load ScenarioState from {scenario_state_path}") + with open(scenario_state_path) as f: + return from_json(f.read(), ScenarioState) + + +def write_scenario_state(scenario_state_path: str, scenario_state: ScenarioState) -> None: + with open(scenario_state_path, "w") as f: + f.write(to_json(scenario_state)) + + +def encrypt_reference(reference: Reference) -> Reference: + global encryptor + encrypted_output = dataclasses.replace(reference.output, text=encryptor.encrypt_text(reference.output.text)) + return dataclasses.replace(reference, output=encrypted_output) + + +def encrypt_instance(instance: Instance) -> Instance: + global encryptor + encrypted_input = dataclasses.replace(instance.input, text=encryptor.encrypt_text(instance.input.text)) + encrypted_references = [encrypt_reference(reference) for reference in instance.references] + return dataclasses.replace(instance, input=encrypted_input, references=encrypted_references) + + +def encrypt_request(request: Request) -> Request: + global encryptor + return dataclasses.replace( + request, prompt=encryptor.encrypt_text(request.prompt), messages=None, multimodal_prompt=None + ) + + +def encrypt_output_mapping(output_mapping: Optional[Dict[str, str]]) -> Optional[Dict[str, str]]: + if output_mapping is None: + return None + return {key: encryptor.encrypt_text(val) for key, val in output_mapping.items()} + + +def encrypt_result(result: Optional[RequestResult]) -> Optional[RequestResult]: + if result is None: + return None + + encrypted_results = [ + dataclasses.replace(completion, text=encryptor.encrypt_text(completion.text)) + for completion in result.completions + ] + return dataclasses.replace(result, completions=encrypted_results) + + +def encrypt_request_state(request_state: RequestState) -> RequestState: + return dataclasses.replace( + request_state, + instance=encrypt_instance(request_state.instance), + request=encrypt_request(request_state.request), + output_mapping=encrypt_output_mapping(request_state.output_mapping), + result=encrypt_result(request_state.result), + ) + + +def encrypt_scenario_state(scenario_state: ScenarioState) -> ScenarioState: + encrypted_request_states = [encrypt_request_state(request_state) for request_state in scenario_state.request_states] + return dataclasses.replace(scenario_state, request_states=encrypted_request_states) + + +def modify_scenario_state_for_run(run_path: str) -> None: + scenario_state_path = os.path.join(run_path, _SCENARIO_STATE_FILE_NAME) + scenario_state = read_scenario_state(scenario_state_path) + encrypted_scenario_state = encrypt_scenario_state(scenario_state) + write_scenario_state(scenario_state_path, encrypted_scenario_state) + + +def modify_scenario_states_for_suite(run_suite_path: str, scenario: str) -> None: + """Load the runs in the run suite path.""" + # run_suite_path can contain subdirectories that are not runs (e.g. eval_cache, groups) + # so filter them out. + scenario_prefix = scenario if scenario != "all" else "" + run_dir_names = sorted( + [ + p + for p in os.listdir(run_suite_path) + if p != "eval_cache" + and p != "groups" + and os.path.isdir(os.path.join(run_suite_path, p)) + and p.startswith(scenario_prefix) + ] + ) + for run_dir_name in tqdm(run_dir_names, disable=None): + scenario_state_path: str = os.path.join(run_suite_path, run_dir_name, _SCENARIO_STATE_FILE_NAME) + if not os.path.exists(scenario_state_path): + hlog(f"WARNING: {run_dir_name} doesn't have {_SCENARIO_STATE_FILE_NAME}, skipping") + continue + run_path: str = os.path.join(run_suite_path, run_dir_name) + modify_scenario_state_for_run(run_path) + + # Write the encryption data to a file + encryption_data_path = os.path.join(run_path, _DISPLAY_ENCRYPTION_DATA_JSON_FILE_NAME) + write(encryption_data_path, to_json(encryptor.encryption_data_mapping)) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", "--output-path", type=str, help="Where the benchmarking output lives", default="benchmark_output" + ) + parser.add_argument( + "--suite", + type=str, + help="Name of the suite this encryption should go under.", + ) + parser.add_argument( + "--scenario", + type=str, + default="all", + help="Name of the scenario this encryption should go under. Default is all.", + ) + args = parser.parse_args() + output_path = args.output_path + suite = args.suite + run_suite_path = os.path.join(output_path, "runs", suite) + modify_scenario_states_for_suite(run_suite_path, scenario=args.scenario) + + +if __name__ == "__main__": + key = os.urandom(32) # 256-bit key + iv = os.urandom(12) # 96-bit IV (suitable for AES-GCM) + encryptor = HELMEncryptor(key, iv) + main()