Skip to content

Commit

Permalink
AIChat refactor to support standalone and persistent conversations (#…
Browse files Browse the repository at this point in the history
…24921)

* AI Chat: Introduce AIChatService, ConversationHandler, and direct bindings from UI

Modified AIChat WebUI to directly bind to both AIChatService and ConversationHandler for most operations.
Lays the groundwork for conversations to be independent of web content. In fact, most of this functionality is also within this PR. Conversation persistance (in-memory) is guarded behind a feature flag.

Refactor print preview extractor to be passed to AIChatTabHelper directly
Also removed max page content consideration since:
1) Model can be changed after fetching the content
2) Multiple callbacks could occur with different page content limits from different conversations
3) We need to know the (reasonable) total content length to report the percentage of content that Leo has been sent.
I did consider sending the requested page content maximum length for each GetContent call, but that does not solve all the issues. Since there is a maximum limit of 20 print preview pages, it seems it's ok.

kAIChatHistory flag -> AIChatHistory

* ConversationHandler doesn't need to deal with navigation ID

* test fix

* AIChatTabHelper params instead of multiple test. Always trim content.

* test and review feedback - comments, id->uuid, page-navigation-tests

* fix for android build

* fix same-document back/forward navigation by considering page title changes during navigation

* ios refactor

* Fix compiling on iOS. Fix Service registration crash.

* Fix crashes on iOS. Fix logic so AIChat on iOS works correctly. Fix models list conversion to iOS.

* fix ConversationHandler::GenerateQuestions, refactor non-conversation rewriting out of ConversationHandler, test ConversationHandler::GetState

* feedback

* fix AIChatRenderViewContextMenuBrowserTest

* don't wait for client connection before submitting human message

* AIChatService::MaybeAssociateContentWithConversation

* feedback

* android HandleVoiceRecognition now optionally passes ConversationId to target a specific conversation

* feedback

* fix ModelService migrating from chat-claude-instant default model pref value

* feedback

* format

* rebase fixes

* AIChatTabHelper refine and test retry logic

* fix android compile

* fix android again

* no channel_info new string

* associatedcontentdriver - remove is_page_text_fetch_in_progress_

* ConversationHandler::HasAnyHistory ignores staged entries

* AIChatService: erase from content_conversation map, and test it

* MaybeUnlink should check if client is connected

* fix android again?

* ChromeAutocompleteProviderClient should check AIChatService exists

---------

Co-authored-by: Brandon T <JustBrandonT@gmail.com>
  • Loading branch information
petemill and Brandon-T authored Sep 20, 2024
1 parent 36ceeea commit a0424cf
Show file tree
Hide file tree
Showing 125 changed files with 9,302 additions and 6,634 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,9 @@ private static void openURL(String url) {
}

