Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: support set rag only #7057

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,29 @@ class ChatBloc extends Bloc<ChatEvent, ChatState> {
},
didReceiveChatSettings: (settings) {
emit(
state.copyWith(selectedSourceIds: settings.ragIds),
state.copyWith(
selectedSourceIds: settings.ragIds.ragIds,
onlyUseSelectedSources: settings.ragOnly,
),
);
},
updateSelectedSources: (selectedSourcesIds) async {
emit(state.copyWith(selectedSourceIds: selectedSourcesIds));
final ragIds = RepeatedRagId(ragIds: selectedSourcesIds);
final payload = UpdateChatSettingsPB.create()
..chatId = ChatId(value: chatId)
..ragIds = ragIds;

await AIEventUpdateChatSettings(payload)
.send()
.onFailure(Log.error);
},
setRagOnly: (ragOnly) async {
emit(state.copyWith(onlyUseSelectedSources: ragOnly));
final payload = UpdateChatSettingsPB.create()
..chatId = ChatId(value: chatId)
..ragOnly = ragOnly;

final payload = UpdateChatSettingsPB(
chatId: ChatId(value: chatId),
ragIds: selectedSourcesIds,
);
await AIEventUpdateChatSettings(payload)
.send()
.onFailure(Log.error);
Expand Down Expand Up @@ -575,6 +588,10 @@ class ChatEvent with _$ChatEvent {
required List<String> selectedSourcesIds,
}) = _UpdateSelectedSources;

const factory ChatEvent.setRagOnly({
required bool ragOnly,
}) = _SetRagOnly;

