From ed9b6c2ac7e076f2e220c2336917c8541f0f485d Mon Sep 17 00:00:00 2001 From: Wenxin Du <117315983+duwenxin99@users.noreply.github.com> Date: Tue, 30 Jan 2024 14:10:23 -0500 Subject: [PATCH] fix(langchain_tools_demo): Add ID Token credential flow for GCE (#198) Fixes: https://github.com/GoogleCloudPlatform/genai-databases-retrieval-app/issues/190 --- docs/run_langchain_demo.md | 1 + langchain_tools_demo/tools.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/docs/run_langchain_demo.md b/docs/run_langchain_demo.md index 5036a93f..a2674aa7 100644 --- a/docs/run_langchain_demo.md +++ b/docs/run_langchain_demo.md @@ -18,6 +18,7 @@ ```bash gcloud auth application-default login ``` + * Tip: if you are running into `403` error, check to make sure the service account you are using has the `Cloud Run Invoker` IAM in the retrieval service project. 1. Change into the `langchain_tools_demo` directory: diff --git a/langchain_tools_demo/tools.py b/langchain_tools_demo/tools.py index b0632209..ec14307e 100644 --- a/langchain_tools_demo/tools.py +++ b/langchain_tools_demo/tools.py @@ -17,6 +17,7 @@ import aiohttp import google.oauth2.id_token # type: ignore +from google.auth import compute_engine # type: ignore from google.auth.transport.requests import Request # type: ignore from langchain.agents.agent import ExceptionTool # type: ignore from langchain.tools import StructuredTool @@ -34,9 +35,19 @@ def get_id_token(): global CREDENTIALS if CREDENTIALS is None: CREDENTIALS, _ = google.auth.default() + if not hasattr(CREDENTIALS, "id_token"): + # Use Compute Engine default credential + CREDENTIALS = compute_engine.IDTokenCredentials( + request=Request(), + target_audience=BASE_URL, + use_metadata_identity_endpoint=True, + ) if not CREDENTIALS.valid: CREDENTIALS.refresh(Request()) - return CREDENTIALS.id_token + if hasattr(CREDENTIALS, "id_token"): + return CREDENTIALS.id_token + else: + return CREDENTIALS.token def get_headers(client: aiohttp.ClientSession):