@CalledByNative
private static void handleVoiceRecognition(
WebContents chatWindowWebContents, WebContents contextWebContents) {
private static void handleVoiceRecognition(WebContents webContents, String conversation_uuid) {
new BraveLeoVoiceRecognitionHandler(
chatWindowWebContents.getTopLevelNativeWindow(), contextWebContents)
webContents.getTopLevelNativeWindow(), webContents, conversation_uuid)
.startVoiceRecognition();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,12 @@ public static void verifySubscription(Callback callback) {
}

public static void openLeoQuery(
WebContents webContents, String query, boolean openLeoChatWindow) {
WebContents webContents,
String conversationUuid,
String query,
boolean openLeoChatWindow) {
try {
BraveLeoUtilsJni.get().openLeoQuery(webContents, query);
BraveLeoUtilsJni.get().openLeoQuery(webContents, conversationUuid, query);
if (openLeoChatWindow) {
BraveActivity activity = BraveActivity.getBraveActivity();
activity.openBraveLeo();
Expand Down Expand Up @@ -106,6 +109,6 @@ public static void bringMainActivityOnTop() {

@NativeMethods
public interface Natives {
void openLeoQuery(WebContents webContents, String query);
void openLeoQuery(WebContents webContents, String conversationUuid, String query);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class BraveLeoVoiceRecognitionHandler {
private static final String TAG = "LeoVoiceRecognition";
private WindowAndroid mWindowAndroid;
private WebContents mContextWebContents;
private String mConversationUuid;

/** Callback for when we receive voice search results after initiating voice recognition. */
class VoiceRecognitionCompleteCallback implements WindowAndroid.IntentCallback {
Expand Down Expand Up @@ -66,7 +67,8 @@ private void handleTranscriptionResult(Intent data) {
if (TextUtils.isEmpty(topResultQuery)) {
return;
}
BraveLeoUtils.openLeoQuery(mContextWebContents, topResultQuery, false);
BraveLeoUtils.openLeoQuery(
mContextWebContents, mConversationUuid, topResultQuery, false);
}
}

Expand Down Expand Up @@ -96,9 +98,10 @@ public float getConfidence() {
}

public BraveLeoVoiceRecognitionHandler(
WindowAndroid windowAndroid, WebContents contextWebContents) {
WindowAndroid windowAndroid, WebContents contextWebContents, String conversationUuid) {
mWindowAndroid = windowAndroid;
mContextWebContents = contextWebContents;
mConversationUuid = conversationUuid;
}

private List<BraveLeoVoiceRecognitionHandler.VoiceResult> convertBundleToVoiceResults(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ public boolean isLeoEnabled() {
}

@Override
public void openLeoQuery(WebContents webContents, String query) {
public void openLeoQuery(WebContents webContents, String conversationUuid, String query) {
mDelegate.clearOmniboxFocus();
BraveLeoUtils.openLeoQuery(webContents, query, true);
BraveLeoUtils.openLeoQuery(webContents, conversationUuid, query, true);
}

@Override
Expand All @@ -166,7 +166,7 @@ void onVoiceResults(@Nullable List<VoiceRecognitionHandler.VoiceResult> voiceRes
// Remove the start word from the query and process it.
topResultQuery =
topResultQuery.substring(LEO_START_WORD_UPPER_CASE.length()).trim();
openLeoQuery(tab.getWebContents(), topResultQuery);
openLeoQuery(tab.getWebContents(), "", topResultQuery);

// Clear the voice results to prevent the query from being processed by Chromium
// since it's already handled by Leo.
Expand Down
1 change: 1 addition & 0 deletions browser/DEPS
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ include_rules = [
"+media/webrtc", # For webrtc media switches.
"+mojo/public",
"+net",
"+printing/buildflags/buildflags.h",
"+sandbox/mac",
"+sandbox/policy",
"+services/audio/public",
Expand Down
18 changes: 10 additions & 8 deletions browser/ai_chat/ai_chat_browsertests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,21 @@ class AiChatBrowserTest : public InProcessBrowserTest {
std::string FetchPageContent() {
std::string content;
base::RunLoop run_loop;
ai_chat::FetchPageContent(
browser()->tab_strip_model()->GetActiveWebContents(), "",
base::BindLambdaForTesting(
[&run_loop, &content](std::string page_content, bool is_video,
std::string invalidation_token) {
content = std::move(page_content);
run_loop.Quit();
}));
page_content_fetcher_ = std::make_unique<PageContentFetcher>(
browser()->tab_strip_model()->GetActiveWebContents());
page_content_fetcher_->FetchPageContent(
"", base::BindLambdaForTesting(
[&run_loop, &content](std::string page_content, bool is_video,
std::string invalidation_token) {
content = std::move(page_content);
run_loop.Quit();
}));
run_loop.Run();
return content;
}

private:
std::unique_ptr<PageContentFetcher> page_content_fetcher_;
content::ContentMockCertVerifier mock_cert_verifier_;
net::EmbeddedTestServer https_server_{net::EmbeddedTestServer::TYPE_HTTPS};
};
Expand Down
2 changes: 1 addition & 1 deletion browser/ai_chat/ai_chat_metrics_browsertest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ IN_PROC_BROWSER_TEST_F(AIChatMetricsTest, ContextMenuActions) {
ai_chat_metrics_->RecordEnabled(
true, true,
base::BindLambdaForTesting(
[&](mojom::PageHandler::GetPremiumStatusCallback callback) {
[&](mojom::Service::GetPremiumStatusCallback callback) {
std::move(callback).Run(mojom::PremiumStatus::Active, nullptr);
}));
histogram_tester_.ExpectUniqueSample(kMostUsedContextMenuActionHistogramName,
Expand Down
125 changes: 62 additions & 63 deletions browser/ai_chat/ai_chat_render_view_context_menu_browsertest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,22 @@

#include "base/path_service.h"
#include "base/run_loop.h"
#include "base/test/bind.h"
#include "brave/app/brave_command_ids.h"
#include "brave/browser/ai_chat/ai_chat_service_factory.h"
#include "brave/browser/ui/brave_browser.h"
#include "brave/browser/ui/sidebar/sidebar_controller.h"
#include "brave/browser/ui/sidebar/sidebar_model.h"
#include "brave/components/ai_chat/content/browser/ai_chat_tab_helper.h"
#include "brave/components/ai_chat/core/browser/ai_chat_service.h"
#include "brave/components/ai_chat/core/browser/engine/engine_consumer.h"
#include "brave/components/ai_chat/core/browser/engine/mock_engine_consumer.h"
#include "brave/components/ai_chat/core/browser/engine/mock_remote_completion_client.h"
#include "brave/components/ai_chat/core/browser/utils.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"
#include "brave/components/constants/brave_paths.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/renderer_context_menu/render_view_context_menu.h"
#include "chrome/browser/renderer_context_menu/render_view_context_menu_browsertest_util.h"
#include "chrome/browser/renderer_context_menu/render_view_context_menu_test_util.h"
#include "chrome/browser/ui/browser.h"
#include "chrome/browser/ui/tabs/tab_strip_model.h"
Expand All @@ -44,40 +47,21 @@ using ::testing::_;

namespace ai_chat {

class MockEngineConsumer : public EngineConsumer {
public:
MOCK_METHOD(void,
GenerateQuestionSuggestions,
(const bool&, const std::string&, SuggestedQuestionsCallback),
(override));
MOCK_METHOD(void,
GenerateAssistantResponse,
(const bool&,
const std::string&,
const ConversationHistory&,
const std::string&,
GenerationDataCallback,
GenerationCompletedCallback),
(override));
MOCK_METHOD(void,
GenerateRewriteSuggestion,
(std::string,
const std::string&,
GenerationDataCallback,
GenerationCompletedCallback),
(override));
MOCK_METHOD(void, SanitizeInput, (std::string&), (override));
MOCK_METHOD(void, ClearAllQueries, (), (override));
MOCK_METHOD(void,
UpdateModelOptions,
(const mojom::ModelOptions&),
(override));
};
namespace {

void ExecuteRewriteCommand(RenderViewContextMenu* context_menu) {
// Calls EngineConsumer::GenerateRewriteSuggestion
context_menu->ExecuteCommand(IDC_AI_CHAT_CONTEXT_SHORTEN, 0);
context_menu->Cancel();
}

} // namespace

class AIChatRenderViewContextMenuBrowserTest : public InProcessBrowserTest {
public:
AIChatRenderViewContextMenuBrowserTest()
: https_server_(net::EmbeddedTestServer::TYPE_HTTPS) {}
: ai_engine_(std::make_unique<MockEngineConsumer>()),
https_server_(net::EmbeddedTestServer::TYPE_HTTPS) {}

~AIChatRenderViewContextMenuBrowserTest() override = default;

Expand Down Expand Up @@ -114,28 +98,13 @@ class AIChatRenderViewContextMenuBrowserTest : public InProcessBrowserTest {

void TestRewriteInPlace(
content::WebContents* web_contents,
MockEngineConsumer* mock_engine,
const std::string& element_id,
const std::string& expected_selected_text,
const std::vector<std::string>& received_data,
base::expected<std::string, mojom::APIError> completed_result,
const std::string& expected_updated_text) {
base::RunLoop run_loop;
// Verify that rewrite is requested
EXPECT_CALL(*mock_engine, GenerateRewriteSuggestion(_, _, _, _))
.WillOnce([&](std::string text, const std::string& question,
EngineConsumer::GenerationDataCallback data_callback,
EngineConsumer::GenerationCompletedCallback callback) {
ASSERT_TRUE(callback);
ASSERT_TRUE(data_callback);
for (const auto& data : received_data) {
auto event = mojom::ConversationEntryEvent::NewCompletionEvent(
mojom::CompletionEvent::New(data));
data_callback.Run(std::move(event));
}
std::move(callback).Run(completed_result);
run_loop.Quit();
});
MockEngineConsumer* ai_engine;

// Select text in the element and create context menu to execute a rewrite
// command.
Expand All @@ -154,15 +123,42 @@ class AIChatRenderViewContextMenuBrowserTest : public InProcessBrowserTest {
base::StringPrintf("getRectY('%s')", element_id.c_str()))
.ExtractInt();

// Calls ConversationDriver::SubmitSelectedText
ContextMenuNotificationObserver context_menu_observer(
IDC_AI_CHAT_CONTEXT_SHORTEN);
RenderViewContextMenu::RegisterMenuShownCallbackForTesting(
base::BindLambdaForTesting([&](RenderViewContextMenu* context_menu) {
auto* brave_context_menu =
static_cast<BraveRenderViewContextMenu*>(context_menu);
brave_context_menu->SetAIEngineForTesting(
std::make_unique<MockEngineConsumer>());
ai_engine = static_cast<MockEngineConsumer*>(
brave_context_menu->GetAIEngineForTesting());
// Verify that rewrite is requested
EXPECT_CALL(*ai_engine, GenerateRewriteSuggestion(_, _, _, _))
.WillOnce(
[&](std::string text, const std::string& question,
EngineConsumer::GenerationDataCallback data_callback,
EngineConsumer::GenerationCompletedCallback callback) {
ASSERT_TRUE(callback);
ASSERT_TRUE(data_callback);
for (const auto& data : received_data) {
auto event =
mojom::ConversationEntryEvent::NewCompletionEvent(
mojom::CompletionEvent::New(data));
data_callback.Run(std::move(event));
}
std::move(callback).Run(completed_result);
run_loop.Quit();
});
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(&ExecuteRewriteCommand, context_menu));
}));

web_contents->GetPrimaryMainFrame()
->GetRenderViewHost()
->GetWidget()
->ShowContextMenuAtPoint(gfx::Point(x, y), ui::MENU_SOURCE_MOUSE);
run_loop.Run();
testing::Mock::VerifyAndClearExpectations(mock_engine);
EXPECT_NE(ai_engine, nullptr);
testing::Mock::VerifyAndClearExpectations(ai_engine);

// Verify that the text is rewritten as expected.
std::string updated_text =
Expand All @@ -187,6 +183,7 @@ class AIChatRenderViewContextMenuBrowserTest : public InProcessBrowserTest {
}

private:
std::unique_ptr<MockEngineConsumer> ai_engine_;
content::ContentMockCertVerifier mock_cert_verifier_;
net::test_server::EmbeddedTestServer https_server_;
};
Expand All @@ -205,38 +202,40 @@ IN_PROC_BROWSER_TEST_F(AIChatRenderViewContextMenuBrowserTest, RewriteInPlace) {
ai_chat::AIChatTabHelper::FromWebContents(contents);
ASSERT_TRUE(helper);

helper->SetEngineForTesting(std::make_unique<MockEngineConsumer>());
auto* mock_engine =
static_cast<MockEngineConsumer*>(helper->GetEngineForTesting());
ConversationHandler* conversation_handler =
ai_chat::AIChatServiceFactory::GetInstance()
->GetForBrowserContext(browser()->profile())
->GetOrCreateConversationHandlerForContent(helper->GetContentId(),
helper->GetWeakPtr());
ASSERT_TRUE(conversation_handler);

// Test rewriting textarea value and verify that the response tag is ignored
// by BraveRenderViewContextMenu
TestRewriteInPlace(contents, mock_engine, "textarea", "I'm textarea.",
TestRewriteInPlace(contents, "textarea", "I'm textarea.",
{"O", "OK", "<", "</", "</r", "</re", "</response"}, "",
"OK");

// Do the same again to make sure it still works at the second time.
TestRewriteInPlace(contents, mock_engine, "textarea", "OK",
{"O", "OK", "OK2"}, "", "OK2");
TestRewriteInPlace(contents, "textarea", "OK", {"O", "OK", "OK2"}, "", "OK2");

// Select text in text input and create context menu to execute a rewrite cmd.
// Verify that the text is rewritten.
TestRewriteInPlace(contents, mock_engine, "input_text", "I'm input.",
{"O", "OK", "OK3"}, "", "OK3");
TestRewriteInPlace(contents, mock_engine, "contenteditable",
"I'm contenteditable.", {"O", "OK", "OK4"}, "", "OK4");
TestRewriteInPlace(contents, "input_text", "I'm input.", {"O", "OK", "OK3"},
"", "OK3");
TestRewriteInPlace(contents, "contenteditable", "I'm contenteditable.",
{"O", "OK", "OK4"}, "", "OK4");

// Error case handling tests and verify that the text is not rewritten.
// 1) Get error in completed callback immediately.
EXPECT_FALSE(IsAIChatSidebarActive());
TestRewriteInPlace(contents, mock_engine, "textarea", "OK2", {},
TestRewriteInPlace(contents, "textarea", "OK2", {},
base::unexpected(mojom::APIError::ConnectionIssue), "OK2");
EXPECT_TRUE(IsAIChatSidebarActive());
GetSidebarController()->DeactivateCurrentPanel();

EXPECT_FALSE(IsAIChatSidebarActive());
// 2) Get partial streaming responses then error in completed callback.
TestRewriteInPlace(contents, mock_engine, "textarea", "OK2", {"N", "O"},
TestRewriteInPlace(contents, "textarea", "OK2", {"N", "O"},
base::unexpected(mojom::APIError::ConnectionIssue), "OK2");
EXPECT_TRUE(IsAIChatSidebarActive());
}
Expand Down
Loading

0 comments on commit a0424cf

Please sign in to comment.