Skip to content

Commit

Permalink
fix change model for existing conversations when default changes
Browse files Browse the repository at this point in the history
  • Loading branch information
petemill committed Sep 11, 2024
1 parent f7fa4d5 commit e8ce0d2
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 6 deletions.
1 change: 1 addition & 0 deletions components/ai_chat/core/browser/conversation_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,7 @@ void ConversationHandler::OnDefaultModelChanged(const std::string& old_key,
const std::string& new_key) {
// When default model changes, change any conversation that
// has that model.
DVLOG(1) << "Default model changed from " << old_key << " to " << new_key;
if (model_key_ == old_key) {
ChangeModel(new_key);
}
Expand Down
1 change: 0 additions & 1 deletion components/ai_chat/core/browser/conversation_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,6 @@ class ConversationHandler : public mojom::ConversationHandler,
raw_ptr<AIChatFeedbackAPI> feedback_api_;
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;

// Temporary
base::ScopedObservation<ModelService, ModelService::Observer>
models_observer_{this};

Expand Down
6 changes: 3 additions & 3 deletions components/ai_chat/core/browser/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,17 +438,17 @@ void ModelService::SetDefaultModelKey(const std::string& new_key) {
// Don't continue migrating if user choses another default in the meantime
is_migrating_claude_instant_ = false;

const std::string& current_default_key = GetDefaultModelKey();
const std::string previous_default_key = GetDefaultModelKey();

if (current_default_key == new_key) {
if (previous_default_key == new_key) {
// Nothing to do
return;
}

pref_service_->SetString(kDefaultModelKey, new_key);

for (auto& obs : observers_) {
obs.OnDefaultModelChanged(current_default_key, new_key);
obs.OnDefaultModelChanged(previous_default_key, new_key);
}
}

Expand Down
54 changes: 52 additions & 2 deletions components/ai_chat/core/browser/model_service_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,69 @@
#include <string>
#include <utility>

#include "base/scoped_observation.h"
#include "base/test/scoped_feature_list.h"
#include "brave/components/ai_chat/core/common/features.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"
#include "brave/components/ai_chat/core/common/pref_names.h"
#include "components/os_crypt/sync/os_crypt_mocker.h"
#include "components/prefs/testing_pref_service.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace ai_chat {

namespace {
using ::testing::_;
using ::testing::NiceMock;

class MockModelServiceObserver : public ModelService::Observer {
public:
MockModelServiceObserver() = default;
~MockModelServiceObserver() override = default;

void Observe(ModelService* model_service) {
models_observer_.Observe(model_service);
}

MOCK_METHOD(void,
OnDefaultModelChanged,
(const std::string&, const std::string&),
(override));

private:
base::ScopedObservation<ModelService, ModelService::Observer>
models_observer_{this};
};

} // namespace

class ModelServiceTest : public ::testing::Test {
public:
void SetUp() override {
OSCryptMocker::SetUp();
prefs::RegisterProfilePrefs(pref_service_.registry());
prefs::RegisterProfilePrefsForMigration(pref_service_.registry());
ModelService::RegisterProfilePrefs(pref_service_.registry());
observer_ = std::make_unique<NiceMock<MockModelServiceObserver>>();
}

ModelService* GetService() {
if (!service_) {
service_ = std::make_unique<ModelService>(&pref_service_);
observer_->Observe(service_.get());
}
return service_.get();
}

void TearDown() override { OSCryptMocker::TearDown(); }
void TearDown() override {
OSCryptMocker::TearDown();
observer_.reset();
}

protected:
TestingPrefServiceSimple pref_service_;
std::unique_ptr<NiceMock<MockModelServiceObserver>> observer_;

private:
std::unique_ptr<ModelService> service_;
Expand Down Expand Up @@ -116,7 +149,11 @@ TEST_F(ModelServiceTestWithDifferentPremiumModel,
TEST_F(ModelServiceTestWithDifferentPremiumModel,
MigrateToPremiumDefaultModel_UserModified) {
EXPECT_EQ(GetService()->GetDefaultModelKey(), "chat-leo-expanded");
EXPECT_CALL(*observer_,
OnDefaultModelChanged("chat-leo-expanded", "chat-basic"))
.Times(1);
GetService()->SetDefaultModelKey("chat-basic");
testing::Mock::VerifyAndClearExpectations(observer_.get());
GetService()->OnPremiumStatus(mojom::PremiumStatus::Active);
EXPECT_EQ(GetService()->GetDefaultModelKey(), "chat-basic");
}
Expand Down Expand Up @@ -162,12 +199,25 @@ TEST_F(ModelServiceTest, AddAndModifyCustomModel) {
kAPIKey);
}

TEST_F(ModelServiceTest, ChangeDefaultModelKey) {
TEST_F(ModelServiceTest, ChangeDefaultModelKey_GoodKey) {
GetService()->SetDefaultModelKey("chat-basic");
EXPECT_EQ(GetService()->GetDefaultModelKey(), "chat-basic");
EXPECT_CALL(*observer_,
OnDefaultModelChanged("chat-basic", "chat-leo-expanded"))
.Times(1);
GetService()->SetDefaultModelKey("chat-leo-expanded");
EXPECT_EQ(GetService()->GetDefaultModelKey(), "chat-leo-expanded");
testing::Mock::VerifyAndClearExpectations(observer_.get());
}

TEST_F(ModelServiceTest, ChangeDefaultModelKey_IncorrectKey) {
GetService()->SetDefaultModelKey("chat-basic");
EXPECT_EQ(GetService()->GetDefaultModelKey(), "chat-basic");
EXPECT_CALL(*observer_, OnDefaultModelChanged(_, _)).Times(0);
GetService()->SetDefaultModelKey("bad-key");
// Default model key should not change if the key is invalid.
EXPECT_EQ(GetService()->GetDefaultModelKey(), "chat-basic");
testing::Mock::VerifyAndClearExpectations(observer_.get());
}

} // namespace ai_chat

0 comments on commit e8ce0d2

Please sign in to comment.