Skip to content

Commit

Permalink
Reimplement content refinement function calls from #24953
Browse files Browse the repository at this point in the history
  • Loading branch information
petemill committed Aug 27, 2024
1 parent 5b80837 commit 1e5ceda
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ void AssociatedContentDriver::OnNewPage(int64_t navigation_id) {
cached_text_content_.clear();
content_invalidation_token_.clear();
is_video_ = false;
ConversationHandler::AssociatedContentDelegate::OnNewPage(navigation_id);
}

} // namespace ai_chat
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class AssociatedContentDriver

// To be called when a page navigation is detected and a new conversation
// is expected.
void OnNewPage(int64_t navigation_id);
void OnNewPage(int64_t navigation_id) override;

// Begin the alternative content fetching (print preview / OCR) by
// sending a message to an observer with access to the layer this can
Expand Down
137 changes: 137 additions & 0 deletions components/ai_chat/core/browser/conversation_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@
#include <vector>

#include "base/containers/fixed_flat_set.h"
#include "base/files/file_path.h"
#include "base/memory/weak_ptr.h"
#include "base/no_destructor.h"
#include "base/strings/string_split.h"
#include "base/strings/utf_string_conversions.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/thread_pool.h"
#include "base/types/expected.h"
#include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h"
#include "brave/components/ai_chat/core/browser/ai_chat_feedback_api.h"
#include "brave/components/ai_chat/core/browser/ai_chat_service.h"
#include "brave/components/ai_chat/core/browser/associated_archive_content.h"
#include "brave/components/ai_chat/core/browser/constants.h"
#include "brave/components/ai_chat/core/browser/leo_local_models_updater.h"
#include "brave/components/ai_chat/core/browser/model_service.h"
#include "brave/components/ai_chat/core/common/features.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h"
Expand Down Expand Up @@ -87,11 +93,95 @@ const std::string& GetActionTypeQuestion(mojom::ActionType action_type) {
return iter->second;
}

uint32_t GetMaxContentLengthForModel(const mojom::Model& model) {
return model.options->is_custom_model_options()
? kCustomModelMaxPageContentLength
: model.options->get_leo_model_options()->max_page_content_length;
}

} // namespace

ConversationHandler::AssociatedContentDelegate::AssociatedContentDelegate()
: text_embedder_(nullptr, base::OnTaskRunnerDeleter(nullptr)) {}

ConversationHandler::AssociatedContentDelegate::~AssociatedContentDelegate() =
default;

void ConversationHandler::AssociatedContentDelegate::OnNewPage(
int64_t navigation_id) {
pending_top_similarity_requests_.clear();
if (text_embedder_) {
text_embedder_->CancelAllTasks();
text_embedder_.reset();
}
}

void ConversationHandler::AssociatedContentDelegate::
GetTopSimilarityWithPromptTilContextLimit(
const std::string& prompt,
const std::string& text,
uint32_t context_limit,
TextEmbedder::TopSimilarityCallback callback) {
// Run immediately if already initialized
if (text_embedder_ && text_embedder_->IsInitialized()) {
text_embedder_->GetTopSimilarityWithPromptTilContextLimit(
prompt, text, context_limit, std::move(callback));
return;
}

// Will have to wait for initialization to complete, store params for calling
// later.
pending_top_similarity_requests_.emplace_back(prompt, text, context_limit,
std::move(callback));

// Start initialization if not already started
if (!text_embedder_) {
base::FilePath universal_qa_model_path =
LeoLocalModelsUpdaterState::GetInstance()->GetUniversalQAModel();
// Tasks in TextEmbedder are run on |embedder_task_runner|. The
// text_embedder_ must be deleted on that sequence to guarantee that pending
// tasks can safely be executed.
scoped_refptr<base::SequencedTaskRunner> embedder_task_runner =
base::ThreadPool::CreateSequencedTaskRunner(
{base::MayBlock(), base::TaskPriority::USER_BLOCKING});
text_embedder_ = TextEmbedder::Create(
base::FilePath(universal_qa_model_path), embedder_task_runner);
if (!text_embedder_) {
auto& item = pending_top_similarity_requests_.back();
std::move(std::get<3>(item))
.Run(base::unexpected("Failed to create TextEmbedder"));
pending_top_similarity_requests_.pop_back();
return;
}
text_embedder_->Initialize(
base::BindOnce(&ConversationHandler::AssociatedContentDelegate::
OnTextEmbedderInitialized,
weak_ptr_factory_.GetWeakPtr()));
}
}

