Skip to content

Commit

Permalink
Add Encryption for GPQA (#3216)
Browse files Browse the repository at this point in the history
  • Loading branch information
liamjxu authored Dec 20, 2024
1 parent b01f5f6 commit be8ac6b
Show file tree
Hide file tree
Showing 8 changed files with 669 additions and 11 deletions.
16 changes: 11 additions & 5 deletions helm-frontend/src/components/Instances.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,15 @@ interface Props {
runName: string;
suite: string;
metricFieldMap: MetricFieldMap;
userAgreed: boolean;
}

export default function Instances({ runName, suite, metricFieldMap }: Props) {
export default function Instances({
runName,
suite,
metricFieldMap,
userAgreed,
}: Props) {
const [searchParams, setSearchParams] = useSearchParams();
const [instances, setInstances] = useState<Instance[]>([]);
const [displayPredictionsMap, setDisplayPredictionsMap] = useState<
Expand All @@ -43,9 +49,9 @@ export default function Instances({ runName, suite, metricFieldMap }: Props) {

const [instancesResp, displayPredictions, displayRequests] =
await Promise.all([
getInstances(runName, signal, suite),
getDisplayPredictionsByName(runName, signal, suite),
getDisplayRequestsByName(runName, signal, suite),
getInstances(runName, signal, suite, userAgreed),
getDisplayPredictionsByName(runName, signal, suite, userAgreed),
getDisplayRequestsByName(runName, signal, suite, userAgreed),
]);
setInstances(instancesResp);

Expand Down Expand Up @@ -93,7 +99,7 @@ export default function Instances({ runName, suite, metricFieldMap }: Props) {
void fetchData();

return () => controller.abort();
}, [runName, suite]);
}, [runName, suite, userAgreed]);

const pagedInstances = instances.slice(
(currentInstancesPage - 1) * INSTANCES_PAGE_SIZE,
Expand Down
49 changes: 49 additions & 0 deletions helm-frontend/src/routes/Run.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ export default function Run() {
MetricFieldMap | undefined
>({});

const [agreeInput, setAgreeInput] = useState("");
const [userAgreed, setUserAgreed] = useState(false);

useEffect(() => {
const controller = new AbortController();
async function fetchData() {
Expand Down Expand Up @@ -93,6 +96,16 @@ export default function Run() {
return <Loading />;
}

// Handler for agreement
const handleAgreement = () => {
if (agreeInput.trim() === "Yes, I agree") {
setUserAgreed(true);
} else {
setUserAgreed(false);
alert("Please type 'Yes, I agree' exactly.");
}
};

return (
<>
<div className="flex justify-between gap-8 mb-12">
Expand Down Expand Up @@ -178,11 +191,47 @@ export default function Run() {
</Tab>
</Tabs>
</div>

{activeTab === 0 && runName.includes("gpqa") && !userAgreed && (
<div className="mb-8">
<hr className="my-4" />
<p className="mb-4">
The GPQA dataset instances are encrypted by default to comply with
the following request:
</p>
<blockquote className="italic border-l-4 border-gray-300 pl-4 text-gray-700 mb-4">
“We ask that you do not reveal examples from this dataset in plain
text or images online, to minimize the risk of these instances being
included in foundation model training corpora.”
</blockquote>
<p className="mb-4">
If you agree to this condition, please type{" "}
<strong>"Yes, I agree"</strong> in the box below and then click{" "}
<strong>Decrypt</strong>.
</p>
<div className="flex gap-2 mt-2">
<input
type="text"
value={agreeInput}
onChange={(e) => setAgreeInput(e.target.value)}
className="input input-bordered"
placeholder='Type "Yes, I agree"'
/>
<button onClick={handleAgreement} className="btn btn-primary">
Decrypt
</button>
</div>
<hr className="my-4" />
</div>
)}

{activeTab === 0 ? (
<Instances
key={userAgreed ? "instances-agreed" : "instances-not-agreed"}
runName={runName}
suite={runSuite}
metricFieldMap={metricFieldMap}
userAgreed={userAgreed} // Pass the boolean to Instances
/>
) : (
<RunMetrics
Expand Down
74 changes: 72 additions & 2 deletions helm-frontend/src/services/getDisplayPredictionsByName.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,93 @@
import type DisplayPrediction from "@/types/DisplayPrediction";
import { EncryptionDataMap } from "@/types/EncryptionDataMap";
import getBenchmarkEndpoint from "@/utils/getBenchmarkEndpoint";
import getBenchmarkSuite from "@/utils/getBenchmarkSuite";

async function decryptField(
ciphertext: string,
key: string,
iv: string,
tag: string,
): Promise<string> {
const decodeBase64 = (str: string) =>
Uint8Array.from(atob(str), (c) => c.charCodeAt(0));

const cryptoKey = await window.crypto.subtle.importKey(
"raw",
decodeBase64(key),
"AES-GCM",
true,
["decrypt"],
);

const combinedCiphertext = new Uint8Array([
...decodeBase64(ciphertext),
...decodeBase64(tag),
]);

const ivArray = decodeBase64(iv);

const decrypted = await window.crypto.subtle.decrypt(
{ name: "AES-GCM", iv: ivArray },
cryptoKey,
combinedCiphertext,
);

return new TextDecoder().decode(decrypted);
}

export default async function getDisplayPredictionsByName(
runName: string,
signal: AbortSignal,
suite?: string,
userAgreed?: boolean,
): Promise<DisplayPrediction[]> {
try {
const displayPrediction = await fetch(
const response = await fetch(
getBenchmarkEndpoint(
`/runs/${
suite || getBenchmarkSuite()
}/${runName}/display_predictions.json`,
),
{ signal },
);
const displayPredictions = (await response.json()) as DisplayPrediction[];

if (runName.includes("gpqa") && userAgreed) {
const encryptionResponse = await fetch(
getBenchmarkEndpoint(
`/runs/${
suite || getBenchmarkSuite()
}/${runName}/encryption_data.json`,
),
{ signal },
);
const encryptionData =
(await encryptionResponse.json()) as EncryptionDataMap;

for (const prediction of displayPredictions) {
const encryptedText = prediction.predicted_text;
const encryptionDetails = encryptionData[encryptedText];

if (encryptionDetails) {
try {
prediction.predicted_text = await decryptField(
encryptionDetails.ciphertext,
encryptionDetails.key,
encryptionDetails.iv,
encryptionDetails.tag,
);
} catch (error) {
console.error(
`Failed to decrypt predicted_text for instance_id: ${prediction.instance_id}`,
error,
);
}
}
}
}

return (await displayPrediction.json()) as DisplayPrediction[];
return displayPredictions;
} catch (error) {
if (error instanceof Error && error.name === "AbortError") {
console.log(error);
Expand Down
75 changes: 73 additions & 2 deletions helm-frontend/src/services/getDisplayRequestsByName.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,94 @@
import type DisplayRequest from "@/types/DisplayRequest";
import { EncryptionDataMap } from "@/types/EncryptionDataMap";
import getBenchmarkEndpoint from "@/utils/getBenchmarkEndpoint";
import getBenchmarkSuite from "@/utils/getBenchmarkSuite";

// Helper function for decryption
async function decryptField(
ciphertext: string,
key: string,
iv: string,
tag: string,
): Promise<string> {
const decodeBase64 = (str: string) =>
Uint8Array.from(atob(str), (c) => c.charCodeAt(0));

const cryptoKey = await window.crypto.subtle.importKey(
"raw",
decodeBase64(key),
"AES-GCM",
true,
["decrypt"],
);

const combinedCiphertext = new Uint8Array([
...decodeBase64(ciphertext),
...decodeBase64(tag),
]);

const ivArray = decodeBase64(iv);

const decrypted = await window.crypto.subtle.decrypt(
{ name: "AES-GCM", iv: ivArray },
cryptoKey,
combinedCiphertext,
);

return new TextDecoder().decode(decrypted);
}

export default async function getDisplayRequestsByName(
runName: string,
signal: AbortSignal,
suite?: string,
userAgreed?: boolean,
): Promise<DisplayRequest[]> {
try {
const displayRequest = await fetch(
const response = await fetch(
getBenchmarkEndpoint(
`/runs/${
suite || getBenchmarkSuite()
}/${runName}/display_requests.json`,
),
{ signal },
);
const displayRequests = (await response.json()) as DisplayRequest[];

if (runName.startsWith("gpqa") && userAgreed) {
const encryptionResponse = await fetch(
getBenchmarkEndpoint(
`/runs/${
suite || getBenchmarkSuite()
}/${runName}/encryption_data.json`,
),
{ signal },
);
const encryptionData =
(await encryptionResponse.json()) as EncryptionDataMap;

for (const request of displayRequests) {
const encryptedPrompt = request.request.prompt;
const encryptionDetails = encryptionData[encryptedPrompt];

if (encryptionDetails) {
try {
request.request.prompt = await decryptField(
encryptionDetails.ciphertext,
encryptionDetails.key,
encryptionDetails.iv,
encryptionDetails.tag,
);
} catch (error) {
console.error(
`Failed to decrypt prompt for instance_id: ${request.instance_id}`,
error,
);
}
}
}
}

return (await displayRequest.json()) as DisplayRequest[];
return displayRequests;
} catch (error) {
if (error instanceof Error && error.name !== "AbortError") {
console.log(error);
Expand Down
Loading

0 comments on commit be8ac6b

Please sign in to comment.