Skip to content

Commit

Permalink
refactor search query fetching
Browse files Browse the repository at this point in the history
  • Loading branch information
petemill committed Aug 31, 2024
1 parent bfa93cc commit a8143f6
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 11 deletions.
4 changes: 4 additions & 0 deletions components/ai_chat/content/browser/ai_chat_tab_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "base/strings/string_util.h"
#include "brave/components/ai_chat/content/browser/page_content_fetcher.h"
#include "brave/components/ai_chat/content/browser/pdf_utils.h"
#include "brave/components/ai_chat/core/browser/associated_content_driver.h"
#include "brave/components/ai_chat/core/browser/constants.h"
#include "brave/components/ai_chat/core/browser/utils.h"
#include "brave/components/ai_chat/core/common/mojom/page_content_extractor.mojom.h"
Expand Down Expand Up @@ -88,6 +89,9 @@ AIChatTabHelper::AIChatTabHelper(
ExtractPrintPreviewContentFunction extract_print_preview_content_function)
: content::WebContentsObserver(web_contents),
content::WebContentsUserData<AIChatTabHelper>(*web_contents),
AssociatedContentDriver(web_contents->GetBrowserContext()
->GetDefaultStoragePartition()
->GetURLLoaderFactoryForBrowserProcess()),
extract_print_preview_content_function_(
extract_print_preview_content_function) {
CHECK(extract_print_preview_content_function);
Expand Down
123 changes: 120 additions & 3 deletions components/ai_chat/core/browser/associated_content_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,46 @@
#include "base/one_shot_event.h"
#include "base/ranges/algorithm.h"
#include "base/strings/string_util.h"
#include "brave/brave_domains/service_domains.h"
#include "brave/components/ai_chat/core/browser/brave_search_responses.h"
#include "brave/components/ai_chat/core/browser/constants.h"
#include "brave/components/ai_chat/core/browser/conversation_handler.h"

#include "brave/components/ai_chat/core/browser/utils.h"
#include "net/base/url_util.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
namespace ai_chat {

AssociatedContentDriver::AssociatedContentDriver()
: on_page_text_fetch_complete_(std::make_unique<base::OneShotEvent>()) {}
namespace {

net::NetworkTrafficAnnotationTag
GetSearchQuerySummaryNetworkTrafficAnnotationTag() {
return net::DefineNetworkTrafficAnnotation(
"ai_chat_associated_content_driver",
R"(
semantics {
sender: "Brave Leo AI Chat"
description:
"This sender is used to get search query summary from Brave search."
trigger:
"Triggered by uses of Brave Leo AI Chat on Brave Search SERP."
data:
"User's search query and the corresponding summary."
destination: WEBSITE
}
policy {
cookies_allowed: NO
policy_exception_justification:
"Not implemented."
}
)");
}

} // namespace

AssociatedContentDriver::AssociatedContentDriver(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)
: url_loader_factory_(url_loader_factory),
on_page_text_fetch_complete_(std::make_unique<base::OneShotEvent>()) {}

AssociatedContentDriver::~AssociatedContentDriver() {
for (auto& conversation : associated_conversations_) {
Expand Down Expand Up @@ -139,6 +172,89 @@ bool AssociatedContentDriver::GetCachedIsVideo() {
return is_video_;
}

void AssociatedContentDriver::GetStagedEntriesFromContent(
ConversationHandler::GetStagedEntriesCallback callback) {
// At the moment we only know about staged entries from:
// - Brave Search results page
if (!IsBraveSearchSERP(GetPageURL())) {
std::move(callback).Run(std::nullopt);
return;
}
GetSearchSummarizerKey(
base::BindOnce(&AssociatedContentDriver::OnSearchSummarizerKeyFetched,
weak_ptr_factory_.GetWeakPtr(), std::move(callback),
current_navigation_id_));
}

void AssociatedContentDriver::OnSearchSummarizerKeyFetched(
ConversationHandler::GetStagedEntriesCallback callback,
int64_t navigation_id,
const std::optional<std::string>& key) {
if (!key || key->empty() || navigation_id != current_navigation_id_) {
std::move(callback).Run(std::nullopt);
return;
}

if (!api_request_helper_) {
api_request_helper_ =
std::make_unique<api_request_helper::APIRequestHelper>(
GetSearchQuerySummaryNetworkTrafficAnnotationTag(),
url_loader_factory_);
}

// https://search.brave.com/api/chatllm/raw_data?key=<key>
GURL base_url(
base::StrCat({url::kHttpsScheme, url::kStandardSchemeSeparator,
brave_domains::GetServicesDomain(kBraveSearchURLPrefix),
"/api/chatllm/raw_data"}));
CHECK(base_url.is_valid());
GURL url = net::AppendQueryParameter(base_url, "key", *key);

api_request_helper_->Request(
"GET", url, "", "application/json",
base::BindOnce(&AssociatedContentDriver::OnSearchQuerySummaryFetched,
weak_ptr_factory_.GetWeakPtr(), std::move(callback),
navigation_id),
{}, {});
}

void AssociatedContentDriver::OnSearchQuerySummaryFetched(
ConversationHandler::GetStagedEntriesCallback callback,
int64_t navigation_id,
api_request_helper::APIRequestResult result) {
if (!result.Is2XXResponseCode() || navigation_id != current_navigation_id_) {
std::move(callback).Run(std::nullopt);
return;
}

auto search_query_summary =
ParseSearchQuerySummaryResponse(result.value_body());
if (!search_query_summary) {
std::move(callback).Run(std::nullopt);
return;
}

std::move(callback).Run(search_query_summary);
}

// static
std::optional<SearchQuerySummary>
AssociatedContentDriver::ParseSearchQuerySummaryResponse(
const base::Value& value) {
auto search_query_response =
brave_search_responses::QuerySummaryResponse::FromValue(value);
if (!search_query_response || search_query_response->conversation.empty()) {
return std::nullopt;
}

const auto& query_summary = search_query_response->conversation[0];
if (query_summary.answer.empty()) {
return std::nullopt;
}

return SearchQuerySummary(query_summary.query, query_summary.answer[0].text);
}

void AssociatedContentDriver::OnFaviconImageDataChanged() {
for (auto& conversation : associated_conversations_) {
conversation->OnFaviconImageDataChanged();
Expand All @@ -161,6 +277,7 @@ void AssociatedContentDriver::OnNewPage(int64_t navigation_id) {
cached_text_content_.clear();
content_invalidation_token_.clear();
is_video_ = false;
api_request_helper_.reset();
ConversationHandler::AssociatedContentDelegate::OnNewPage(navigation_id);
}

Expand Down
24 changes: 23 additions & 1 deletion components/ai_chat/core/browser/associated_content_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "brave/components/ai_chat/core/browser/ai_chat_service.h"
#include "brave/components/ai_chat/core/browser/conversation_handler.h"
#include "brave/components/ai_chat/core/browser/model_service.h"
#include "brave/components/ai_chat/core/common/mojom/page_content_extractor.mojom.h"
#include "brave/components/api_request_helper/api_request_helper.h"

class PrefService;

Expand All @@ -35,7 +37,8 @@ class AssociatedContentDriver
virtual void OnAssociatedContentNavigated(int new_navigation_id) {}
};

AssociatedContentDriver();
explicit AssociatedContentDriver(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory);
~AssociatedContentDriver() override;

AssociatedContentDriver(const AssociatedContentDriver&) = delete;
Expand All @@ -55,6 +58,8 @@ class AssociatedContentDriver
ConversationHandler::GetPageContentCallback callback) override;
std::string_view GetCachedTextContent() override;
bool GetCachedIsVideo() override;
void GetStagedEntriesFromContent(
ConversationHandler::GetStagedEntriesCallback callback) override;
// // Implementer should use alternative method of page content fetching
// void PrintPreviewFallback(ConversationHandler::GetPageContentCallback
// callback) override;
Expand All @@ -66,6 +71,9 @@ class AssociatedContentDriver
protected:
virtual GURL GetPageURL() const = 0;
virtual std::u16string GetPageTitle() const = 0;
// Get summarizer-key meta tag content from Brave Search SERP if exists.
virtual void GetSearchSummarizerKey(
mojom::PageContentExtractor::GetSearchSummarizerKeyCallback callback) = 0;

// Implementer should fetch content from the "page" associated with this
// conversation.
Expand Down Expand Up @@ -120,12 +128,26 @@ class AssociatedContentDriver
ConversationHandler::GetPageContentCallback callback,
int64_t navigation_id);

void OnSearchSummarizerKeyFetched(
ConversationHandler::GetStagedEntriesCallback callback,
int64_t navigation_id,
const std::optional<std::string>& key);
void OnSearchQuerySummaryFetched(
ConversationHandler::GetStagedEntriesCallback callback,
int64_t navigation_id,
api_request_helper::APIRequestResult result);
static std::optional<SearchQuerySummary> ParseSearchQuerySummaryResponse(
const base::Value& value);

raw_ptr<PrefService> pref_service_;
raw_ptr<AIChatMetrics> ai_chat_metrics_;
std::unique_ptr<AIChatCredentialManager> credential_manager_;

scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;

// Used for fetching search query summary.
std::unique_ptr<api_request_helper::APIRequestHelper> api_request_helper_;

base::ObserverList<Observer> observers_;
std::unique_ptr<base::OneShotEvent> on_page_text_fetch_complete_;
bool is_page_text_fetch_in_progress_ = false;
Expand Down
Loading

0 comments on commit a8143f6

Please sign in to comment.