void ConversationHandler::AssociatedContentDelegate::OnTextEmbedderInitialized(
bool initialized) {
if (!initialized) {
VLOG(1) << "Failed to initialize TextEmbedder";
for (auto& callback_info : pending_top_similarity_requests_) {
std::move(std::get<3>(callback_info))
.Run(base::unexpected<std::string>(
"Failed to initialize TextEmbedder"));
}
pending_top_similarity_requests_.clear();
return;
}

CHECK(text_embedder_ && text_embedder_->IsInitialized());
for (auto& callback_info : pending_top_similarity_requests_) {
text_embedder_->GetTopSimilarityWithPromptTilContextLimit(
std::get<0>(callback_info), std::get<1>(callback_info),
std::get<2>(callback_info), std::move(std::get<3>(callback_info)));
}
pending_top_similarity_requests_.clear();
}

ConversationHandler::ConversationHandler(
const mojom::Conversation* conversation,
AIChatService* ai_chat_service,
Expand Down Expand Up @@ -840,6 +930,27 @@ void ConversationHandler::PerformAssistantGeneration(
auto data_completed_callback =
base::BindOnce(&ConversationHandler::OnEngineCompletionComplete,
weak_ptr_factory_.GetWeakPtr(), associated_content_uuid);

bool should_refine_page_content =
features::IsPageContentRefineEnabled() &&
page_content.length() > GetMaxContentLengthForModel(GetCurrentModel()) &&
input != l10n_util::GetStringUTF8(IDS_AI_CHAT_QUESTION_SUMMARIZE_PAGE);
if (should_refine_page_content && associated_content_delegate_) {
DVLOG(2) << "Asking to refine content, which is of length: "
<< page_content.length();
associated_content_delegate_->GetTopSimilarityWithPromptTilContextLimit(
input, page_content, GetMaxContentLengthForModel(GetCurrentModel()),
base::BindOnce(&ConversationHandler::OnGetRefinedPageContent,
weak_ptr_factory_.GetWeakPtr(), input,
std::move(data_received_callback),
std::move(data_completed_callback), page_content,
is_video));
return;
} else if (!should_refine_page_content && is_content_refined_) {
is_content_refined_ = false;
OnAssociatedContentInfoChanged();
}

engine_->GenerateAssistantResponse(is_video, page_content, chat_history_,
input, std::move(data_received_callback),
std::move(data_completed_callback));
Expand Down Expand Up @@ -982,6 +1093,31 @@ void ConversationHandler::OnGeneratePageContentComplete(
OnAssociatedContentInfoChanged();
}

void ConversationHandler::OnGetRefinedPageContent(
const std::string& input,
EngineConsumer::GenerationDataCallback data_received_callback,
EngineConsumer::GenerationCompletedCallback data_completed_callback,
std::string page_content,
bool is_video,
base::expected<std::string, std::string> refined_page_content) {
std::string page_content_to_use = std::move(page_content);
if (refined_page_content.has_value()) {
page_content_to_use = std::move(refined_page_content.value());
is_content_refined_ = true;
OnAssociatedContentInfoChanged();
} else {
VLOG(1) << "Failed to get refined page content: "
<< refined_page_content.error();
if (is_content_refined_) {
is_content_refined_ = false;
OnAssociatedContentInfoChanged();
}
}
engine_->GenerateAssistantResponse(
is_video, page_content_to_use, chat_history_, input,
std::move(data_received_callback), std::move(data_completed_callback));
}

void ConversationHandler::OnEngineCompletionDataReceived(
int navigation_id,
mojom::ConversationEntryEventPtr result) {
Expand Down Expand Up @@ -1141,6 +1277,7 @@ void ConversationHandler::BuildAssociatedContentInfo() {
}
associated_content_info_->content_used_percentage =
GetContentUsedPercentage();
associated_content_info_->is_content_refined = is_content_refined_;
associated_content_info_->is_content_association_possible = true;
} else {
associated_content_info_->is_content_association_possible = false;
Expand Down
48 changes: 47 additions & 1 deletion components/ai_chat/core/browser/conversation_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

Expand All @@ -17,6 +18,7 @@
#include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h"
#include "brave/components/ai_chat/core/browser/engine/engine_consumer.h"
#include "brave/components/ai_chat/core/browser/model_service.h"
#include "brave/components/ai_chat/core/browser/text_embedder.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
Expand Down Expand Up @@ -54,6 +56,7 @@ class ConversationHandler : public mojom::ConversationHandler,
// Supplements a conversation with associated page content
class AssociatedContentDelegate {
public:
AssociatedContentDelegate();
virtual ~AssociatedContentDelegate();
virtual void AddRelatedConversation(ConversationHandler* conversation) {}
virtual void OnRelatedConversationDestroyed(
Expand All @@ -74,6 +77,38 @@ class ConversationHandler : public mojom::ConversationHandler,
// fetch for the content.
virtual std::string_view GetCachedTextContent() = 0;
virtual bool GetCachedIsVideo() = 0;

void GetTopSimilarityWithPromptTilContextLimit(
const std::string& prompt,
const std::string& text,
uint32_t context_limit,
TextEmbedder::TopSimilarityCallback callback);

void SetTextEmbedderForTesting(
std::unique_ptr<TextEmbedder, base::OnTaskRunnerDeleter>
text_embedder) {
text_embedder_ = std::move(text_embedder);
}
TextEmbedder* GetTextEmbedderForTesting() { return text_embedder_.get(); }

protected:
// Content has navigated
virtual void OnNewPage(int64_t navigation_id);

private:
void OnTextEmbedderInitialized(bool initialized);

// Owned by this class so that all associated conversation can benefit from
// a single cache as page content is unlikely to change between messages
// and conversations.
std::unique_ptr<TextEmbedder, base::OnTaskRunnerDeleter> text_embedder_;
std::vector<std::tuple<const std::string&,
const std::string&,
uint32_t,
TextEmbedder::TopSimilarityCallback>>
pending_top_similarity_requests_;

base::WeakPtrFactory<AssociatedContentDelegate> weak_ptr_factory_{this};
};

class Observer : public base::CheckedObserver {
Expand Down Expand Up @@ -194,6 +229,10 @@ class ConversationHandler : public mojom::ConversationHandler,
void OnModelRemoved(const std::string& removed_key) override;

private:
FRIEND_TEST_ALL_PREFIXES(PageContentRefineTest, LeoLocalModelsUpdater);
FRIEND_TEST_ALL_PREFIXES(PageContentRefineTest, TextEmbedder);
FRIEND_TEST_ALL_PREFIXES(PageContentRefineTest, TextEmbedderInitialized);

void InitEngine();
void BuildAssociatedContentInfo();
bool IsContentAssociationPossible();
Expand All @@ -218,7 +257,13 @@ class ConversationHandler : public mojom::ConversationHandler,
std::string contents_text,
bool is_video,
std::string invalidation_token);
void OnExistingGeneratePageContentComplete(GetPageContentCallback callback);
void OnGetRefinedPageContent(
const std::string& input,
EngineConsumer::GenerationDataCallback data_received_callback,
EngineConsumer::GenerationCompletedCallback data_completed_callback,
std::string page_content,
bool is_video,
base::expected<std::string, std::string> refined_page_content);
void OnEngineCompletionDataReceived(int associated_content_uuid,
mojom::ConversationEntryEventPtr result);
void OnEngineCompletionComplete(int associated_content_uuid,
Expand Down Expand Up @@ -274,6 +319,7 @@ class ConversationHandler : public mojom::ConversationHandler,
// but change AssociatedContentDelegate as the active Tab navigates to
// different pages.
bool should_send_page_contents_ = false;
bool is_content_refined_ = false;

bool is_print_preview_fallback_requested_ = false;

Expand Down

0 comments on commit 1e5ceda

Please sign in to comment.