// send message
const factory ChatEvent.sendMessage({
required String message,
Expand Down Expand Up @@ -611,12 +628,14 @@ class ChatEvent with _$ChatEvent {
@freezed
class ChatState with _$ChatState {
const factory ChatState({
required bool onlyUseSelectedSources,
required List<String> selectedSourceIds,
required LoadChatMessageStatus loadingState,
required PromptResponseState promptResponseState,
}) = _ChatState;

factory ChatState.initial() => const ChatState(
onlyUseSelectedSources: false,
selectedSourceIds: [],
loadingState: LoadChatMessageStatus.loading,
promptResponseState: PromptResponseState.ready,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class ChatSettingsCubit extends Cubit<ChatSettingsState> {
List<ChatSource> selectedSources = [];
String filter = '';

void updateOnlyUseSelectedSources(bool onlyUseSelectedSources) {
if (state.onlyUseSelectedSources != onlyUseSelectedSources) {
emit(state.copyWith(onlyUseSelectedSources: onlyUseSelectedSources));
}
}

void updateSelectedSources(List<String> newSelectedSourceIds) {
selectedSourceIds = [...newSelectedSourceIds];
}
Expand Down Expand Up @@ -380,11 +386,13 @@ class ChatSettingsCubit extends Cubit<ChatSettingsState> {
@freezed
class ChatSettingsState with _$ChatSettingsState {
const factory ChatSettingsState({
required bool onlyUseSelectedSources,
required List<ChatSource> visibleSources,
required List<ChatSource> selectedSources,
}) = _ChatSettingsState;

factory ChatSettingsState.initial() => const ChatSettingsState(
onlyUseSelectedSources: false,
visibleSources: [],
selectedSources: [],
);
Expand Down
22 changes: 16 additions & 6 deletions frontend/appflowy_flutter/lib/plugins/ai_chat/chat_page.dart
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import 'package:appflowy_result/appflowy_result.dart';
import 'package:desktop_drop/desktop_drop.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:flowy_infra_ui/flowy_infra_ui.dart';
import 'package:flutter/foundation.dart';
import 'package:flutter/material.dart';
import 'package:flutter_bloc/flutter_bloc.dart';
import 'package:flutter_chat_core/flutter_chat_core.dart';
Expand Down Expand Up @@ -296,12 +297,21 @@ class _ChatContentPage extends StatelessWidget {
),
);
},
onUpdateSelectedSources: (ids) {
chatBloc.add(
ChatEvent.updateSelectedSources(
selectedSourcesIds: ids,
),
);
onUpdateSelectedSources: (ragOnly, ids) {
if (ragOnly != chatBloc.state.onlyUseSelectedSources) {
chatBloc.add(
ChatEvent.setRagOnly(
ragOnly: ragOnly,
),
);
}
if (!listEquals(ids, chatBloc.state.selectedSourceIds)) {
chatBloc.add(
ChatEvent.updateSelectedSources(
selectedSourcesIds: ids,
),
);
}
},
)
: MobileAIPromptInput(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class DesktopAIPromptInput extends StatefulWidget {
final bool isStreaming;
final void Function() onStopStreaming;
final void Function(String, Map<String, dynamic>) onSubmitted;
final void Function(List<String>) onUpdateSelectedSources;
final void Function(bool, List<String>) onUpdateSelectedSources;

@override
State<DesktopAIPromptInput> createState() => _DesktopAIPromptInputState();
Expand Down Expand Up @@ -497,7 +497,7 @@ class _PromptBottomActions extends StatelessWidget {
final SendButtonState sendButtonState;
final void Function() onSendPressed;
final void Function() onStopStreaming;
final void Function(List<String>) onUpdateSelectedSources;
final void Function(bool, List<String>) onUpdateSelectedSources;

@override
Widget build(BuildContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class PromptInputDesktopSelectSourcesButton extends StatefulWidget {
required this.onUpdateSelectedSources,
});

final void Function(List<String>) onUpdateSelectedSources;
final void Function(bool, List<String>) onUpdateSelectedSources;

@override
State<PromptInputDesktopSelectSourcesButton> createState() =>
Expand All @@ -43,9 +43,10 @@ class _PromptInputDesktopSelectSourcesButtonState
void initState() {
super.initState();
WidgetsBinding.instance.addPostFrameCallback((_) {
cubit.updateSelectedSources(
context.read<ChatBloc>().state.selectedSourceIds,
);
final chatBlocState = context.read<ChatBloc>().state;
cubit
..updateSelectedSources(chatBlocState.selectedSourceIds)
..updateOnlyUseSelectedSources(chatBlocState.onlyUseSelectedSources);
});
}

Expand Down Expand Up @@ -80,6 +81,7 @@ class _PromptInputDesktopSelectSourcesButtonState
return BlocListener<ChatBloc, ChatState>(
listener: (context, state) {
cubit
..updateOnlyUseSelectedSources(state.onlyUseSelectedSources)
..updateSelectedSources(state.selectedSourceIds)
..updateSelectedStatus();
},
Expand All @@ -95,7 +97,10 @@ class _PromptInputDesktopSelectSourcesButtonState
}
},
onClose: () {
widget.onUpdateSelectedSources(cubit.selectedSourceIds);
widget.onUpdateSelectedSources(
cubit.state.onlyUseSelectedSources,
cubit.selectedSourceIds,
);
if (spaceView != null) {
context.read<ChatSettingsCubit>().refreshSources(spaceView);
}
Expand Down Expand Up @@ -188,6 +193,29 @@ class _PopoverContent extends StatelessWidget {
context.read<ChatSettingsCubit>().updateFilter(value),
),
),
Container(
margin: const EdgeInsets.fromLTRB(8, 0, 8, 8),
height: 30,
child: FlowyButton(
text: FlowyText(
LocaleKeys.chat_onlyUseRags.tr(),
overflow: TextOverflow.ellipsis,
),
onTap: () {
context
.read<ChatSettingsCubit>()
.updateOnlyUseSelectedSources(
!state.onlyUseSelectedSources,
);
},
rightIcon: state.onlyUseSelectedSources
? FlowySvg(
FlowySvgs.check_s,
color: Theme.of(context).colorScheme.primary,
)
: null,
),
),
_buildDivider(),
Flexible(
child: ListView(
Expand Down
3 changes: 2 additions & 1 deletion frontend/resources/translations/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@
"addToPageButton": "Add to page",
"addToPageTitle": "Add message to...",
"addToNewPage": "Add to a new page",
"addToNewPageName": "Messages extracted from \"{}\""
"addToNewPageName": "Messages extracted from \"{}\"",
"onlyUseRags": "Selected sources only"
},
"trash": {
"text": "Trash",
Expand Down
64 changes: 36 additions & 28 deletions frontend/rust-lib/flowy-ai/src/ai_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::notification::{chat_notification_builder, ChatNotification};
use flowy_storage_pub::storage::StorageService;
use lib_infra::async_trait::async_trait;
use lib_infra::util::timestamp;
use serde_json::json;
use std::path::PathBuf;
use std::sync::{Arc, Weak};
use tracing::{error, info, trace};
Expand Down Expand Up @@ -344,12 +345,12 @@ impl AIManager {
Ok(())
}

pub async fn get_rag_ids(&self, chat_id: &str) -> FlowyResult<Vec<String>> {
pub async fn get_chat_settings(&self, chat_id: &str) -> FlowyResult<ChatSettingsPB> {
if let Some(settings) = self
.store_preferences
.get_object::<ChatSettings>(&setting_store_key(chat_id))
{
return Ok(settings.rag_ids);
return Ok(settings.into());
}

let settings = refresh_chat_setting(
Expand All @@ -359,49 +360,58 @@ impl AIManager {
chat_id,
)
.await?;
Ok(settings.rag_ids)
Ok(settings.into())
}

pub async fn update_rag_ids(&self, chat_id: &str, rag_ids: Vec<String>) -> FlowyResult<()> {
pub async fn update_settings(
&self,
chat_id: &str,
rag_ids: Option<Vec<String>>,
rag_only: Option<bool>,
) -> FlowyResult<()> {
info!("[Chat] update chat:{} rag ids: {:?}", chat_id, rag_ids);

let workspace_id = self.user_service.workspace_id()?;
let update_setting = UpdateChatParams {
name: None,
metadata: None,
rag_ids: Some(rag_ids.clone()),
metadata: rag_only.map(|rag_only| json!({"rag_only": rag_only})),
rag_ids: rag_ids.clone(),
};

self
.cloud_service_wm
.update_chat_settings(&workspace_id, chat_id, update_setting)
.await?;

let chat_setting_store_key = setting_store_key(chat_id);

if let Some(settings) = self
if let Some(mut settings) = self
.store_preferences
.get_object::<ChatSettings>(&chat_setting_store_key)
{
if let Err(err) = self.store_preferences.set_object(
&chat_setting_store_key,
&ChatSettings {
rag_ids: rag_ids.clone(),
..settings
},
) {
if let Some(rag_only) = rag_only {
settings.metadata = json!({"rag_only": rag_only});
}

if let Some(rag_ids) = rag_ids {
settings.rag_ids = rag_ids.clone();
let user_service = self.user_service.clone();
let external_service = self.external_service.clone();
tokio::spawn(async move {
if let Ok(workspace_id) = user_service.workspace_id() {
let _ = external_service
.sync_rag_documents(&workspace_id, rag_ids)
.await;
}
});
}

if let Err(err) = self
.store_preferences
.set_object(&chat_setting_store_key, &settings)
{
error!("failed to set chat settings: {}", err);
}
}

let user_service = self.user_service.clone();
let external_service = self.external_service.clone();
tokio::spawn(async move {
if let Ok(workspace_id) = user_service.workspace_id() {
let _ = external_service
.sync_rag_documents(&workspace_id, rag_ids)
.await;
}
});
Ok(())
}
}
Expand Down Expand Up @@ -438,9 +448,7 @@ async fn refresh_chat_setting(
}

chat_notification_builder(chat_id, ChatNotification::DidUpdateChatSettings)
.payload(ChatSettingsPB {
rag_ids: settings.rag_ids.clone(),
})
.payload(ChatSettingsPB::from(settings.clone()))
.send();

Ok(settings)
Expand Down
36 changes: 33 additions & 3 deletions frontend/rust-lib/flowy-ai/src/entities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use std::collections::HashMap;

use crate::local_ai::local_llm_resource::PendingResource;
use flowy_ai_pub::cloud::{
ChatMessage, LLMModel, RelatedQuestion, RepeatedChatMessage, RepeatedRelatedQuestion,
ChatMessage, ChatSettings, LLMModel, RelatedQuestion, RepeatedChatMessage,
RepeatedRelatedQuestion,
};
use flowy_derive::{ProtoBuf, ProtoBuf_Enum};
use lib_infra::validator_fn::required_not_empty_str;
Expand Down Expand Up @@ -542,7 +543,27 @@ pub struct CreateChatContextPB {
#[derive(Default, ProtoBuf, Clone, Debug)]
pub struct ChatSettingsPB {
#[pb(index = 1)]
pub rag_ids: Vec<String>,
pub rag_ids: RepeatedRagId,

#[pb(index = 2)]
pub rag_only: bool,
}

impl From<ChatSettings> for ChatSettingsPB {
fn from(value: ChatSettings) -> Self {
let rag_ids = RepeatedRagId {
rag_ids: value.rag_ids.clone(),
};

let rag_only = value
.metadata
.as_object()
.and_then(|map| map.get("rag_only"))
.and_then(|value| value.as_bool())
.unwrap_or_default();

Self { rag_ids, rag_only }
}
}

#[derive(Default, ProtoBuf, Clone, Debug, Validate)]
Expand All @@ -551,6 +572,15 @@ pub struct UpdateChatSettingsPB {
#[validate(nested)]
pub chat_id: ChatId,

#[pb(index = 2)]
#[pb(index = 2, one_of)]
pub rag_ids: Option<RepeatedRagId>,

#[pb(index = 3, one_of)]
pub rag_only: Option<bool>,
}

#[derive(Default, ProtoBuf, Clone, Debug)]
pub struct RepeatedRagId {
#[pb(index = 1)]
pub rag_ids: Vec<String>,
}
Loading
Loading