From 20f3116bd99bb6a92f258f86b3716454e05dbb45 Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Fri, 2 May 2025 15:37:40 +0800 Subject: [PATCH] refactor: Model select (#7875) * refactor: model select * refactor: add test * fix: add source * fix: add source * chore: notify all unset source * chore: fix test --- .../settings/ai/settings_ai_bloc.dart | 5 +- .../setting_ai_view/model_selection.dart | 4 +- frontend/rust-lib/flowy-ai-pub/src/cloud.rs | 16 + frontend/rust-lib/flowy-ai/src/ai_manager.rs | 414 ++++++--------- .../rust-lib/flowy-ai/src/event_handler.rs | 3 +- frontend/rust-lib/flowy-ai/src/lib.rs | 4 +- .../flowy-ai/src/local_ai/controller.rs | 8 +- .../src/middleware/chat_service_mw.rs | 11 + .../rust-lib/flowy-ai/src/model_select.rs | 471 ++++++++++++++++++ .../flowy-ai/src/model_select_test.rs | 434 ++++++++++++++++ .../src/offline/offline_message_sync.rs | 11 + frontend/rust-lib/flowy-ai/src/util.rs | 3 - .../rust-lib/flowy-core/src/app_life_cycle.rs | 37 +- .../src/deps_resolve/cloud_service_impl.rs | 12 + .../rust-lib/flowy-core/src/server_layer.rs | 18 +- .../flowy-server/src/af_cloud/impls/chat.rs | 18 +- .../src/local_server/impls/chat.rs | 9 + 17 files changed, 1170 insertions(+), 308 deletions(-) create mode 100644 frontend/rust-lib/flowy-ai/src/model_select.rs create mode 100644 frontend/rust-lib/flowy-ai/src/model_select_test.rs delete mode 100644 frontend/rust-lib/flowy-ai/src/util.rs diff --git a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/settings_ai_bloc.dart b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/settings_ai_bloc.dart index 9409494b57..4034463ae5 100644 --- a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/settings_ai_bloc.dart +++ b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/settings_ai_bloc.dart @@ -11,7 +11,7 @@ import 'package:freezed_annotation/freezed_annotation.dart'; part 'settings_ai_bloc.freezed.dart'; -const String aiModelsGlobalActiveModel = "ai_models_global_active_model"; +const String aiModelsGlobalActiveModel = "global_active_model"; class SettingsAIBloc extends Bloc { SettingsAIBloc( @@ -75,9 +75,6 @@ class SettingsAIBloc extends Bloc { ); }, selectModel: (AIModelPB model) async { - if (!model.isLocal) { - await _updateUserWorkspaceSetting(model: model.name); - } await AIEventUpdateSelectedModel( UpdateSelectedModelPB( source: aiModelsGlobalActiveModel, diff --git a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart index 7357c2951c..a0ba9bcb34 100644 --- a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart +++ b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart @@ -17,8 +17,6 @@ class AIModelSelection extends StatelessWidget { @override Widget build(BuildContext context) { return BlocBuilder( - buildWhen: (previous, current) => - previous.availableModels != current.availableModels, builder: (context, state) { final models = state.availableModels?.models; if (models == null) { @@ -44,7 +42,7 @@ class AIModelSelection extends StatelessWidget { ), Flexible( child: SettingsDropdown( - key: const Key('_AIModelSelection'), + key: ValueKey(selectedModel.name), onChanged: (model) => context .read() .add(SettingsAIEvent.selectModel(model)), diff --git a/frontend/rust-lib/flowy-ai-pub/src/cloud.rs b/frontend/rust-lib/flowy-ai-pub/src/cloud.rs index 2292e0f332..d8bfa19a24 100644 --- a/frontend/rust-lib/flowy-ai-pub/src/cloud.rs +++ b/frontend/rust-lib/flowy-ai-pub/src/cloud.rs @@ -33,6 +33,17 @@ pub struct AIModel { pub desc: String, } +impl AIModel { + /// Create a new model instance + pub fn new(name: impl Into, description: impl Into, is_local: bool) -> Self { + Self { + name: name.into(), + desc: description.into(), + is_local, + } + } +} + impl From for AIModel { fn from(value: AvailableModel) -> Self { let desc = value @@ -175,4 +186,9 @@ pub trait ChatCloudService: Send + Sync + 'static { async fn get_available_models(&self, workspace_id: &Uuid) -> Result; async fn get_workspace_default_model(&self, workspace_id: &Uuid) -> Result; + async fn set_workspace_default_model( + &self, + workspace_id: &Uuid, + model: &str, + ) -> Result<(), FlowyError>; } diff --git a/frontend/rust-lib/flowy-ai/src/ai_manager.rs b/frontend/rust-lib/flowy-ai/src/ai_manager.rs index 3d59a92b6a..056dfeacf4 100644 --- a/frontend/rust-lib/flowy-ai/src/ai_manager.rs +++ b/frontend/rust-lib/flowy-ai/src/ai_manager.rs @@ -9,27 +9,26 @@ use flowy_ai_pub::persistence::read_chat_metadata; use std::collections::HashMap; use dashmap::DashMap; -use flowy_ai_pub::cloud::{ - AIModel, ChatCloudService, ChatSettings, UpdateChatParams, DEFAULT_AI_MODEL_NAME, -}; +use flowy_ai_pub::cloud::{AIModel, ChatCloudService, ChatSettings, UpdateChatParams}; use flowy_error::{ErrorCode, FlowyError, FlowyResult}; use flowy_sqlite::kv::KVStorePreferences; +use crate::model_select::{ + LocalAiSource, LocalModelStorageImpl, ModelSelectionControl, ServerAiSource, + ServerModelStorageImpl, SourceKey, GLOBAL_ACTIVE_MODEL_KEY, +}; use crate::notification::{chat_notification_builder, ChatNotification}; -use crate::util::ai_available_models_key; -use flowy_ai_pub::cloud::ai_dto::AvailableModel; use flowy_ai_pub::persistence::{ batch_insert_collab_metadata, batch_select_collab_metadata, AFCollabMetadata, }; use flowy_ai_pub::user_service::AIUserService; use flowy_storage_pub::storage::StorageService; use lib_infra::async_trait::async_trait; -use lib_infra::util::timestamp; use serde_json::json; use std::path::PathBuf; use std::str::FromStr; use std::sync::{Arc, Weak}; -use tokio::sync::RwLock; +use tokio::sync::Mutex; use tracing::{error, info, instrument, trace}; use uuid::Uuid; @@ -52,14 +51,6 @@ pub trait AIExternalService: Send + Sync + 'static { async fn notify_did_send_message(&self, chat_id: &Uuid, message: &str) -> Result<(), FlowyError>; } -#[derive(Debug, Default)] -struct ServerModelsCache { - models: Vec, - timestamp: Option, -} - -pub const GLOBAL_ACTIVE_MODEL_KEY: &str = "global_active_model"; - pub struct AIManager { pub cloud_service_wm: Arc, pub user_service: Arc, @@ -67,7 +58,7 @@ pub struct AIManager { chats: Arc>>, pub local_ai: Arc, pub store_preferences: Arc, - server_models: Arc>, + model_control: Mutex, } impl Drop for AIManager { fn drop(&mut self) { @@ -97,6 +88,10 @@ impl AIManager { local_ai.clone(), storage_service, )); + let mut model_control = ModelSelectionControl::new(); + model_control.set_local_storage(LocalModelStorageImpl(store_preferences.clone())); + model_control.set_server_storage(ServerModelStorageImpl(cloud_service_wm.clone())); + model_control.add_source(Box::new(ServerAiSource::new(cloud_service_wm.clone()))); Self { cloud_service_wm, @@ -105,7 +100,7 @@ impl AIManager { local_ai, external_service, store_preferences, - server_models: Arc::new(Default::default()), + model_control: Mutex::new(model_control), } } @@ -134,10 +129,6 @@ impl AIManager { info!("[AI Manager] Local AI is running but not enabled, shutting it down"); let local_ai = self.local_ai.clone(); tokio::spawn(async move { - // Wait for 5 seconds to allow other services to initialize - // TODO: pick a right time to start plugin service. Maybe [UserStatusCallback::did_launch] - tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; - if let Err(err) = local_ai.toggle_plugin(false).await { error!("[AI Manager] failed to shutdown local AI: {:?}", err); } @@ -150,10 +141,6 @@ impl AIManager { info!("[AI Manager] Local AI is enabled but not running, starting it now"); let local_ai = self.local_ai.clone(); tokio::spawn(async move { - // Wait for 5 seconds to allow other services to initialize - // TODO: pick a right time to start plugin service. Maybe [UserStatusCallback::did_launch] - tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; - if let Err(err) = local_ai.toggle_plugin(true).await { error!("[AI Manager] failed to start local AI: {:?}", err); } @@ -167,19 +154,41 @@ impl AIManager { } } + async fn prepare_local_ai(&self, workspace_id: &Uuid) { + self + .local_ai + .reload_ollama_client(&workspace_id.to_string()); + self + .model_control + .lock() + .await + .add_source(Box::new(LocalAiSource::new(self.local_ai.clone()))); + } + #[instrument(skip_all, err)] pub async fn on_launch_if_authenticated(&self, workspace_id: &Uuid) -> Result<(), FlowyError> { + let is_enabled = self + .local_ai + .is_enabled_on_workspace(&workspace_id.to_string()); + + info!("local is enabled: {}", is_enabled); + if is_enabled { + self.prepare_local_ai(workspace_id).await; + } else { + self.model_control.lock().await.remove_local_source(); + } + self.reload_with_workspace_id(workspace_id).await; Ok(()) } pub async fn initialize_after_sign_in(&self, workspace_id: &Uuid) -> Result<(), FlowyError> { - self.reload_with_workspace_id(workspace_id).await; + self.on_launch_if_authenticated(workspace_id).await?; Ok(()) } pub async fn initialize_after_sign_up(&self, workspace_id: &Uuid) -> Result<(), FlowyError> { - self.reload_with_workspace_id(workspace_id).await; + self.on_launch_if_authenticated(workspace_id).await?; Ok(()) } @@ -188,7 +197,7 @@ impl AIManager { &self, workspace_id: &Uuid, ) -> Result<(), FlowyError> { - self.reload_with_workspace_id(workspace_id).await; + self.on_launch_if_authenticated(workspace_id).await?; Ok(()) } @@ -309,7 +318,7 @@ impl AIManager { ) -> Result { let chat = self.get_or_create_chat_instance(¶ms.chat_id).await?; let ai_model = self.get_active_model(¶ms.chat_id.to_string()).await; - let question = chat.stream_chat_message(¶ms, ai_model).await?; + let question = chat.stream_chat_message(¶ms, Some(ai_model)).await?; let _ = self .external_service .notify_did_send_message(¶ms.chat_id, ¶ms.message) @@ -332,15 +341,16 @@ impl AIManager { let model = match model { None => self.get_active_model(&chat_id.to_string()).await, - Some(model) => Some(model.into()), + Some(model) => model.into(), }; chat - .stream_regenerate_response(question_message_id, answer_stream_port, format, model) + .stream_regenerate_response(question_message_id, answer_stream_port, format, Some(model)) .await?; Ok(()) } pub async fn update_local_ai_setting(&self, setting: LocalAISetting) -> FlowyResult<()> { + let workspace_id = self.user_service.workspace_id()?; let old_settings = self.local_ai.get_local_ai_setting(); // Only restart if the server URL has changed and local AI is not running let need_restart = @@ -353,178 +363,137 @@ impl AIManager { .await?; // Handle model change if needed - let model_changed = old_settings.chat_model_name != setting.chat_model_name; - if model_changed { - info!( - "[AI Plugin] update global active model, previous: {}, current: {}", - old_settings.chat_model_name, setting.chat_model_name - ); - let model = AIModel::local(setting.chat_model_name, "".to_string()); - self - .update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model) - .await?; - } + info!( + "[AI Plugin] update global active model, previous: {}, current: {}", + old_settings.chat_model_name, setting.chat_model_name + ); + let model = AIModel::local(setting.chat_model_name, "".to_string()); + self + .update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model) + .await?; if need_restart { - self.local_ai.reload_ollama_client(); + self + .local_ai + .reload_ollama_client(&workspace_id.to_string()); self.local_ai.restart_plugin().await; } Ok(()) } - async fn get_workspace_select_model(&self) -> FlowyResult { + #[instrument(skip_all, level = "debug")] + pub async fn update_selected_model(&self, source: String, model: AIModel) -> FlowyResult<()> { let workspace_id = self.user_service.workspace_id()?; - let model = self - .cloud_service_wm - .get_workspace_default_model(&workspace_id) + let source_key = SourceKey::new(source.clone()); + self + .model_control + .lock() + .await + .set_active_model(&workspace_id, &source_key, model.clone()) .await?; - if model.is_empty() { - return Ok(DEFAULT_AI_MODEL_NAME.to_string()); - } - Ok(model) - } - - async fn get_server_available_models(&self) -> FlowyResult> { - let workspace_id = self.user_service.workspace_id()?; - let now = timestamp(); - - // First, try reading from the cache with expiration check - let should_fetch = { - let cached_models = self.server_models.read().await; - cached_models.models.is_empty() || cached_models.timestamp.map_or(true, |ts| now - ts >= 300) - }; - - if !should_fetch { - // Cache is still valid, return cached data - let cached_models = self.server_models.read().await; - return Ok(cached_models.models.clone()); - } - - // Cache miss or expired: fetch from the cloud. - match self - .cloud_service_wm - .get_available_models(&workspace_id) - .await - { - Ok(list) => { - let models = list.models; - if let Err(err) = self.update_models_cache(&models, now).await { - error!("Failed to update models cache: {}", err); - } - - Ok(models) - }, - Err(err) => { - error!("Failed to fetch available models: {}", err); - - // Return cached data if available, even if expired - let cached_models = self.server_models.read().await; - if !cached_models.models.is_empty() { - info!("Returning expired cached models due to fetch failure"); - return Ok(cached_models.models.clone()); - } - - // If no cached data, return empty list - Ok(Vec::new()) - }, - } - } - - async fn update_models_cache( - &self, - models: &[AvailableModel], - timestamp: i64, - ) -> FlowyResult<()> { - match self.server_models.try_write() { - Ok(mut cache) => { - cache.models = models.to_vec(); - cache.timestamp = Some(timestamp); - Ok(()) - }, - Err(_) => { - // Handle lock acquisition failure - Err(FlowyError::internal().with_context("Failed to acquire write lock for models cache")) - }, - } - } - - pub async fn update_selected_model(&self, source: String, model: AIModel) -> FlowyResult<()> { - let source_key = ai_available_models_key(&source); info!( - "[Model Selection] update {} selected model: {:?} for key:{}", - source, model, source_key + "[Model Selection] selected model: {:?} for key:{}", + model, + source_key.storage_id() ); - self - .store_preferences - .set_object::(&source_key, &model)?; - chat_notification_builder(&source_key, ChatNotification::DidUpdateSelectedModel) - .payload(AIModelPB::from(model)) - .send(); + let mut notify_source = vec![source.clone()]; + if source == GLOBAL_ACTIVE_MODEL_KEY { + let ids = self + .model_control + .lock() + .await + .get_all_unset_sources() + .await; + info!("[Model Selection] notify all unset sources: {:?}", ids); + notify_source.extend(ids); + } + + trace!("[Model Selection] notify sources: {:?}", notify_source); + for source in notify_source { + chat_notification_builder(&source, ChatNotification::DidUpdateSelectedModel) + .payload(AIModelPB::from(model.clone())) + .send(); + } + Ok(()) } - #[instrument(skip_all, level = "debug")] + #[instrument(skip_all, level = "debug", err)] pub async fn toggle_local_ai(&self) -> FlowyResult<()> { let enabled = self.local_ai.toggle_local_ai().await?; + let workspace_id = self.user_service.workspace_id()?; if enabled { + self.prepare_local_ai(&workspace_id).await; + if let Some(name) = self.local_ai.get_plugin_chat_model() { - info!("Set global active model to local ai: {}", name); let model = AIModel::local(name, "".to_string()); - self + info!( + "[Model Selection] Set global active model to local ai: {}", + model.name + ); + if let Err(err) = self .update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model) - .await?; + .await + { + error!( + "[Model Selection] Failed to set global active model: {}", + err + ); + } } } else { - info!("Set global active model to default"); - let global_active_model = self.get_workspace_select_model().await?; - let models = self.get_server_available_models().await?; - if let Some(model) = models.into_iter().find(|m| m.name == global_active_model) { - self - .update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), AIModel::from(model)) - .await?; + let mut model_control = self.model_control.lock().await; + model_control.remove_local_source(); + + let model = model_control.get_global_active_model(&workspace_id).await; + let mut notify_source = model_control.get_all_unset_sources().await; + notify_source.push(GLOBAL_ACTIVE_MODEL_KEY.to_string()); + drop(model_control); + + trace!( + "[Model Selection] notify sources: {:?}, model:{}, when disable local ai", + notify_source, + model.name + ); + for source in notify_source { + chat_notification_builder(&source, ChatNotification::DidUpdateSelectedModel) + .payload(AIModelPB::from(model.clone())) + .send(); } } Ok(()) } - pub async fn get_active_model(&self, source: &str) -> Option { - let mut model = self - .store_preferences - .get_object::(&ai_available_models_key(source)); - - match model { - None => { - model = self - .store_preferences - .get_object::(&ai_available_models_key(GLOBAL_ACTIVE_MODEL_KEY)); - - model - }, - Some(mut model) => { - let mut all_models = vec![]; - if let Ok(m) = self.get_server_available_models().await { - all_models.extend(m.into_iter().map(AIModel::from)); - } - - let local_models = self.local_ai.get_all_chat_local_models().await; - all_models.extend(local_models); - if !all_models.contains(&model) { - model = AIModel::default() - } - Some(model) + pub async fn get_active_model(&self, source: &str) -> AIModel { + match self.user_service.workspace_id() { + Ok(workspace_id) => { + let source_key = SourceKey::new(source.to_string()); + self + .model_control + .lock() + .await + .get_active_model(&workspace_id, &source_key) + .await }, + Err(_) => AIModel::default(), } } pub async fn get_local_available_models(&self) -> FlowyResult { let setting = self.local_ai.get_local_ai_setting(); - let mut models = self.local_ai.get_all_chat_local_models().await; - let selected_model = AIModel::local(setting.chat_model_name, "".to_string()); + let workspace_id = self.user_service.workspace_id()?; + let mut models = self + .model_control + .lock() + .await + .get_local_models(&workspace_id) + .await; + let selected_model = AIModel::local(setting.chat_model_name, "".to_string()); if models.is_empty() { models.push(selected_model.clone()); } @@ -545,108 +514,23 @@ impl AIManager { return self.get_local_available_models().await; } - // Fetch server models - let mut all_models: Vec = self - .get_server_available_models() - .await? - .into_iter() - .map(AIModel::from) - .collect(); - - trace!("[Model Selection]: Available models: {:?}", all_models); - - // Add local models if enabled - if self.local_ai.is_enabled() { - if setting_only { - let setting = self.local_ai.get_local_ai_setting(); - all_models.push(AIModel::local(setting.chat_model_name, "".to_string())); - } else { - let local_models = self.local_ai.get_all_chat_local_models().await; - trace!( - "[Model Selection]: Available Local models: {:?}", - local_models - .iter() - .map(|m| m.name.as_str()) - .collect::>() - ); - all_models.extend(local_models); - } - } - - // Return early if no models available - if all_models.is_empty() { - return Ok(ModelSelectionPB { - models: Vec::new(), - selected_model: AIModelPB::default(), - }); - } - - // Get server active model (only once) - let server_active_model = self - .get_workspace_select_model() - .await - .map(|m| AIModel::server(m, "".to_string())) - .unwrap_or_else(|_| AIModel::default()); - - trace!( - "[Model Selection] server active model: {:?}", - server_active_model - ); - - // Use server model as default if it exists in available models - let default_model = if all_models - .iter() - .any(|m| m.name == server_active_model.name) - { - server_active_model.clone() + let workspace_id = self.user_service.workspace_id()?; + let local_model_name = if setting_only { + Some(self.local_ai.get_local_ai_setting().chat_model_name) } else { - AIModel::default() + None }; - // Get user's previously selected model - let user_selected_model = match self.get_active_model(&source).await { - Some(model) => { - trace!("[Model Selection] user previous select model: {:?}", model); - model - }, - None => { - // When no selected model and local AI is active, use local AI model - all_models - .iter() - .find(|m| m.is_local) - .cloned() - .unwrap_or_else(|| default_model.clone()) - }, - }; + let source_key = SourceKey::new(source); + let model_control = self.model_control.lock().await; + let active_model = model_control + .get_active_model(&workspace_id, &source_key) + .await; + let all_models = model_control + .get_models_with_specific_local_model(&workspace_id, local_model_name) + .await; + drop(model_control); - trace!( - "[Model Selection] all models: {:?}", - all_models - .iter() - .map(|m| m.name.as_str()) - .collect::>() - ); - - // Determine final active model - use user's selection if available, otherwise default - let active_model = all_models - .iter() - .find(|m| m.name == user_selected_model.name) - .cloned() - .unwrap_or(default_model.clone()); - - // Update stored preference if changed - if active_model.name != user_selected_model.name { - if let Err(err) = self - .update_selected_model(source, active_model.clone()) - .await - { - error!("[Model Selection] failed to update selected model: {}", err); - } - } - - trace!("[Model Selection] final active model: {:?}", active_model); - - // Create response with one transformation pass Ok(ModelSelectionPB { models: all_models.into_iter().map(AIModelPB::from).collect(), selected_model: AIModelPB::from(active_model), @@ -717,7 +601,9 @@ impl AIManager { ) -> Result { let chat = self.get_or_create_chat_instance(chat_id).await?; let ai_model = self.get_active_model(&chat_id.to_string()).await; - let resp = chat.get_related_question(message_id, ai_model).await?; + let resp = chat + .get_related_question(message_id, Some(ai_model)) + .await?; Ok(resp) } diff --git a/frontend/rust-lib/flowy-ai/src/event_handler.rs b/frontend/rust-lib/flowy-ai/src/event_handler.rs index fd7bab3298..bdeeb58e89 100644 --- a/frontend/rust-lib/flowy-ai/src/event_handler.rs +++ b/frontend/rust-lib/flowy-ai/src/event_handler.rs @@ -98,6 +98,7 @@ pub(crate) async fn get_source_model_selection_handler( data_result_ok(models) } +#[tracing::instrument(level = "debug", skip_all, err)] pub(crate) async fn update_selected_model_handler( data: AFPluginData, ai_manager: AFPluginState>, @@ -192,7 +193,7 @@ pub(crate) async fn start_complete_text_handler( let data = data.into_inner(); let ai_manager = upgrade_ai_manager(ai_manager)?; let ai_model = ai_manager.get_active_model(&data.object_id).await; - let task = tools.create_complete_task(data, ai_model).await?; + let task = tools.create_complete_task(data, Some(ai_model)).await?; data_result_ok(task) } diff --git a/frontend/rust-lib/flowy-ai/src/lib.rs b/frontend/rust-lib/flowy-ai/src/lib.rs index e0c9d1ac89..fc8620c394 100644 --- a/frontend/rust-lib/flowy-ai/src/lib.rs +++ b/frontend/rust-lib/flowy-ai/src/lib.rs @@ -13,9 +13,11 @@ pub mod local_ai; #[cfg(any(target_os = "windows", target_os = "macos", target_os = "linux"))] pub mod embeddings; mod middleware; +mod model_select; +#[cfg(test)] +mod model_select_test; pub mod notification; pub mod offline; mod protobuf; mod search; mod stream_message; -mod util; diff --git a/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs b/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs index 878e2eaa23..53e551bf5f 100644 --- a/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs +++ b/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs @@ -59,7 +59,7 @@ pub struct LocalAIController { current_chat_id: ArcSwapOption, store_preferences: Weak, user_service: Arc, - ollama: ArcSwapOption, + pub(crate) ollama: ArcSwapOption, } impl Deref for LocalAIController { @@ -169,8 +169,8 @@ impl LocalAIController { } } - pub fn reload_ollama_client(&self) { - if !self.is_enabled() { + pub fn reload_ollama_client(&self, workspace_id: &str) { + if !self.is_enabled_on_workspace(workspace_id) { return; } @@ -268,11 +268,11 @@ impl LocalAIController { } pub fn is_enabled_on_workspace(&self, workspace_id: &str) -> bool { - let key = local_ai_enabled_key(workspace_id); if !get_operating_system().is_desktop() { return false; } + let key = local_ai_enabled_key(workspace_id); match self.upgrade_store_preferences() { Ok(store) => store.get_bool(&key).unwrap_or(false), Err(_) => false, diff --git a/frontend/rust-lib/flowy-ai/src/middleware/chat_service_mw.rs b/frontend/rust-lib/flowy-ai/src/middleware/chat_service_mw.rs index 74f5d5560b..67e525730f 100644 --- a/frontend/rust-lib/flowy-ai/src/middleware/chat_service_mw.rs +++ b/frontend/rust-lib/flowy-ai/src/middleware/chat_service_mw.rs @@ -372,4 +372,15 @@ impl ChatCloudService for ChatServiceMiddleware { .get_workspace_default_model(workspace_id) .await } + + async fn set_workspace_default_model( + &self, + workspace_id: &Uuid, + model: &str, + ) -> Result<(), FlowyError> { + self + .cloud_service + .set_workspace_default_model(workspace_id, model) + .await + } } diff --git a/frontend/rust-lib/flowy-ai/src/model_select.rs b/frontend/rust-lib/flowy-ai/src/model_select.rs new file mode 100644 index 0000000000..1b01818d1d --- /dev/null +++ b/frontend/rust-lib/flowy-ai/src/model_select.rs @@ -0,0 +1,471 @@ +use crate::local_ai::controller::LocalAIController; +use arc_swap::ArcSwapOption; +use flowy_ai_pub::cloud::{AIModel, ChatCloudService}; +use flowy_error::{FlowyError, FlowyResult}; +use flowy_sqlite::kv::KVStorePreferences; +use lib_infra::async_trait::async_trait; +use lib_infra::util::timestamp; +use std::collections::HashSet; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::{error, info, trace}; +use uuid::Uuid; + +type Model = AIModel; +pub const GLOBAL_ACTIVE_MODEL_KEY: &str = "global_active_model"; + +/// Manages multiple sources and provides operations for model selection +pub struct ModelSelectionControl { + sources: Vec>, + default_model: Model, + local_storage: ArcSwapOption>, + server_storage: ArcSwapOption>, + unset_sources: RwLock>, +} + +impl ModelSelectionControl { + /// Create a new manager with the given storage backends + pub fn new() -> Self { + let default_model = Model::default(); + Self { + sources: Vec::new(), + default_model, + local_storage: ArcSwapOption::new(None), + server_storage: ArcSwapOption::new(None), + unset_sources: Default::default(), + } + } + + /// Replace the local storage backend at runtime + pub fn set_local_storage(&self, storage: impl UserModelStorage + 'static) { + self.local_storage.store(Some(Arc::new(Box::new(storage)))); + } + + /// Replace the server storage backend at runtime + pub fn set_server_storage(&self, storage: impl UserModelStorage + 'static) { + self.server_storage.store(Some(Arc::new(Box::new(storage)))); + } + + /// Add a new model source at runtime + pub fn add_source(&mut self, source: Box) { + info!("[Model Selection] Adding source: {}", source.source_name()); + // remove existing source with the same name + self + .sources + .retain(|s| s.source_name() != source.source_name()); + + self.sources.push(source); + } + + /// Remove all sources matching the given name + pub fn remove_local_source(&mut self) { + info!("[Model Selection] Removing local source"); + self + .sources + .retain(|source| source.source_name() != "local"); + } + + /// Asynchronously aggregate models from all sources, or return the default if none found + pub async fn get_models(&self, workspace_id: &Uuid) -> Vec { + let mut models = Vec::new(); + for source in &self.sources { + let mut list = source.list_chat_models(workspace_id).await; + models.append(&mut list); + } + if models.is_empty() { + vec![self.default_model.clone()] + } else { + models + } + } + + /// Fetches all server‐side models and, if specified, a single local model by name. + /// + /// First collects models from any source named `"server"`. Then it fetches all local models + /// (from the `"local"` source) and: + /// - If `local_model_name` is `Some(name)`, it will append exactly that local model + /// if it exists. + /// - If `local_model_name` is `None`, it will append *all* local models. + /// + pub async fn get_models_with_specific_local_model( + &self, + workspace_id: &Uuid, + local_model_name: Option, + ) -> Vec { + let mut models = Vec::new(); + // add server models + for source in &self.sources { + if source.source_name() == "server" { + let mut list = source.list_chat_models(workspace_id).await; + models.append(&mut list); + } + } + + // check input local model present in local models + let local_models = self.get_local_models(workspace_id).await; + match local_model_name { + Some(name) => { + local_models.into_iter().for_each(|model| { + if model.name == name { + models.push(model); + } + }); + }, + None => { + models.extend(local_models); + }, + } + + models + } + + pub async fn get_local_models(&self, workspace_id: &Uuid) -> Vec { + for source in &self.sources { + if source.source_name() == "local" { + return source.list_chat_models(workspace_id).await; + } + } + vec![] + } + + pub async fn get_all_unset_sources(&self) -> Vec { + let unset_sources = self.unset_sources.read().await; + unset_sources.iter().cloned().collect() + } + + pub async fn get_global_active_model(&self, workspace_id: &Uuid) -> Model { + self + .get_active_model( + workspace_id, + &SourceKey::new(GLOBAL_ACTIVE_MODEL_KEY.to_string()), + ) + .await + } + + /// Retrieves the active model: first tries local storage, then server storage. Ensures validity in the model list. + /// If neither storage yields a valid model, falls back to default. + pub async fn get_active_model(&self, workspace_id: &Uuid, source_key: &SourceKey) -> Model { + let available = self.get_models(workspace_id).await; + // Try local storage + if let Some(storage) = self.local_storage.load_full() { + trace!("[Model Selection] Checking local storage"); + if let Some(local_model) = storage.get_selected_model(workspace_id, source_key).await { + trace!("[Model Selection] Found local model: {}", local_model.name); + if available.iter().any(|m| m.name == local_model.name) { + return local_model; + } else { + trace!( + "[Model Selection] Local {} not found in available list, available: {:?}", + local_model.name, + available.iter().map(|m| &m.name).collect::>() + ); + } + } else { + self + .unset_sources + .write() + .await + .insert(source_key.key.clone()); + } + } + + // use local model if user doesn't set the model for given source + if self + .sources + .iter() + .any(|source| source.source_name() == "local") + { + trace!("[Model Selection] Checking global active model"); + let global_source = SourceKey::new(GLOBAL_ACTIVE_MODEL_KEY.to_string()); + if let Some(storage) = self.local_storage.load_full() { + if let Some(local_model) = storage + .get_selected_model(workspace_id, &global_source) + .await + { + trace!( + "[Model Selection] Found global active model: {}", + local_model.name + ); + if available.iter().any(|m| m.name == local_model.name) { + return local_model; + } + } + } + } + + // Try server storage + if let Some(storage) = self.server_storage.load_full() { + trace!("[Model Selection] Checking server storage"); + if let Some(server_model) = storage.get_selected_model(workspace_id, source_key).await { + trace!( + "[Model Selection] Found server model: {}", + server_model.name + ); + if available.iter().any(|m| m.name == server_model.name) { + return server_model; + } else { + trace!( + "[Model Selection] Server {} not found in available list, available: {:?}", + server_model.name, + available.iter().map(|m| &m.name).collect::>() + ); + } + } + } + // Fallback: default + info!( + "[Model Selection] No active model found, using default: {}", + self.default_model.name + ); + self.default_model.clone() + } + + /// Sets the active model in both local and server storage + pub async fn set_active_model( + &self, + workspace_id: &Uuid, + source_key: &SourceKey, + model: Model, + ) -> Result<(), FlowyError> { + info!( + "[Model Selection] active model: {} for source: {}", + model.name, source_key.key + ); + self.unset_sources.write().await.remove(&source_key.key); + + let available = self.get_models(workspace_id).await; + if available.contains(&model) { + // Update local storage + if let Some(storage) = self.local_storage.load_full() { + storage + .set_selected_model(workspace_id, source_key, model.clone()) + .await?; + } + + // Update server storage + if let Some(storage) = self.server_storage.load_full() { + storage + .set_selected_model(workspace_id, source_key, model) + .await?; + } + Ok(()) + } else { + Err( + FlowyError::internal() + .with_context(format!("Model '{:?}' not found in available list", model)), + ) + } + } +} + +/// Namespaced key for model selection storage +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SourceKey { + key: String, +} + +impl SourceKey { + /// Create a new SourceKey + pub fn new(key: String) -> Self { + Self { key } + } + + /// Combine the UUID key with a model's is_local flag and name to produce a storage identifier + pub fn storage_id(&self) -> String { + format!("ai_models_{}", self.key) + } +} + +/// A trait that defines an asynchronous source of AI models +#[async_trait] +pub trait ModelSource: Send + Sync { + /// Identifier for this source (e.g., "local" or "server") + fn source_name(&self) -> &'static str; + + /// Asynchronously returns a list of available models from this source + async fn list_chat_models(&self, workspace_id: &Uuid) -> Vec; +} + +pub struct LocalAiSource { + controller: Arc, +} + +impl LocalAiSource { + pub fn new(controller: Arc) -> Self { + Self { controller } + } +} + +#[async_trait] +impl ModelSource for LocalAiSource { + fn source_name(&self) -> &'static str { + "local" + } + + async fn list_chat_models(&self, _workspace_id: &Uuid) -> Vec { + match self.controller.ollama.load_full() { + None => vec![], + Some(ollama) => ollama + .list_local_models() + .await + .map(|models| { + models + .into_iter() + .filter(|m| !m.name.contains("embed")) + .map(|m| AIModel::local(m.name, String::new())) + .collect() + }) + .unwrap_or_default(), + } + } +} + +/// A server-side AI source (e.g., cloud API) +#[derive(Debug, Default)] +struct ServerModelsCache { + models: Vec, + timestamp: Option, +} + +pub struct ServerAiSource { + cached_models: Arc>, + cloud_service: Arc, +} + +impl ServerAiSource { + pub fn new(cloud_service: Arc) -> Self { + Self { + cached_models: Arc::new(Default::default()), + cloud_service, + } + } + + async fn update_models_cache(&self, models: &[Model], timestamp: i64) -> FlowyResult<()> { + match self.cached_models.try_write() { + Ok(mut cache) => { + cache.models = models.to_vec(); + cache.timestamp = Some(timestamp); + Ok(()) + }, + Err(_) => { + Err(FlowyError::internal().with_context("Failed to acquire write lock for models cache")) + }, + } + } +} + +#[async_trait] +impl ModelSource for ServerAiSource { + fn source_name(&self) -> &'static str { + "server" + } + + async fn list_chat_models(&self, workspace_id: &Uuid) -> Vec { + let now = timestamp(); + let should_fetch = { + let cached = self.cached_models.read().await; + cached.models.is_empty() || cached.timestamp.map_or(true, |ts| now - ts >= 300) + }; + if !should_fetch { + return self.cached_models.read().await.models.clone(); + } + match self.cloud_service.get_available_models(workspace_id).await { + Ok(resp) => { + let models = resp + .models + .into_iter() + .map(AIModel::from) + .collect::>(); + if let Err(e) = self.update_models_cache(&models, now).await { + error!("Failed to update cache: {}", e); + } + models + }, + Err(err) => { + error!("Failed to fetch models: {}", err); + let cached = self.cached_models.read().await; + if !cached.models.is_empty() { + info!("Returning expired cache due to error"); + return cached.models.clone(); + } + Vec::new() + }, + } + } +} + +#[async_trait] +pub trait UserModelStorage: Send + Sync { + async fn get_selected_model(&self, workspace_id: &Uuid, source_key: &SourceKey) -> Option; + async fn set_selected_model( + &self, + workspace_id: &Uuid, + source_key: &SourceKey, + model: Model, + ) -> Result<(), FlowyError>; +} + +pub struct ServerModelStorageImpl(pub Arc); + +#[async_trait] +impl UserModelStorage for ServerModelStorageImpl { + async fn get_selected_model( + &self, + workspace_id: &Uuid, + _source_key: &SourceKey, + ) -> Option { + let name = self + .0 + .get_workspace_default_model(workspace_id) + .await + .ok()?; + Some(Model::server(name, String::new())) + } + + async fn set_selected_model( + &self, + workspace_id: &Uuid, + source_key: &SourceKey, + model: Model, + ) -> Result<(), FlowyError> { + if model.is_local { + // local model does not need to be set + return Ok(()); + } + + if source_key.key != GLOBAL_ACTIVE_MODEL_KEY { + return Ok(()); + } + + self + .0 + .set_workspace_default_model(workspace_id, &model.name) + .await?; + Ok(()) + } +} + +pub struct LocalModelStorageImpl(pub Arc); + +#[async_trait] +impl UserModelStorage for LocalModelStorageImpl { + async fn get_selected_model( + &self, + _workspace_id: &Uuid, + source_key: &SourceKey, + ) -> Option { + self.0.get_object::(&source_key.storage_id()) + } + + async fn set_selected_model( + &self, + _workspace_id: &Uuid, + source_key: &SourceKey, + model: Model, + ) -> Result<(), FlowyError> { + self + .0 + .set_object::(&source_key.storage_id(), &model)?; + Ok(()) + } +} diff --git a/frontend/rust-lib/flowy-ai/src/model_select_test.rs b/frontend/rust-lib/flowy-ai/src/model_select_test.rs new file mode 100644 index 0000000000..c5c1f35e80 --- /dev/null +++ b/frontend/rust-lib/flowy-ai/src/model_select_test.rs @@ -0,0 +1,434 @@ +use crate::model_select::{ModelSelectionControl, ModelSource, SourceKey, UserModelStorage}; +use flowy_ai_pub::cloud::AIModel; +use flowy_error::FlowyError; +use lib_infra::async_trait::async_trait; +use tokio::sync::RwLock; +use uuid::Uuid; + +// Mock implementations for testing +struct MockModelSource { + name: &'static str, + models: Vec, +} + +#[async_trait] +impl ModelSource for MockModelSource { + fn source_name(&self) -> &'static str { + self.name + } + + async fn list_chat_models(&self, _workspace_id: &Uuid) -> Vec { + self.models.clone() + } +} + +struct MockModelStorage { + selected_model: RwLock>, +} + +impl MockModelStorage { + fn new(initial_model: Option) -> Self { + Self { + selected_model: RwLock::new(initial_model), + } + } +} + +#[async_trait] +impl UserModelStorage for MockModelStorage { + async fn get_selected_model( + &self, + _workspace_id: &Uuid, + _source_key: &SourceKey, + ) -> Option { + self.selected_model.read().await.clone() + } + + async fn set_selected_model( + &self, + _workspace_id: &Uuid, + _source_key: &SourceKey, + model: AIModel, + ) -> Result<(), FlowyError> { + *self.selected_model.write().await = Some(model); + Ok(()) + } +} + +#[tokio::test] +async fn test_empty_model_list_returns_default() { + let control = ModelSelectionControl::new(); + let workspace_id = Uuid::new_v4(); + + let models = control.get_models(&workspace_id).await; + + assert_eq!(models.len(), 1); + assert_eq!(models[0], AIModel::default()); +} + +#[tokio::test] +async fn test_get_models_from_multiple_sources() { + let mut control = ModelSelectionControl::new(); + + let local_source = Box::new(MockModelSource { + name: "local", + models: vec![ + AIModel::local("local-model-1".to_string(), "".to_string()), + AIModel::local("local-model-2".to_string(), "".to_string()), + ], + }); + + let server_source = Box::new(MockModelSource { + name: "server", + models: vec![ + AIModel::server("server-model-1".to_string(), "".to_string()), + AIModel::server("server-model-2".to_string(), "".to_string()), + ], + }); + + control.add_source(local_source); + control.add_source(server_source); + + let workspace_id = Uuid::new_v4(); + let models = control.get_models(&workspace_id).await; + + assert_eq!(models.len(), 4); + assert!(models.iter().any(|m| m.name == "local-model-1")); + assert!(models.iter().any(|m| m.name == "local-model-2")); + assert!(models.iter().any(|m| m.name == "server-model-1")); + assert!(models.iter().any(|m| m.name == "server-model-2")); +} + +#[tokio::test] +async fn test_get_models_with_specific_local_model() { + let mut control = ModelSelectionControl::new(); + + let local_source = Box::new(MockModelSource { + name: "local", + models: vec![ + AIModel::local("local-model-1".to_string(), "".to_string()), + AIModel::local("local-model-2".to_string(), "".to_string()), + ], + }); + + let server_source = Box::new(MockModelSource { + name: "server", + models: vec![ + AIModel::server("server-model-1".to_string(), "".to_string()), + AIModel::server("server-model-2".to_string(), "".to_string()), + ], + }); + + control.add_source(local_source); + control.add_source(server_source); + + let workspace_id = Uuid::new_v4(); + + // Test with specific local model + let models = control + .get_models_with_specific_local_model(&workspace_id, Some("local-model-1".to_string())) + .await; + assert_eq!(models.len(), 3); + assert!(models.iter().any(|m| m.name == "local-model-1")); + assert!(!models.iter().any(|m| m.name == "local-model-2")); + + // Test with non-existent local model + let models = control + .get_models_with_specific_local_model(&workspace_id, Some("non-existent".to_string())) + .await; + assert_eq!(models.len(), 2); // Only server models + + // Test with no specified local model (should include all local models) + let models = control + .get_models_with_specific_local_model(&workspace_id, None) + .await; + assert_eq!(models.len(), 4); +} + +#[tokio::test] +async fn test_get_local_models() { + let mut control = ModelSelectionControl::new(); + + let local_source = Box::new(MockModelSource { + name: "local", + models: vec![ + AIModel::local("local-model-1".to_string(), "".to_string()), + AIModel::local("local-model-2".to_string(), "".to_string()), + ], + }); + + let server_source = Box::new(MockModelSource { + name: "server", + models: vec![AIModel::server( + "server-model-1".to_string(), + "".to_string(), + )], + }); + + control.add_source(local_source); + control.add_source(server_source); + + let workspace_id = Uuid::new_v4(); + let local_models = control.get_local_models(&workspace_id).await; + + assert_eq!(local_models.len(), 2); + assert!(local_models.iter().all(|m| m.is_local)); +} + +#[tokio::test] +async fn test_remove_local_source() { + let mut control = ModelSelectionControl::new(); + + let local_source = Box::new(MockModelSource { + name: "local", + models: vec![AIModel::local("local-model-1".to_string(), "".to_string())], + }); + + let server_source = Box::new(MockModelSource { + name: "server", + models: vec![AIModel::server( + "server-model-1".to_string(), + "".to_string(), + )], + }); + + control.add_source(local_source); + control.add_source(server_source); + + let workspace_id = Uuid::new_v4(); + assert_eq!(control.get_models(&workspace_id).await.len(), 2); + + control.remove_local_source(); + let models = control.get_models(&workspace_id).await; + + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "server-model-1"); +} + +#[tokio::test] +async fn test_get_active_model_from_local_storage() { + let mut control = ModelSelectionControl::new(); + let workspace_id = Uuid::new_v4(); + let source_key = SourceKey::new("test".to_string()); + + // Add a local source with some models + let local_model = AIModel::local("local-model-1".to_string(), "".to_string()); + let local_source = Box::new(MockModelSource { + name: "local", + models: vec![local_model.clone()], + }); + control.add_source(local_source); + + // Set up local storage with a selected model + let local_storage = MockModelStorage::new(Some(local_model.clone())); + control.set_local_storage(local_storage); + + // Get active model should return the locally stored model + let active = control.get_active_model(&workspace_id, &source_key).await; + assert_eq!(active, local_model); +} + +#[tokio::test] +async fn test_global_active_model_fallback() { + let mut control = ModelSelectionControl::new(); + let workspace_id = Uuid::new_v4(); + let source_key = SourceKey::new("specific_source".to_string()); + + // Add a local source with models + let local_model = AIModel::local("local-model-1".to_string(), "".to_string()); + let local_source = Box::new(MockModelSource { + name: "local", + models: vec![local_model.clone()], + }); + control.add_source(local_source); + + // Set up local storage with global model but not specific source model + let global_storage = MockModelStorage::new(Some(local_model.clone())); + + // Set the local storage + control.set_local_storage(global_storage); + + // Get active model should fall back to the global model since + // there's no model for the specific source key + let active = control.get_active_model(&workspace_id, &source_key).await; + assert_eq!(active, local_model); +} + +#[tokio::test] +async fn test_get_active_model_fallback_to_server_storage() { + let mut control = ModelSelectionControl::new(); + let workspace_id = Uuid::new_v4(); + let source_key = SourceKey::new("test".to_string()); + + // Add a server source with some models + let server_model = AIModel::server("server-model-1".to_string(), "".to_string()); + let server_source = Box::new(MockModelSource { + name: "server", + models: vec![server_model.clone()], + }); + control.add_source(server_source); + + // Set up local storage with no selected model + let local_storage = MockModelStorage::new(None); + control.set_local_storage(local_storage); + + // Set up server storage with a selected model + let server_storage = MockModelStorage::new(Some(server_model.clone())); + control.set_server_storage(server_storage); + + // Get active model should fall back to server storage + let active = control.get_active_model(&workspace_id, &source_key).await; + assert_eq!(active, server_model); +} + +#[tokio::test] +async fn test_get_active_model_fallback_to_default() { + let mut control = ModelSelectionControl::new(); + let workspace_id = Uuid::new_v4(); + let source_key = SourceKey::new("test".to_string()); + + // Add sources with some models + let model1 = AIModel::local("model-1".to_string(), "".to_string()); + let model2 = AIModel::server("model-2".to_string(), "".to_string()); + + let source = Box::new(MockModelSource { + name: "test", + models: vec![model1.clone(), model2.clone()], + }); + control.add_source(source); + + // Set up storages with models that don't match available models + let different_model = AIModel::local("non-existent".to_string(), "".to_string()); + let local_storage = MockModelStorage::new(Some(different_model.clone())); + let server_storage = MockModelStorage::new(Some(different_model.clone())); + + control.set_local_storage(local_storage); + control.set_server_storage(server_storage); + + // Should fall back to default model since storages return non-matching models + let active = control.get_active_model(&workspace_id, &source_key).await; + assert_eq!(active, AIModel::default()); +} + +#[tokio::test] +async fn test_set_active_model() { + let mut control = ModelSelectionControl::new(); + let workspace_id = Uuid::new_v4(); + let source_key = SourceKey::new("test".to_string()); + + // Add a source with some models + let model = AIModel::local("model-1".to_string(), "".to_string()); + let source = Box::new(MockModelSource { + name: "test", + models: vec![model.clone()], + }); + control.add_source(source); + + // Set up storage + let local_storage = MockModelStorage::new(None); + let server_storage = MockModelStorage::new(None); + control.set_local_storage(local_storage); + control.set_server_storage(server_storage); + + // Set active model + let result = control + .set_active_model(&workspace_id, &source_key, model.clone()) + .await; + assert!(result.is_ok()); + + // Verify that the active model was set correctly + let active = control.get_active_model(&workspace_id, &source_key).await; + assert_eq!(active, model); +} + +#[tokio::test] +async fn test_set_active_model_invalid_model() { + let mut control = ModelSelectionControl::new(); + let workspace_id = Uuid::new_v4(); + let source_key = SourceKey::new("test".to_string()); + + // Add a source with some models + let available_model = AIModel::local("available-model".to_string(), "".to_string()); + let source = Box::new(MockModelSource { + name: "test", + models: vec![available_model.clone()], + }); + control.add_source(source); + + // Set up storage + let local_storage = MockModelStorage::new(None); + let server_storage = MockModelStorage::new(None); + control.set_local_storage(local_storage); + control.set_server_storage(server_storage); + + // Try to set an invalid model + let invalid_model = AIModel::local("invalid-model".to_string(), "".to_string()); + let result = control + .set_active_model(&workspace_id, &source_key, invalid_model) + .await; + + // Should fail because the model is not in the available list + assert!(result.is_err()); +} + +#[tokio::test] +async fn test_global_active_model_fallback_with_local_source() { + let mut control = ModelSelectionControl::new(); + let workspace_id = Uuid::new_v4(); + let source_key = SourceKey::new("specific_source".to_string()); + + // Add a local source with models + let local_model = AIModel::local("local-model-1".to_string(), "".to_string()); + let local_source = Box::new(MockModelSource { + name: "local", // This is important - the fallback only happens when a local source exists + models: vec![local_model.clone()], + }); + control.add_source(local_source); + + // Create a custom storage that only returns a model for the global key + struct GlobalOnlyStorage { + global_model: AIModel, + } + + #[async_trait] + impl UserModelStorage for GlobalOnlyStorage { + async fn get_selected_model( + &self, + _workspace_id: &Uuid, + source_key: &SourceKey, + ) -> Option { + if source_key.storage_id() + == format!("ai_models_{}", crate::model_select::GLOBAL_ACTIVE_MODEL_KEY) + { + Some(self.global_model.clone()) + } else { + None + } + } + + async fn set_selected_model( + &self, + _workspace_id: &Uuid, + _source_key: &SourceKey, + _model: AIModel, + ) -> Result<(), FlowyError> { + Ok(()) + } + } + + // Set up local storage with only the global model + let global_storage = GlobalOnlyStorage { + global_model: local_model.clone(), + }; + control.set_local_storage(global_storage); + + // Get active model for a specific source_key (not the global one) + // Should fall back to the global model since: + // 1. There's no model for the specific source_key + // 2. There is a local source + // 3. There is a global active model set + let active = control.get_active_model(&workspace_id, &source_key).await; + + // Should get the global model + assert_eq!(active, local_model); +} diff --git a/frontend/rust-lib/flowy-ai/src/offline/offline_message_sync.rs b/frontend/rust-lib/flowy-ai/src/offline/offline_message_sync.rs index 55daf6b77f..3369095d94 100644 --- a/frontend/rust-lib/flowy-ai/src/offline/offline_message_sync.rs +++ b/frontend/rust-lib/flowy-ai/src/offline/offline_message_sync.rs @@ -255,4 +255,15 @@ impl ChatCloudService for AutoSyncChatService { .get_workspace_default_model(workspace_id) .await } + + async fn set_workspace_default_model( + &self, + workspace_id: &Uuid, + model: &str, + ) -> Result<(), FlowyError> { + self + .cloud_service + .set_workspace_default_model(workspace_id, model) + .await + } } diff --git a/frontend/rust-lib/flowy-ai/src/util.rs b/frontend/rust-lib/flowy-ai/src/util.rs deleted file mode 100644 index a181d1b1d3..0000000000 --- a/frontend/rust-lib/flowy-ai/src/util.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub fn ai_available_models_key(object_id: &str) -> String { - format!("ai_models_{}", object_id) -} diff --git a/frontend/rust-lib/flowy-core/src/app_life_cycle.rs b/frontend/rust-lib/flowy-core/src/app_life_cycle.rs index 0630bdf25c..bddca3af9a 100644 --- a/frontend/rust-lib/flowy-core/src/app_life_cycle.rs +++ b/frontend/rust-lib/flowy-core/src/app_life_cycle.rs @@ -258,10 +258,14 @@ impl AppLifeCycle for AppLifeCycleImpl { .initialize_after_sign_in(user_id) .await?; - self - .ai_manager()? - .initialize_after_sign_in(workspace_id) - .await?; + let ai_manager = self.ai_manager()?; + let cloned_workspace_id = *workspace_id; + self.runtime.spawn(async move { + ai_manager + .initialize_after_sign_in(&cloned_workspace_id) + .await?; + Ok::<_, FlowyError>(()) + }); self .create_thanvity_state_if_not_exists(user_id, workspace_id, user_paths) @@ -345,10 +349,14 @@ impl AppLifeCycle for AppLifeCycleImpl { .await .context("DocumentManager error")?; - self - .ai_manager()? - .initialize_after_sign_up(workspace_id) - .await?; + let ai_manager = self.ai_manager()?; + let cloned_workspace_id = *workspace_id; + self.runtime.spawn(async move { + ai_manager + .initialize_after_sign_up(&cloned_workspace_id) + .await?; + Ok::<_, FlowyError>(()) + }); self .create_thanvity_state_if_not_exists(user_profile.uid, workspace_id, user_paths) @@ -422,10 +430,15 @@ impl AppLifeCycle for AppLifeCycleImpl { .document_manager()? .initialize_after_open_workspace(user_id) .await?; - self - .ai_manager()? - .initialize_after_open_workspace(workspace_id) - .await?; + + let ai_manager = self.ai_manager()?; + let cloned_workspace_id = *workspace_id; + self.runtime.spawn(async move { + ai_manager + .initialize_after_open_workspace(&cloned_workspace_id) + .await?; + Ok::<_, FlowyError>(()) + }); self .storage_manager()? .initialize_after_open_workspace(workspace_id) diff --git a/frontend/rust-lib/flowy-core/src/deps_resolve/cloud_service_impl.rs b/frontend/rust-lib/flowy-core/src/deps_resolve/cloud_service_impl.rs index c49757f735..c05aa3152a 100644 --- a/frontend/rust-lib/flowy-core/src/deps_resolve/cloud_service_impl.rs +++ b/frontend/rust-lib/flowy-core/src/deps_resolve/cloud_service_impl.rs @@ -818,6 +818,18 @@ impl ChatCloudService for ServerProvider { .get_workspace_default_model(workspace_id) .await } + + async fn set_workspace_default_model( + &self, + workspace_id: &Uuid, + model: &str, + ) -> Result<(), FlowyError> { + self + .get_server()? + .chat_service() + .set_workspace_default_model(workspace_id, model) + .await + } } #[async_trait] diff --git a/frontend/rust-lib/flowy-core/src/server_layer.rs b/frontend/rust-lib/flowy-core/src/server_layer.rs index 5b8b3d5585..de45556b85 100644 --- a/frontend/rust-lib/flowy-core/src/server_layer.rs +++ b/frontend/rust-lib/flowy-core/src/server_layer.rs @@ -79,22 +79,12 @@ impl ServerProvider { } } - pub fn on_launch_if_authenticated(&self, _workspace_type: &WorkspaceType) { - self.local_ai.reload_ollama_client(); - } + pub fn on_launch_if_authenticated(&self, _workspace_type: &WorkspaceType) {} - pub fn on_sign_in(&self, _workspace_type: &WorkspaceType) { - self.local_ai.reload_ollama_client(); - } + pub fn on_sign_in(&self, _workspace_type: &WorkspaceType) {} - pub fn on_sign_up(&self, workspace_type: &WorkspaceType) { - if workspace_type.is_local() { - self.local_ai.reload_ollama_client(); - } - } - pub fn init_after_open_workspace(&self, _workspace_type: &WorkspaceType) { - self.local_ai.reload_ollama_client(); - } + pub fn on_sign_up(&self, _workspace_type: &WorkspaceType) {} + pub fn init_after_open_workspace(&self, _workspace_type: &WorkspaceType) {} pub fn set_auth_type(&self, new_auth_type: AuthType) { let old_type = self.get_auth_type(); diff --git a/frontend/rust-lib/flowy-server/src/af_cloud/impls/chat.rs b/frontend/rust-lib/flowy-server/src/af_cloud/impls/chat.rs index 6086f7084b..4e9b321a84 100644 --- a/frontend/rust-lib/flowy-server/src/af_cloud/impls/chat.rs +++ b/frontend/rust-lib/flowy-server/src/af_cloud/impls/chat.rs @@ -8,8 +8,8 @@ use client_api::entity::chat_dto::{ RepeatedChatMessage, }; use flowy_ai_pub::cloud::{ - AIModel, ChatCloudService, ChatMessage, ChatMessageType, ChatSettings, ModelList, StreamAnswer, - StreamComplete, UpdateChatParams, + AFWorkspaceSettingsChange, AIModel, ChatCloudService, ChatMessage, ChatMessageType, ChatSettings, + ModelList, StreamAnswer, StreamComplete, UpdateChatParams, }; use flowy_error::FlowyError; use futures_util::{StreamExt, TryStreamExt}; @@ -267,4 +267,18 @@ where .await?; Ok(setting.ai_model) } + + async fn set_workspace_default_model( + &self, + workspace_id: &Uuid, + model: &str, + ) -> Result<(), FlowyError> { + let change = AFWorkspaceSettingsChange::new().ai_model(model.to_string()); + let setting = self + .inner + .try_get_client()? + .update_workspace_settings(workspace_id.to_string().as_str(), &change) + .await?; + Ok(()) + } } diff --git a/frontend/rust-lib/flowy-server/src/local_server/impls/chat.rs b/frontend/rust-lib/flowy-server/src/local_server/impls/chat.rs index 845b6dec1c..098de5f8c0 100644 --- a/frontend/rust-lib/flowy-server/src/local_server/impls/chat.rs +++ b/frontend/rust-lib/flowy-server/src/local_server/impls/chat.rs @@ -319,6 +319,15 @@ impl ChatCloudService for LocalChatServiceImpl { async fn get_workspace_default_model(&self, _workspace_id: &Uuid) -> Result { Ok(DEFAULT_AI_MODEL_NAME.to_string()) } + + async fn set_workspace_default_model( + &self, + _workspace_id: &Uuid, + _model: &str, + ) -> Result<(), FlowyError> { + // do nothing + Ok(()) + } } fn chat_message_from_row(row: ChatMessageTable) -> ChatMessage {