From 7dd8d06c85e16f9a89dcdceb5e0eb40a0cf8139b Mon Sep 17 00:00:00 2001 From: Nathan Date: Fri, 25 Apr 2025 16:53:32 +0800 Subject: [PATCH 1/6] chore: display all local models --- .../prompt_input/select_model_menu.dart | 62 +++--- .../settings/ai/ollama_setting_bloc.dart | 21 +- frontend/rust-lib/Cargo.lock | 98 +++++--- .../src/persistence/local_model_sql.rs | 54 +++++ .../flowy-ai-pub/src/persistence/mod.rs | 2 + frontend/rust-lib/flowy-ai/Cargo.toml | 6 +- frontend/rust-lib/flowy-ai/src/ai_manager.rs | 103 +++++---- frontend/rust-lib/flowy-ai/src/entities.rs | 6 +- .../rust-lib/flowy-ai/src/event_handler.rs | 6 +- .../flowy-ai/src/local_ai/controller.rs | 209 +++++++++++++----- .../flowy-ai/src/local_ai/resource.rs | 1 - frontend/rust-lib/flowy-error/Cargo.toml | 4 + frontend/rust-lib/flowy-error/src/errors.rs | 7 + .../2025-04-25-071459_local_ai_model/down.sql | 1 + .../2025-04-25-071459_local_ai_model/up.sql | 6 + frontend/rust-lib/flowy-sqlite/src/schema.rs | 32 ++- 16 files changed, 418 insertions(+), 200 deletions(-) create mode 100644 frontend/rust-lib/flowy-ai-pub/src/persistence/local_model_sql.rs create mode 100644 frontend/rust-lib/flowy-sqlite/migrations/2025-04-25-071459_local_ai_model/down.sql create mode 100644 frontend/rust-lib/flowy-sqlite/migrations/2025-04-25-071459_local_ai_model/up.sql diff --git a/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart b/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart index a611d84310..36b389f8c4 100644 --- a/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart +++ b/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart @@ -90,38 +90,40 @@ class SelectModelPopoverContent extends StatelessWidget { return Padding( padding: const EdgeInsets.all(8.0), - child: Column( - mainAxisSize: MainAxisSize.min, - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - if (localModels.isNotEmpty) ...[ - _ModelSectionHeader( - title: LocaleKeys.chat_switchModel_localModel.tr(), + child: SingleChildScrollView( + child: Column( + mainAxisSize: MainAxisSize.min, + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + if (localModels.isNotEmpty) ...[ + _ModelSectionHeader( + title: LocaleKeys.chat_switchModel_localModel.tr(), + ), + const VSpace(4.0), + ], + ...localModels.map( + (model) => _ModelItem( + model: model, + isSelected: model == selectedModel, + onTap: () => onSelectModel?.call(model), + ), + ), + if (cloudModels.isNotEmpty && localModels.isNotEmpty) ...[ + const VSpace(8.0), + _ModelSectionHeader( + title: LocaleKeys.chat_switchModel_cloudModel.tr(), + ), + const VSpace(4.0), + ], + ...cloudModels.map( + (model) => _ModelItem( + model: model, + isSelected: model == selectedModel, + onTap: () => onSelectModel?.call(model), + ), ), - const VSpace(4.0), ], - ...localModels.map( - (model) => _ModelItem( - model: model, - isSelected: model == selectedModel, - onTap: () => onSelectModel?.call(model), - ), - ), - if (cloudModels.isNotEmpty && localModels.isNotEmpty) ...[ - const VSpace(8.0), - _ModelSectionHeader( - title: LocaleKeys.chat_switchModel_cloudModel.tr(), - ), - const VSpace(4.0), - ], - ...cloudModels.map( - (model) => _ModelItem( - model: model, - isSelected: model == selectedModel, - onTap: () => onSelectModel?.call(model), - ), - ), - ], + ), ), ); } diff --git a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/ollama_setting_bloc.dart b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/ollama_setting_bloc.dart index f5c4209028..2659292b11 100644 --- a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/ollama_setting_bloc.dart +++ b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/ollama_setting_bloc.dart @@ -11,6 +11,9 @@ import 'package:equatable/equatable.dart'; part 'ollama_setting_bloc.freezed.dart'; +const kDefaultChatModel = 'llama3.1:latest'; +const kDefaultEmbeddingModel = 'nomic-embed-text:latest'; + class OllamaSettingBloc extends Bloc { OllamaSettingBloc() : super(const OllamaSettingState()) { on(_handleEvent); @@ -70,7 +73,7 @@ class OllamaSettingBloc extends Bloc { final setting = LocalAISettingPB(); final settingUpdaters = { SettingType.serverUrl: (value) => setting.serverUrl = value, - SettingType.chatModel: (value) => setting.chatModelName = value, + SettingType.chatModel: (value) => setting.defaultModel = value, SettingType.embeddingModel: (value) => setting.embeddingModelName = value, }; @@ -108,13 +111,13 @@ class OllamaSettingBloc extends Bloc { settingType: SettingType.serverUrl, ), SettingItem( - content: setting.chatModelName, - hintText: 'llama3.1', + content: setting.defaultModel, + hintText: kDefaultChatModel, settingType: SettingType.chatModel, ), SettingItem( content: setting.embeddingModelName, - hintText: 'nomic-embed-text', + hintText: kDefaultEmbeddingModel, settingType: SettingType.embeddingModel, ), ]; @@ -125,7 +128,7 @@ class OllamaSettingBloc extends Bloc { settingType: SettingType.serverUrl, ), SubmittedItem( - content: setting.chatModelName, + content: setting.defaultModel, settingType: SettingType.chatModel, ), SubmittedItem( @@ -203,13 +206,13 @@ class OllamaSettingState with _$OllamaSettingState { settingType: SettingType.serverUrl, ), SettingItem( - content: 'llama3.1', - hintText: 'llama3.1', + content: kDefaultChatModel, + hintText: kDefaultChatModel, settingType: SettingType.chatModel, ), SettingItem( - content: 'nomic-embed-text', - hintText: 'nomic-embed-text', + content: kDefaultEmbeddingModel, + hintText: kDefaultEmbeddingModel, settingType: SettingType.embeddingModel, ), ]) diff --git a/frontend/rust-lib/Cargo.lock b/frontend/rust-lib/Cargo.lock index b7d7fbd7f2..4689ed5c4f 100644 --- a/frontend/rust-lib/Cargo.lock +++ b/frontend/rust-lib/Cargo.lock @@ -2210,12 +2210,6 @@ dependencies = [ "litrs", ] -[[package]] -name = "dotenv" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" - [[package]] name = "downcast-rs" version = "2.0.1" @@ -2506,7 +2500,6 @@ dependencies = [ "bytes", "collab-integrate", "dashmap 6.0.1", - "dotenv", "flowy-ai-pub", "flowy-codegen", "flowy-derive", @@ -2520,19 +2513,18 @@ dependencies = [ "lib-infra", "log", "notify", + "ollama-rs", "pin-project", "protobuf", "reqwest 0.11.27", "serde", "serde_json", "sha2", - "simsimd", "strum_macros 0.21.1", "tokio", "tokio-stream", "tokio-util", "tracing", - "tracing-subscriber", "uuid", "validator 0.18.1", ] @@ -2798,6 +2790,7 @@ dependencies = [ "flowy-derive", "flowy-sqlite", "lib-dispatch", + "ollama-rs", "protobuf", "r2d2", "reqwest 0.11.27", @@ -4044,6 +4037,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", "hashbrown 0.12.3", + "serde", ] [[package]] @@ -4894,6 +4888,23 @@ dependencies = [ "memchr", ] +[[package]] +name = "ollama-rs" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a4b4750770584c8b4a643d0329e7bedacc4ecf68b7c7ac3e1fec2bafd6312f7" +dependencies = [ + "async-stream", + "log", + "reqwest 0.12.15", + "schemars", + "serde", + "serde_json", + "static_assertions", + "thiserror 2.0.12", + "url", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -5696,7 +5707,7 @@ dependencies = [ "rustc-hash 2.1.0", "rustls 0.23.20", "socket2 0.5.5", - "thiserror 2.0.9", + "thiserror 2.0.12", "tokio", "tracing", ] @@ -5715,7 +5726,7 @@ dependencies = [ "rustls 0.23.20", "rustls-pki-types", "slab", - "thiserror 2.0.9", + "thiserror 2.0.12", "tinyvec", "tracing", "web-time", @@ -6407,6 +6418,31 @@ dependencies = [ "parking_lot 0.12.1", ] +[[package]] +name = "schemars" +version = "0.8.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" +dependencies = [ + "dyn-clone", + "indexmap 1.9.3", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "0.8.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals 0.29.1", + "syn 2.0.94", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -6554,6 +6590,17 @@ dependencies = [ "syn 2.0.94", ] +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.94", +] + [[package]] name = "serde_html_form" version = "0.2.7" @@ -6746,15 +6793,6 @@ dependencies = [ "time", ] -[[package]] -name = "simsimd" -version = "4.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efc843bc8f12d9c8e6b734a0fe8918fc497b42f6ae0f347dbfdad5b5138ab9b4" -dependencies = [ - "cc", -] - [[package]] name = "siphasher" version = "0.3.11" @@ -6841,6 +6879,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "string_cache" version = "0.8.7" @@ -7098,7 +7142,7 @@ dependencies = [ "tantivy-stacker", "tantivy-tokenizer-api", "tempfile", - "thiserror 2.0.9", + "thiserror 2.0.12", "time", "uuid", "winapi", @@ -7280,11 +7324,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.9" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ - "thiserror-impl 2.0.9", + "thiserror-impl 2.0.12", ] [[package]] @@ -7300,9 +7344,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.9" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", @@ -7823,7 +7867,7 @@ checksum = "7a94b0f0954b3e59bfc2c246b4c8574390d94a4ad4ad246aaf2fb07d7dfd3b47" dependencies = [ "proc-macro2", "quote", - "serde_derive_internals", + "serde_derive_internals 0.28.0", "syn 2.0.94", ] diff --git a/frontend/rust-lib/flowy-ai-pub/src/persistence/local_model_sql.rs b/frontend/rust-lib/flowy-ai-pub/src/persistence/local_model_sql.rs new file mode 100644 index 0000000000..1e1d49b79e --- /dev/null +++ b/frontend/rust-lib/flowy-ai-pub/src/persistence/local_model_sql.rs @@ -0,0 +1,54 @@ +use diesel::sqlite::SqliteConnection; +use flowy_error::FlowyResult; +use flowy_sqlite::upsert::excluded; +use flowy_sqlite::{ + diesel, + query_dsl::*, + schema::{local_ai_model_table, local_ai_model_table::dsl}, + ExpressionMethods, Identifiable, Insertable, Queryable, +}; + +#[derive(Clone, Default, Queryable, Insertable, Identifiable)] +#[diesel(table_name = local_ai_model_table)] +#[diesel(primary_key(name))] +pub struct LocalAIModelTable { + pub name: String, + pub model_type: i16, +} + +#[derive(Clone, Debug, Copy)] +pub enum ModelType { + Embedding = 0, + Chat = 1, +} + +impl From for ModelType { + fn from(value: i16) -> Self { + match value { + 0 => ModelType::Embedding, + 1 => ModelType::Chat, + _ => ModelType::Embedding, + } + } +} + +pub fn select_local_ai_model(conn: &mut SqliteConnection, name: &str) -> Option { + local_ai_model_table::table + .filter(dsl::name.eq(name)) + .first::(conn) + .ok() +} + +pub fn upsert_local_ai_model( + conn: &mut SqliteConnection, + row: &LocalAIModelTable, +) -> FlowyResult<()> { + diesel::insert_into(local_ai_model_table::table) + .values(row) + .on_conflict(local_ai_model_table::name) + .do_update() + .set((local_ai_model_table::model_type.eq(excluded(local_ai_model_table::model_type)),)) + .execute(conn)?; + + Ok(()) +} diff --git a/frontend/rust-lib/flowy-ai-pub/src/persistence/mod.rs b/frontend/rust-lib/flowy-ai-pub/src/persistence/mod.rs index b21eb507ae..7ae97148ce 100644 --- a/frontend/rust-lib/flowy-ai-pub/src/persistence/mod.rs +++ b/frontend/rust-lib/flowy-ai-pub/src/persistence/mod.rs @@ -1,5 +1,7 @@ mod chat_message_sql; mod chat_sql; +mod local_model_sql; pub use chat_message_sql::*; pub use chat_sql::*; +pub use local_model_sql::*; diff --git a/frontend/rust-lib/flowy-ai/Cargo.toml b/frontend/rust-lib/flowy-ai/Cargo.toml index 3a6aaf5898..fb90714d00 100644 --- a/frontend/rust-lib/flowy-ai/Cargo.toml +++ b/frontend/rust-lib/flowy-ai/Cargo.toml @@ -48,16 +48,16 @@ collab-integrate.workspace = true [target.'cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))'.dependencies] notify = "6.1.1" +ollama-rs = "0.3.0" +#faiss = { version = "0.12.1" } af-mcp = { version = "0.1.0" } [dev-dependencies] -dotenv = "0.15.0" uuid.workspace = true -tracing-subscriber = { version = "0.3.17", features = ["registry", "env-filter", "ansi", "json"] } -simsimd = "4.4.0" [build-dependencies] flowy-codegen.workspace = true [features] dart = ["flowy-codegen/dart", "flowy-notification/dart"] +local_ai = [] \ No newline at end of file diff --git a/frontend/rust-lib/flowy-ai/src/ai_manager.rs b/frontend/rust-lib/flowy-ai/src/ai_manager.rs index 06b3adaeea..d18cb91355 100644 --- a/frontend/rust-lib/flowy-ai/src/ai_manager.rs +++ b/frontend/rust-lib/flowy-ai/src/ai_manager.rs @@ -330,14 +330,10 @@ impl AIManager { .get_question_id_from_answer_id(chat_id, answer_message_id) .await?; - let model = model.map_or_else( - || { - self - .store_preferences - .get_object::(&ai_available_models_key(&chat_id.to_string())) - }, - |model| Some(model.into()), - ); + let model = match model { + None => self.get_active_model(&chat_id.to_string()).await, + Some(model) => Some(model.into()), + }; chat .stream_regenerate_response(question_message_id, answer_stream_port, format, model) .await?; @@ -354,9 +350,10 @@ impl AIManager { "[AI Plugin] update global active model, previous: {}, current: {}", previous_model, current_model ); - let source_key = ai_available_models_key(GLOBAL_ACTIVE_MODEL_KEY); let model = AIModel::local(current_model, "".to_string()); - self.update_selected_model(source_key, model).await?; + self + .update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model) + .await?; } Ok(()) @@ -440,11 +437,11 @@ impl AIManager { } pub async fn update_selected_model(&self, source: String, model: AIModel) -> FlowyResult<()> { - info!( - "[Model Selection] update {} selected model: {:?}", - source, model - ); let source_key = ai_available_models_key(&source); + info!( + "[Model Selection] update {} selected model: {:?} for key:{}", + source, model, source_key + ); self .store_preferences .set_object::(&source_key, &model)?; @@ -458,12 +455,13 @@ impl AIManager { #[instrument(skip_all, level = "debug")] pub async fn toggle_local_ai(&self) -> FlowyResult<()> { let enabled = self.local_ai.toggle_local_ai().await?; - let source_key = ai_available_models_key(GLOBAL_ACTIVE_MODEL_KEY); if enabled { if let Some(name) = self.local_ai.get_plugin_chat_model() { info!("Set global active model to local ai: {}", name); let model = AIModel::local(name, "".to_string()); - self.update_selected_model(source_key, model).await?; + self + .update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model) + .await?; } } else { info!("Set global active model to default"); @@ -471,7 +469,7 @@ impl AIManager { let models = self.get_server_available_models().await?; if let Some(model) = models.into_iter().find(|m| m.name == global_active_model) { self - .update_selected_model(source_key, AIModel::from(model)) + .update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), AIModel::from(model)) .await?; } } @@ -484,21 +482,31 @@ impl AIManager { .store_preferences .get_object::(&ai_available_models_key(source)); - if model.is_none() { - if let Some(local_model) = self.local_ai.get_plugin_chat_model() { - model = Some(AIModel::local(local_model, "".to_string())); - } + match model { + None => { + if let Some(local_model) = self.local_ai.get_plugin_chat_model() { + model = Some(AIModel::local(local_model, "".to_string())); + } + model + }, + Some(mut model) => { + let models = self.local_ai.get_all_chat_local_models().await; + if !models.contains(&model) { + if let Some(local_model) = self.local_ai.get_plugin_chat_model() { + model = AIModel::local(local_model, "".to_string()); + } + } + Some(model) + }, } - - model } pub async fn get_available_models(&self, source: String) -> FlowyResult { let is_local_mode = self.user_service.is_local_model().await?; if is_local_mode { let setting = self.local_ai.get_local_ai_setting(); + let models = self.local_ai.get_all_chat_local_models().await; let selected_model = AIModel::local(setting.chat_model_name, "".to_string()); - let models = vec![selected_model.clone()]; Ok(AvailableModelsPB { models: models.into_iter().map(|m| m.into()).collect(), @@ -506,27 +514,24 @@ impl AIManager { }) } else { // Build the models list from server models and mark them as non-local. - let mut models: Vec = self + let mut all_models: Vec = self .get_server_available_models() .await? .into_iter() .map(AIModel::from) .collect(); - trace!("[Model Selection]: Available models: {:?}", models); - let mut current_active_local_ai_model = None; + trace!("[Model Selection]: Available models: {:?}", all_models); // If user enable local ai, then add local ai model to the list. - if let Some(local_model) = self.local_ai.get_plugin_chat_model() { - let model = AIModel::local(local_model, "".to_string()); - current_active_local_ai_model = Some(model.clone()); - trace!("[Model Selection] current local ai model: {}", model.name); - models.push(model); + if self.local_ai.is_enabled() { + let local_models = self.local_ai.get_all_chat_local_models().await; + all_models.extend(local_models.into_iter().map(|m| m)); } - if models.is_empty() { + if all_models.is_empty() { return Ok(AvailableModelsPB { - models: models.into_iter().map(|m| m.into()).collect(), + models: all_models.into_iter().map(|m| m.into()).collect(), selected_model: AIModelPB::default(), }); } @@ -545,37 +550,29 @@ impl AIManager { let mut user_selected_model = server_active_model.clone(); // when current select model is deprecated, reset the model to default - if !models.iter().any(|m| m.name == server_active_model.name) { + if !all_models + .iter() + .any(|m| m.name == server_active_model.name) + { server_active_model = AIModel::default(); } - let source_key = ai_available_models_key(&source); // We use source to identify user selected model. source can be document id or chat id. - match self.store_preferences.get_object::(&source_key) { + match self.get_active_model(&source).await { None => { // when there is selected model and current local ai is active, then use local ai - if let Some(local_ai_model) = models.iter().find(|m| m.is_local) { + if let Some(local_ai_model) = all_models.iter().find(|m| m.is_local) { user_selected_model = local_ai_model.clone(); } }, - Some(mut model) => { + Some(model) => { trace!("[Model Selection] user previous select model: {:?}", model); - // If source is provided, try to get the user-selected model from the store. User selected - // model will be used as the active model if it exists. - if model.is_local { - if let Some(local_ai_model) = ¤t_active_local_ai_model { - if local_ai_model.name != model.name { - model = local_ai_model.clone(); - } - } - } - user_selected_model = model; }, } // If user selected model is not available in the list, use the global active model. - let active_model = models + let active_model = all_models .iter() .find(|m| m.name == user_selected_model.name) .cloned() @@ -585,15 +582,15 @@ impl AIManager { if let Some(ref active_model) = active_model { if active_model.name != user_selected_model.name { self - .store_preferences - .set_object::(&source_key, &active_model.clone())?; + .update_selected_model(source, active_model.clone()) + .await?; } } trace!("[Model Selection] final active model: {:?}", active_model); let selected_model = AIModelPB::from(active_model.unwrap_or_default()); Ok(AvailableModelsPB { - models: models.into_iter().map(|m| m.into()).collect(), + models: all_models.into_iter().map(|m| m.into()).collect(), selected_model, }) } diff --git a/frontend/rust-lib/flowy-ai/src/entities.rs b/frontend/rust-lib/flowy-ai/src/entities.rs index 5a4aecbbd7..796664a18f 100644 --- a/frontend/rust-lib/flowy-ai/src/entities.rs +++ b/frontend/rust-lib/flowy-ai/src/entities.rs @@ -686,7 +686,7 @@ pub struct LocalAISettingPB { #[pb(index = 2)] #[validate(custom(function = "required_not_empty_str"))] - pub chat_model_name: String, + pub default_model: String, #[pb(index = 3)] #[validate(custom(function = "required_not_empty_str"))] @@ -697,7 +697,7 @@ impl From for LocalAISettingPB { fn from(value: LocalAISetting) -> Self { LocalAISettingPB { server_url: value.ollama_server_url, - chat_model_name: value.chat_model_name, + default_model: value.chat_model_name, embedding_model_name: value.embedding_model_name, } } @@ -707,7 +707,7 @@ impl From for LocalAISetting { fn from(value: LocalAISettingPB) -> Self { LocalAISetting { ollama_server_url: value.server_url, - chat_model_name: value.chat_model_name, + chat_model_name: value.default_model, embedding_model_name: value.embedding_model_name, } } diff --git a/frontend/rust-lib/flowy-ai/src/event_handler.rs b/frontend/rust-lib/flowy-ai/src/event_handler.rs index f85858b1c2..f778063309 100644 --- a/frontend/rust-lib/flowy-ai/src/event_handler.rs +++ b/frontend/rust-lib/flowy-ai/src/event_handler.rs @@ -1,7 +1,6 @@ use crate::ai_manager::{AIManager, GLOBAL_ACTIVE_MODEL_KEY}; use crate::completion::AICompletion; use crate::entities::*; -use crate::util::ai_available_models_key; use flowy_ai_pub::cloud::{AIModel, ChatMessageType}; use flowy_error::{ErrorCode, FlowyError, FlowyResult}; use lib_dispatch::prelude::{data_result_ok, AFPluginData, AFPluginState, DataResult}; @@ -82,8 +81,9 @@ pub(crate) async fn get_server_model_list_handler( ai_manager: AFPluginState>, ) -> DataResult { let ai_manager = upgrade_ai_manager(ai_manager)?; - let source_key = ai_available_models_key(GLOBAL_ACTIVE_MODEL_KEY); - let models = ai_manager.get_available_models(source_key).await?; + let models = ai_manager + .get_available_models(GLOBAL_ACTIVE_MODEL_KEY.to_string()) + .await?; data_result_ok(models) } diff --git a/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs b/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs index 1ec08854e0..d384ddfb75 100644 --- a/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs +++ b/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs @@ -16,9 +16,15 @@ use af_local_ai::ollama_plugin::OllamaAIPlugin; use af_plugin::core::path::is_plugin_ready; use af_plugin::core::plugin::RunningState; use arc_swap::ArcSwapOption; +use flowy_ai_pub::cloud::AIModel; +use flowy_ai_pub::persistence::{ + select_local_ai_model, upsert_local_ai_model, LocalAIModelTable, ModelType, +}; use flowy_ai_pub::user_service::AIUserService; use futures_util::SinkExt; use lib_infra::util::get_operating_system; +use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}; +use ollama_rs::Ollama; use serde::{Deserialize, Serialize}; use std::ops::Deref; use std::path::PathBuf; @@ -39,8 +45,8 @@ impl Default for LocalAISetting { fn default() -> Self { Self { ollama_server_url: "http://localhost:11434".to_string(), - chat_model_name: "llama3.1".to_string(), - embedding_model_name: "nomic-embed-text".to_string(), + chat_model_name: "llama3.1:latest".to_string(), + embedding_model_name: "nomic-embed-text:latest".to_string(), } } } @@ -53,6 +59,7 @@ pub struct LocalAIController { current_chat_id: ArcSwapOption, store_preferences: Weak, user_service: Arc, + ollama: ArcSwapOption, } impl Deref for LocalAIController { @@ -83,69 +90,80 @@ impl LocalAIController { user_service.clone(), res_impl, )); - // Subscribe to state changes - let mut running_state_rx = local_ai.subscribe_running_state(); - let cloned_llm_res = Arc::clone(&local_ai_resource); - let cloned_store_preferences = store_preferences.clone(); - let cloned_local_ai = Arc::clone(&local_ai); - let cloned_user_service = Arc::clone(&user_service); + let ollama = ArcSwapOption::default(); + let sys = get_operating_system(); + if sys.is_desktop() { + let setting = local_ai_resource.get_llm_setting(); + ollama.store( + Ollama::try_new(&setting.ollama_server_url) + .map(Arc::new) + .ok(), + ); - // Spawn a background task to listen for plugin state changes - tokio::spawn(async move { - while let Some(state) = running_state_rx.next().await { - // Skip if we can’t get workspace_id - let Ok(workspace_id) = cloned_user_service.workspace_id() else { - continue; - }; + // Subscribe to state changes + let mut running_state_rx = local_ai.subscribe_running_state(); + let cloned_llm_res = Arc::clone(&local_ai_resource); + let cloned_store_preferences = store_preferences.clone(); + let cloned_local_ai = Arc::clone(&local_ai); + let cloned_user_service = Arc::clone(&user_service); - let key = local_ai_enabled_key(&workspace_id.to_string()); - info!("[AI Plugin] state: {:?}", state); - - // Read whether plugin is enabled from store; default to true - if let Some(store_preferences) = cloned_store_preferences.upgrade() { - let enabled = store_preferences.get_bool(&key).unwrap_or(true); - // Only check resource status if the plugin isn’t in "UnexpectedStop" and is enabled - let (plugin_downloaded, lack_of_resource) = - if !matches!(state, RunningState::UnexpectedStop { .. }) && enabled { - // Possibly check plugin readiness and resource concurrency in parallel, - // but here we do it sequentially for clarity. - let downloaded = is_plugin_ready(); - let resource_lack = cloned_llm_res.get_lack_of_resource().await; - (downloaded, resource_lack) - } else { - (false, None) - }; - - // If plugin is running, retrieve version - let plugin_version = if matches!(state, RunningState::Running { .. }) { - match cloned_local_ai.plugin_info().await { - Ok(info) => Some(info.version), - Err(_) => None, - } - } else { - None + // Spawn a background task to listen for plugin state changes + tokio::spawn(async move { + while let Some(state) = running_state_rx.next().await { + // Skip if we can't get workspace_id + let Ok(workspace_id) = cloned_user_service.workspace_id() else { + continue; }; - // Broadcast the new local AI state - let new_state = RunningStatePB::from(state); - chat_notification_builder( - APPFLOWY_AI_NOTIFICATION_KEY, - ChatNotification::UpdateLocalAIState, - ) - .payload(LocalAIPB { - enabled, - plugin_downloaded, - lack_of_resource, - state: new_state, - plugin_version, - }) - .send(); - } else { - warn!("[AI Plugin] store preferences is dropped"); + let key = crate::local_ai::controller::local_ai_enabled_key(&workspace_id.to_string()); + info!("[AI Plugin] state: {:?}", state); + + // Read whether plugin is enabled from store; default to true + if let Some(store_preferences) = cloned_store_preferences.upgrade() { + let enabled = store_preferences.get_bool(&key).unwrap_or(true); + // Only check resource status if the plugin isn't in "UnexpectedStop" and is enabled + let (plugin_downloaded, lack_of_resource) = + if !matches!(state, RunningState::UnexpectedStop { .. }) && enabled { + // Possibly check plugin readiness and resource concurrency in parallel, + // but here we do it sequentially for clarity. + let downloaded = is_plugin_ready(); + let resource_lack = cloned_llm_res.get_lack_of_resource().await; + (downloaded, resource_lack) + } else { + (false, None) + }; + + // If plugin is running, retrieve version + let plugin_version = if matches!(state, RunningState::Running { .. }) { + match cloned_local_ai.plugin_info().await { + Ok(info) => Some(info.version), + Err(_) => None, + } + } else { + None + }; + + // Broadcast the new local AI state + let new_state = RunningStatePB::from(state); + chat_notification_builder( + APPFLOWY_AI_NOTIFICATION_KEY, + ChatNotification::UpdateLocalAIState, + ) + .payload(LocalAIPB { + enabled, + plugin_downloaded, + lack_of_resource, + state: new_state, + plugin_version, + }) + .send(); + } else { + warn!("[AI Plugin] store preferences is dropped"); + } } - } - }); + }); + } Self { ai_plugin: local_ai, @@ -153,6 +171,7 @@ impl LocalAIController { current_chat_id: ArcSwapOption::default(), store_preferences, user_service, + ollama, } } #[instrument(level = "debug", skip_all)] @@ -287,6 +306,78 @@ impl LocalAIController { self.resource.get_llm_setting() } + pub async fn get_all_chat_local_models(&self) -> Vec { + self + .get_filtered_local_models(|name| !name.contains("embed")) + .await + } + + pub async fn get_all_embedded_local_models(&self) -> Vec { + self + .get_filtered_local_models(|name| name.contains("embed")) + .await + } + + // Helper function to avoid code duplication in model retrieval + async fn get_filtered_local_models(&self, filter_fn: F) -> Vec + where + F: Fn(&str) -> bool, + { + match self.ollama.load_full() { + None => vec![], + Some(ollama) => ollama + .list_local_models() + .await + .map(|models| { + models + .into_iter() + .filter(|m| filter_fn(&m.name.to_lowercase())) + .map(|m| AIModel::local(m.name, String::new())) + .collect() + }) + .unwrap_or_default(), + } + } + + pub async fn check_model_type(&self, model_name: &str) -> FlowyResult { + let uid = self.user_service.user_id()?; + let mut conn = self.user_service.sqlite_connection(uid)?; + match select_local_ai_model(&mut conn, model_name) { + None => { + let ollama = self + .ollama + .load_full() + .ok_or_else(|| FlowyError::local_ai().with_context("ollama is not initialized"))?; + + let request = GenerateEmbeddingsRequest::new( + model_name.to_string(), + EmbeddingsInput::Single("Hello".to_string()), + ); + + let model_type = match ollama.generate_embeddings(request).await { + Ok(value) => { + if value.embeddings.is_empty() { + ModelType::Chat + } else { + ModelType::Embedding + } + }, + Err(_) => ModelType::Chat, + }; + + upsert_local_ai_model( + &mut conn, + &LocalAIModelTable { + name: model_name.to_string(), + model_type: model_type as i16, + }, + )?; + Ok(model_type) + }, + Some(r) => Ok(ModelType::from(r.model_type)), + } + } + pub async fn update_local_ai_setting(&self, setting: LocalAISetting) -> FlowyResult<()> { info!( "[AI Plugin] update local ai setting: {:?}, thread: {:?}", diff --git a/frontend/rust-lib/flowy-ai/src/local_ai/resource.rs b/frontend/rust-lib/flowy-ai/src/local_ai/resource.rs index 36a56e171d..352778f28f 100644 --- a/frontend/rust-lib/flowy-ai/src/local_ai/resource.rs +++ b/frontend/rust-lib/flowy-ai/src/local_ai/resource.rs @@ -161,7 +161,6 @@ impl LocalAIResourceController { let setting = self.get_llm_setting(); let client = Client::builder().timeout(Duration::from_secs(5)).build()?; - match client.get(&setting.ollama_server_url).send().await { Ok(resp) if resp.status().is_success() => { info!( diff --git a/frontend/rust-lib/flowy-error/Cargo.toml b/frontend/rust-lib/flowy-error/Cargo.toml index 61a7422f17..8bc67ee46c 100644 --- a/frontend/rust-lib/flowy-error/Cargo.toml +++ b/frontend/rust-lib/flowy-error/Cargo.toml @@ -36,6 +36,10 @@ client-api = { workspace = true, optional = true } tantivy = { workspace = true, optional = true } uuid.workspace = true +[target.'cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))'.dependencies] +ollama-rs = "0.3.0" + + [features] default = ["impl_from_dispatch_error", "impl_from_serde", "impl_from_reqwest", "impl_from_sqlite"] impl_from_dispatch_error = ["lib-dispatch"] diff --git a/frontend/rust-lib/flowy-error/src/errors.rs b/frontend/rust-lib/flowy-error/src/errors.rs index f76a7d4dda..36240cd08d 100644 --- a/frontend/rust-lib/flowy-error/src/errors.rs +++ b/frontend/rust-lib/flowy-error/src/errors.rs @@ -264,3 +264,10 @@ impl From for FlowyError { FlowyError::internal().with_context(value) } } + +#[cfg(any(target_os = "windows", target_os = "macos", target_os = "linux"))] +impl From for FlowyError { + fn from(value: ollama_rs::error::OllamaError) -> Self { + FlowyError::local_ai().with_context(value) + } +} diff --git a/frontend/rust-lib/flowy-sqlite/migrations/2025-04-25-071459_local_ai_model/down.sql b/frontend/rust-lib/flowy-sqlite/migrations/2025-04-25-071459_local_ai_model/down.sql new file mode 100644 index 0000000000..d9a93fe9a1 --- /dev/null +++ b/frontend/rust-lib/flowy-sqlite/migrations/2025-04-25-071459_local_ai_model/down.sql @@ -0,0 +1 @@ +-- This file should undo anything in `up.sql` diff --git a/frontend/rust-lib/flowy-sqlite/migrations/2025-04-25-071459_local_ai_model/up.sql b/frontend/rust-lib/flowy-sqlite/migrations/2025-04-25-071459_local_ai_model/up.sql new file mode 100644 index 0000000000..243fe61193 --- /dev/null +++ b/frontend/rust-lib/flowy-sqlite/migrations/2025-04-25-071459_local_ai_model/up.sql @@ -0,0 +1,6 @@ +-- Your SQL goes here +CREATE TABLE local_ai_model_table +( + name TEXT PRIMARY KEY NOT NULL, + model_type SMALLINT NOT NULL +); \ No newline at end of file diff --git a/frontend/rust-lib/flowy-sqlite/src/schema.rs b/frontend/rust-lib/flowy-sqlite/src/schema.rs index 0236cbf467..bf7f431682 100644 --- a/frontend/rust-lib/flowy-sqlite/src/schema.rs +++ b/frontend/rust-lib/flowy-sqlite/src/schema.rs @@ -54,6 +54,13 @@ diesel::table! { } } +diesel::table! { + local_ai_model_table (name) { + name -> Text, + model_type -> SmallInt, + } +} + diesel::table! { upload_file_part (upload_id, e_tag) { upload_id -> Text, @@ -133,16 +140,17 @@ diesel::table! { } diesel::allow_tables_to_appear_in_same_query!( - af_collab_metadata, - chat_local_setting_table, - chat_message_table, - chat_table, - collab_snapshot, - upload_file_part, - upload_file_table, - user_data_migration_records, - user_table, - user_workspace_table, - workspace_members_table, - workspace_setting_table, + af_collab_metadata, + chat_local_setting_table, + chat_message_table, + chat_table, + collab_snapshot, + local_ai_model_table, + upload_file_part, + upload_file_table, + user_data_migration_records, + user_table, + user_workspace_table, + workspace_members_table, + workspace_setting_table, ); From 86e6845302d92b9fd7ee6d802c0cd8c47d7fec56 Mon Sep 17 00:00:00 2001 From: Nathan Date: Fri, 25 Apr 2025 22:28:41 +0800 Subject: [PATCH 2/6] chore: update ui --- .../ai/service/ai_model_state_notifier.dart | 6 +- .../prompt_input/select_model_menu.dart | 26 +- .../setting/ai/ai_settings_group.dart | 4 +- .../settings/ai/ollama_setting_bloc.dart | 287 ++++++++++-------- .../setting_ai_view/model_selection.dart | 2 +- .../pages/setting_ai_view/ollama_setting.dart | 65 ++++ frontend/resources/translations/en.json | 3 +- frontend/rust-lib/flowy-ai/src/ai_manager.rs | 201 ++++++------ frontend/rust-lib/flowy-ai/src/entities.rs | 14 +- .../rust-lib/flowy-ai/src/event_handler.rs | 9 + frontend/rust-lib/flowy-ai/src/event_map.rs | 4 + .../flowy-ai/src/local_ai/controller.rs | 6 +- 12 files changed, 378 insertions(+), 249 deletions(-) diff --git a/frontend/appflowy_flutter/lib/ai/service/ai_model_state_notifier.dart b/frontend/appflowy_flutter/lib/ai/service/ai_model_state_notifier.dart index c175c06d07..c360356240 100644 --- a/frontend/appflowy_flutter/lib/ai/service/ai_model_state_notifier.dart +++ b/frontend/appflowy_flutter/lib/ai/service/ai_model_state_notifier.dart @@ -79,7 +79,7 @@ class AIModelStateNotifier { _aiModelSwitchListener.start( onUpdateSelectedModel: (model) async { final updatedModels = _availableModels?.deepCopy() - ?..selectedModel = model; + ?..globalModel = model; _availableModels = updatedModels; _notifyAvailableModelsChanged(); @@ -161,7 +161,7 @@ class AIModelStateNotifier { ); } - if (!availableModels.selectedModel.isLocal) { + if (!availableModels.globalModel.isLocal) { return AIModelState( type: AiType.cloud, hintText: LocaleKeys.chat_inputMessageHint.tr(), @@ -199,7 +199,7 @@ class AIModelStateNotifier { if (availableModels == null) { return ([], null); } - return (availableModels.models, availableModels.selectedModel); + return (availableModels.models, availableModels.globalModel); } void _notifyAvailableModelsChanged() { diff --git a/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart b/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart index 36b389f8c4..317f90ac21 100644 --- a/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart +++ b/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart @@ -218,9 +218,10 @@ class _CurrentModelButton extends StatelessWidget { child: SizedBox( height: DesktopAIPromptSizes.actionBarButtonSize, child: AnimatedSize( - duration: const Duration(milliseconds: 50), - curve: Curves.easeInOut, + duration: const Duration(milliseconds: 200), + curve: Curves.easeOutCubic, alignment: AlignmentDirectional.centerStart, + clipBehavior: Clip.none, child: FlowyHover( style: const HoverStyle( borderRadius: BorderRadius.all(Radius.circular(8)), @@ -228,6 +229,7 @@ class _CurrentModelButton extends StatelessWidget { child: Padding( padding: const EdgeInsetsDirectional.all(4.0), child: Row( + mainAxisSize: MainAxisSize.min, children: [ Padding( // TODO: remove this after change icon to 20px @@ -239,14 +241,18 @@ class _CurrentModelButton extends StatelessWidget { ), ), if (model != null && !model!.isDefault) - Padding( - padding: EdgeInsetsDirectional.only(end: 2.0), - child: FlowyText( - model!.i18n, - fontSize: 12, - figmaLineHeight: 16, - color: Theme.of(context).hintColor, - overflow: TextOverflow.ellipsis, + AnimatedSize( + duration: const Duration(milliseconds: 150), + curve: Curves.easeOutCubic, + child: Padding( + padding: EdgeInsetsDirectional.only(end: 2.0), + child: FlowyText( + model!.i18n, + fontSize: 12, + figmaLineHeight: 16, + color: Theme.of(context).hintColor, + overflow: TextOverflow.ellipsis, + ), ), ), FlowySvg( diff --git a/frontend/appflowy_flutter/lib/mobile/presentation/setting/ai/ai_settings_group.dart b/frontend/appflowy_flutter/lib/mobile/presentation/setting/ai/ai_settings_group.dart index fe0e0a7160..6bed7ed035 100644 --- a/frontend/appflowy_flutter/lib/mobile/presentation/setting/ai/ai_settings_group.dart +++ b/frontend/appflowy_flutter/lib/mobile/presentation/setting/ai/ai_settings_group.dart @@ -36,7 +36,7 @@ class AiSettingsGroup extends StatelessWidget { MobileSettingItem( name: LocaleKeys.settings_aiPage_keys_llmModelType.tr(), trailing: MobileSettingTrailing( - text: state.availableModels?.selectedModel.name ?? "", + text: state.availableModels?.globalModel.name ?? "", ), onTap: () => _onLLMModelTypeTap(context, state), ), @@ -73,7 +73,7 @@ class AiSettingsGroup extends StatelessWidget { text: entry.value.name, showTopBorder: entry.key == 0, isSelected: - availableModels?.selectedModel.name == entry.value.name, + availableModels?.globalModel.name == entry.value.name, onTap: () { context .read() diff --git a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/ollama_setting_bloc.dart b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/ollama_setting_bloc.dart index 2659292b11..3b9060bb08 100644 --- a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/ollama_setting_bloc.dart +++ b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/ollama_setting_bloc.dart @@ -6,177 +6,181 @@ import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart'; import 'package:appflowy_result/appflowy_result.dart'; import 'package:bloc/bloc.dart'; import 'package:collection/collection.dart'; -import 'package:freezed_annotation/freezed_annotation.dart'; import 'package:equatable/equatable.dart'; +import 'package:freezed_annotation/freezed_annotation.dart'; part 'ollama_setting_bloc.freezed.dart'; const kDefaultChatModel = 'llama3.1:latest'; const kDefaultEmbeddingModel = 'nomic-embed-text:latest'; +/// Extension methods to map between PB and UI models class OllamaSettingBloc extends Bloc { OllamaSettingBloc() : super(const OllamaSettingState()) { - on(_handleEvent); + on<_Started>(_handleStarted); + on<_DidLoadLocalModels>(_onLoadLocalModels); + on<_DidLoadSetting>(_onLoadSetting); + on<_UpdateSetting>(_onLoadSetting); + on<_OnEdit>(_onEdit); + on<_OnSubmit>(_onSubmit); + on<_SetDefaultModel>(_onSetDefaultModel); } - Future _handleEvent( - OllamaSettingEvent event, + Future _handleStarted( + _Started event, Emitter emit, ) async { - event.when( - started: () { - AIEventGetLocalAISetting().send().fold( - (setting) { - if (!isClosed) { - add(OllamaSettingEvent.didLoadSetting(setting)); - } - }, - Log.error, - ); - }, - didLoadSetting: (setting) => _updateSetting(setting, emit), - updateSetting: (setting) => _updateSetting(setting, emit), - onEdit: (content, settingType) { - final updatedSubmittedItems = state.submittedItems - .map( - (item) => item.settingType == settingType - ? SubmittedItem( - content: content, - settingType: item.settingType, - ) - : item, - ) - .toList(); + try { + final results = await Future.wait([ + AIEventGetLocalAIModels().send().then((r) => r.getOrThrow()), + AIEventGetLocalAISetting().send().then((r) => r.getOrThrow()), + ]); - // Convert both lists to maps: {settingType: content} - final updatedMap = { - for (final item in updatedSubmittedItems) - item.settingType: item.content, - }; + final models = results[0] as AvailableModelsPB; + final setting = results[1] as LocalAISettingPB; - final inputMap = { - for (final item in state.inputItems) item.settingType: item.content, - }; - - // Compare maps instead of lists - final isEdited = !const MapEquality() - .equals(updatedMap, inputMap); - - emit( - state.copyWith( - submittedItems: updatedSubmittedItems, - isEdited: isEdited, - ), - ); - }, - submit: () { - final setting = LocalAISettingPB(); - final settingUpdaters = { - SettingType.serverUrl: (value) => setting.serverUrl = value, - SettingType.chatModel: (value) => setting.defaultModel = value, - SettingType.embeddingModel: (value) => - setting.embeddingModelName = value, - }; - - for (final item in state.submittedItems) { - settingUpdaters[item.settingType]?.call(item.content); - } - add(OllamaSettingEvent.updateSetting(setting)); - AIEventUpdateLocalAISetting(setting).send().fold( - (_) => Log.info('AI setting updated successfully'), - (err) => Log.error("update ai setting failed: $err"), - ); - }, - ); + if (!isClosed) { + add(OllamaSettingEvent.didLoadLocalModels(models)); + add(OllamaSettingEvent.didLoadSetting(setting)); + } + } catch (e, st) { + Log.error('Failed to load initial AI data: $e\n$st'); + } } - void _updateSetting( - LocalAISettingPB setting, + void _onLoadLocalModels( + _DidLoadLocalModels event, Emitter emit, ) { + emit(state.copyWith(localModels: event.models)); + } + + void _onLoadSetting( + dynamic event, + Emitter emit, + ) { + final setting = (event as dynamic).setting as LocalAISettingPB; + final submitted = setting.toSubmittedItems(); emit( state.copyWith( setting: setting, - inputItems: _createInputItems(setting), - submittedItems: _createSubmittedItems(setting), - isEdited: false, // Reset to false when the setting is loaded/updated. + inputItems: setting.toInputItems(), + submittedItems: submitted, + originalMap: { + for (final item in submitted) item.settingType: item.content, + }, + isEdited: false, ), ); } - List _createInputItems(LocalAISettingPB setting) => [ - SettingItem( - content: setting.serverUrl, - hintText: 'http://localhost:11434', - settingType: SettingType.serverUrl, - ), - SettingItem( - content: setting.defaultModel, - hintText: kDefaultChatModel, - settingType: SettingType.chatModel, - ), - SettingItem( - content: setting.embeddingModelName, - hintText: kDefaultEmbeddingModel, - settingType: SettingType.embeddingModel, - ), - ]; + void _onEdit( + _OnEdit event, + Emitter emit, + ) { + final updated = state.submittedItems + .map( + (item) => item.settingType == event.settingType + ? item.copyWith(content: event.content) + : item, + ) + .toList(); - List _createSubmittedItems(LocalAISettingPB setting) => [ - SubmittedItem( - content: setting.serverUrl, - settingType: SettingType.serverUrl, - ), - SubmittedItem( - content: setting.defaultModel, - settingType: SettingType.chatModel, - ), - SubmittedItem( - content: setting.embeddingModelName, - settingType: SettingType.embeddingModel, - ), - ]; + final currentMap = {for (final i in updated) i.settingType: i.content}; + final isEdited = !const MapEquality() + .equals(state.originalMap, currentMap); + + emit(state.copyWith(submittedItems: updated, isEdited: isEdited)); + } + + void _onSubmit( + _OnSubmit event, + Emitter emit, + ) { + final pb = LocalAISettingPB(); + for (final item in state.submittedItems) { + switch (item.settingType) { + case SettingType.serverUrl: + pb.serverUrl = item.content; + break; + case SettingType.chatModel: + pb.globalChatModel = state.selectedModel?.name ?? item.content; + break; + case SettingType.embeddingModel: + pb.embeddingModelName = item.content; + break; + } + } + add(OllamaSettingEvent.updateSetting(pb)); + AIEventUpdateLocalAISetting(pb).send().fold( + (_) => Log.info('AI setting updated successfully'), + (err) => Log.error('Update AI setting failed: $err'), + ); + } + + void _onSetDefaultModel( + _SetDefaultModel event, + Emitter emit, + ) { + emit(state.copyWith(selectedModel: event.model, isEdited: true)); + } } -// Create an enum for setting type. +/// Setting types for mapping enum SettingType { serverUrl, chatModel, - embeddingModel; // semicolon needed after the enum values + embeddingModel; String get title { switch (this) { case SettingType.serverUrl: return 'Ollama server url'; case SettingType.chatModel: - return 'Chat model name'; + return 'Default model name'; case SettingType.embeddingModel: return 'Embedding model name'; } } } +/// Input field representation class SettingItem extends Equatable { const SettingItem({ required this.content, required this.hintText, required this.settingType, }); + final String content; final String hintText; final SettingType settingType; + @override List get props => [content, settingType]; } +/// Items pending submission class SubmittedItem extends Equatable { const SubmittedItem({ required this.content, required this.settingType, }); + final String content; final SettingType settingType; + /// Returns a copy of this SubmittedItem with given fields updated. + SubmittedItem copyWith({ + String? content, + SettingType? settingType, + }) { + return SubmittedItem( + content: content ?? this.content, + settingType: settingType ?? this.settingType, + ); + } + @override List get props => [content, settingType]; } @@ -184,10 +188,18 @@ class SubmittedItem extends Equatable { @freezed class OllamaSettingEvent with _$OllamaSettingEvent { const factory OllamaSettingEvent.started() = _Started; - const factory OllamaSettingEvent.didLoadSetting(LocalAISettingPB setting) = - _DidLoadSetting; - const factory OllamaSettingEvent.updateSetting(LocalAISettingPB setting) = - _UpdateSetting; + const factory OllamaSettingEvent.didLoadLocalModels( + AvailableModelsPB models, + ) = _DidLoadLocalModels; + const factory OllamaSettingEvent.didLoadSetting( + LocalAISettingPB setting, + ) = _DidLoadSetting; + const factory OllamaSettingEvent.updateSetting( + LocalAISettingPB setting, + ) = _UpdateSetting; + const factory OllamaSettingEvent.setDefaultModel( + AIModelPB model, + ) = _SetDefaultModel; const factory OllamaSettingEvent.onEdit( String content, SettingType settingType, @@ -199,25 +211,42 @@ class OllamaSettingEvent with _$OllamaSettingEvent { class OllamaSettingState with _$OllamaSettingState { const factory OllamaSettingState({ LocalAISettingPB? setting, - @Default([ - SettingItem( - content: 'http://localhost:11434', - hintText: 'http://localhost:11434', - settingType: SettingType.serverUrl, - ), - SettingItem( - content: kDefaultChatModel, - hintText: kDefaultChatModel, - settingType: SettingType.chatModel, - ), - SettingItem( - content: kDefaultEmbeddingModel, - hintText: kDefaultEmbeddingModel, - settingType: SettingType.embeddingModel, - ), - ]) - List inputItems, + @Default([]) List inputItems, + AIModelPB? selectedModel, + AvailableModelsPB? localModels, + AIModelPB? defaultModel, @Default([]) List submittedItems, @Default(false) bool isEdited, - }) = _PluginStateState; + @Default({}) Map originalMap, + }) = _OllamaSettingState; +} + +extension on LocalAISettingPB { + List toInputItems() => [ + SettingItem( + content: serverUrl, + hintText: 'http://localhost:11434', + settingType: SettingType.serverUrl, + ), + SettingItem( + content: embeddingModelName, + hintText: kDefaultEmbeddingModel, + settingType: SettingType.embeddingModel, + ), + ]; + + List toSubmittedItems() => [ + SubmittedItem( + content: serverUrl, + settingType: SettingType.serverUrl, + ), + SubmittedItem( + content: globalChatModel, + settingType: SettingType.chatModel, + ), + SubmittedItem( + content: embeddingModelName, + settingType: SettingType.embeddingModel, + ), + ]; } diff --git a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart index 7357c2951c..83f4ff603e 100644 --- a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart +++ b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart @@ -30,7 +30,7 @@ class AIModelSelection extends StatelessWidget { final localModels = models.where((model) => model.isLocal).toList(); final cloudModels = models.where((model) => !model.isLocal).toList(); - final selectedModel = state.availableModels!.selectedModel; + final selectedModel = state.availableModels!.globalModel; return Padding( padding: const EdgeInsets.symmetric(vertical: 6), diff --git a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/ollama_setting.dart b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/ollama_setting.dart index 6f38043927..fc56fc61e7 100644 --- a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/ollama_setting.dart +++ b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/ollama_setting.dart @@ -7,6 +7,14 @@ import 'package:flowy_infra_ui/widget/spacing.dart'; import 'package:flutter/material.dart'; import 'package:flutter_bloc/flutter_bloc.dart'; +import 'package:appflowy/ai/ai.dart'; +import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart'; + +import 'package:appflowy/generated/locale_keys.g.dart'; +import 'package:appflowy/workspace/presentation/settings/shared/af_dropdown_menu_entry.dart'; +import 'package:appflowy/workspace/presentation/settings/shared/settings_dropdown.dart'; +import 'package:easy_localization/easy_localization.dart'; + class OllamaSettingPage extends StatelessWidget { const OllamaSettingPage({super.key}); @@ -32,6 +40,7 @@ class OllamaSettingPage extends StatelessWidget { children: [ for (final item in state.inputItems) _SettingItemWidget(item: item), + const LocalAIModelSelection(), _SaveButton(isEdited: state.isEdited), ], ), @@ -113,3 +122,59 @@ class _SaveButton extends StatelessWidget { ); } } + +class LocalAIModelSelection extends StatelessWidget { + const LocalAIModelSelection({super.key}); + static const double height = 49; + + @override + Widget build(BuildContext context) { + return BlocBuilder( + buildWhen: (previous, current) => + previous.localModels != current.localModels, + builder: (context, state) { + final models = state.localModels; + if (models == null) { + return const SizedBox( + // Using same height as SettingsDropdown to avoid layout shift + height: height, + ); + } + + return Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + FlowyText.medium( + LocaleKeys.settings_aiPage_keys_globalLLMModel.tr(), + fontSize: 12, + figmaLineHeight: 16, + ), + const VSpace(4), + SizedBox( + height: 40, + child: SettingsDropdown( + key: const Key('_AIModelSelection'), + onChanged: (model) => context + .read() + .add(OllamaSettingEvent.setDefaultModel(model)), + selectedOption: models.globalModel, + selectOptionCompare: (left, right) => left?.name == right?.name, + options: models.models + .map( + (model) => buildDropdownMenuEntry( + context, + value: model, + label: model.i18n, + subLabel: model.desc, + maximumHeight: height, + ), + ) + .toList(), + ), + ), + ], + ); + }, + ); + } +} diff --git a/frontend/resources/translations/en.json b/frontend/resources/translations/en.json index 91abe94cc5..27940824a9 100644 --- a/frontend/resources/translations/en.json +++ b/frontend/resources/translations/en.json @@ -866,6 +866,7 @@ "aiSettingsDescription": "Choose your preferred model to power AppFlowy AI. Now includes GPT-4o, GPT-o3-mini, DeepSeek R1, Claude 3.5 Sonnet, and models available in Ollama", "loginToEnableAIFeature": "AI features are only enabled after logging in with @:appName Cloud. If you don't have an @:appName account, go to 'My Account' to sign up", "llmModel": "Language Model", + "globalLLMModel": "Global Language Model", "llmModelType": "Language Model Type", "downloadLLMPrompt": "Download {}", "downloadAppFlowyOfflineAI": "Downloading AI offline package will enable AI to run on your device. Do you want to continue?", @@ -3342,4 +3343,4 @@ "rewrite": "Rewrite", "insertBelow": "Insert below" } -} +} \ No newline at end of file diff --git a/frontend/rust-lib/flowy-ai/src/ai_manager.rs b/frontend/rust-lib/flowy-ai/src/ai_manager.rs index d18cb91355..bfddebcd8e 100644 --- a/frontend/rust-lib/flowy-ai/src/ai_manager.rs +++ b/frontend/rust-lib/flowy-ai/src/ai_manager.rs @@ -341,14 +341,14 @@ impl AIManager { } pub async fn update_local_ai_setting(&self, setting: LocalAISetting) -> FlowyResult<()> { - let previous_model = self.local_ai.get_local_ai_setting().chat_model_name; + let old_settings = self.local_ai.get_local_ai_setting(); + let need_restart = old_settings.ollama_server_url != setting.ollama_server_url; self.local_ai.update_local_ai_setting(setting).await?; let current_model = self.local_ai.get_local_ai_setting().chat_model_name; - - if previous_model != current_model { + if old_settings.chat_model_name != current_model { info!( "[AI Plugin] update global active model, previous: {}, current: {}", - previous_model, current_model + old_settings.chat_model_name, current_model ); let model = AIModel::local(current_model, "".to_string()); self @@ -356,6 +356,9 @@ impl AIManager { .await?; } + if need_restart { + self.local_ai.restart_plugin().await; + } Ok(()) } @@ -446,7 +449,7 @@ impl AIManager { .store_preferences .set_object::(&source_key, &model)?; - chat_notification_builder(&source, ChatNotification::DidUpdateSelectedModel) + chat_notification_builder(&source_key, ChatNotification::DidUpdateSelectedModel) .payload(AIModelPB::from(model)) .send(); Ok(()) @@ -501,99 +504,109 @@ impl AIManager { } } + pub async fn get_local_available_models(&self) -> FlowyResult { + let setting = self.local_ai.get_local_ai_setting(); + let models = self.local_ai.get_all_chat_local_models().await; + let selected_model = AIModel::local(setting.chat_model_name, "".to_string()); + + Ok(AvailableModelsPB { + models: models.into_iter().map(AIModelPB::from).collect(), + global_model: AIModelPB::from(selected_model), + }) + } + pub async fn get_available_models(&self, source: String) -> FlowyResult { let is_local_mode = self.user_service.is_local_model().await?; if is_local_mode { - let setting = self.local_ai.get_local_ai_setting(); - let models = self.local_ai.get_all_chat_local_models().await; - let selected_model = AIModel::local(setting.chat_model_name, "".to_string()); - - Ok(AvailableModelsPB { - models: models.into_iter().map(|m| m.into()).collect(), - selected_model: AIModelPB::from(selected_model), - }) - } else { - // Build the models list from server models and mark them as non-local. - let mut all_models: Vec = self - .get_server_available_models() - .await? - .into_iter() - .map(AIModel::from) - .collect(); - - trace!("[Model Selection]: Available models: {:?}", all_models); - - // If user enable local ai, then add local ai model to the list. - if self.local_ai.is_enabled() { - let local_models = self.local_ai.get_all_chat_local_models().await; - all_models.extend(local_models.into_iter().map(|m| m)); - } - - if all_models.is_empty() { - return Ok(AvailableModelsPB { - models: all_models.into_iter().map(|m| m.into()).collect(), - selected_model: AIModelPB::default(), - }); - } - - // Global active model is the model selected by the user in the workspace settings. - let mut server_active_model = self - .get_workspace_select_model() - .await - .map(|m| AIModel::server(m, "".to_string())) - .unwrap_or_else(|_| AIModel::default()); - - trace!( - "[Model Selection] server active model: {:?}", - server_active_model - ); - - let mut user_selected_model = server_active_model.clone(); - // when current select model is deprecated, reset the model to default - if !all_models - .iter() - .any(|m| m.name == server_active_model.name) - { - server_active_model = AIModel::default(); - } - - // We use source to identify user selected model. source can be document id or chat id. - match self.get_active_model(&source).await { - None => { - // when there is selected model and current local ai is active, then use local ai - if let Some(local_ai_model) = all_models.iter().find(|m| m.is_local) { - user_selected_model = local_ai_model.clone(); - } - }, - Some(model) => { - trace!("[Model Selection] user previous select model: {:?}", model); - user_selected_model = model; - }, - } - - // If user selected model is not available in the list, use the global active model. - let active_model = all_models - .iter() - .find(|m| m.name == user_selected_model.name) - .cloned() - .or(Some(server_active_model.clone())); - - // Update the stored preference if a different model is used. - if let Some(ref active_model) = active_model { - if active_model.name != user_selected_model.name { - self - .update_selected_model(source, active_model.clone()) - .await?; - } - } - - trace!("[Model Selection] final active model: {:?}", active_model); - let selected_model = AIModelPB::from(active_model.unwrap_or_default()); - Ok(AvailableModelsPB { - models: all_models.into_iter().map(|m| m.into()).collect(), - selected_model, - }) + return self.get_local_available_models().await; } + + // Fetch server models + let mut all_models: Vec = self + .get_server_available_models() + .await? + .into_iter() + .map(AIModel::from) + .collect(); + + trace!("[Model Selection]: Available models: {:?}", all_models); + + // Add local models if enabled + if self.local_ai.is_enabled() { + let setting = self.local_ai.get_local_ai_setting(); + all_models.push(AIModel::local(setting.chat_model_name, "".to_string()).into()); + } + + // Return early if no models available + if all_models.is_empty() { + return Ok(AvailableModelsPB { + models: Vec::new(), + global_model: AIModelPB::default(), + }); + } + + // Get server active model (only once) + let server_active_model = self + .get_workspace_select_model() + .await + .map(|m| AIModel::server(m, "".to_string())) + .unwrap_or_else(|_| AIModel::default()); + + trace!( + "[Model Selection] server active model: {:?}", + server_active_model + ); + + // Use server model as default if it exists in available models + let default_model = if all_models + .iter() + .any(|m| m.name == server_active_model.name) + { + server_active_model.clone() + } else { + AIModel::default() + }; + + // Get user's previously selected model + let user_selected_model = match self.get_active_model(&source).await { + Some(model) => { + trace!("[Model Selection] user previous select model: {:?}", model); + model + }, + None => { + // When no selected model and local AI is active, use local AI model + all_models + .iter() + .find(|m| m.is_local) + .cloned() + .unwrap_or_else(|| default_model.clone()) + }, + }; + + // Determine final active model - use user's selection if available, otherwise default + let active_model = all_models + .iter() + .find(|m| m.name == user_selected_model.name) + .cloned() + .unwrap_or(default_model.clone()); + + // Update stored preference if changed + if active_model.name != user_selected_model.name { + if let Err(err) = self + .update_selected_model(source, active_model.clone()) + .await + { + error!("[Model Selection] failed to update selected model: {}", err); + } + } + + trace!("[Model Selection] final active model: {:?}", active_model); + + // Create response with one transformation pass + Ok(AvailableModelsPB { + models: all_models.into_iter().map(AIModelPB::from).collect(), + global_model: AIModelPB::from(active_model), + }) } pub async fn get_or_create_chat_instance(&self, chat_id: &Uuid) -> Result, FlowyError> { diff --git a/frontend/rust-lib/flowy-ai/src/entities.rs b/frontend/rust-lib/flowy-ai/src/entities.rs index 796664a18f..5e03fadf56 100644 --- a/frontend/rust-lib/flowy-ai/src/entities.rs +++ b/frontend/rust-lib/flowy-ai/src/entities.rs @@ -222,7 +222,13 @@ pub struct AvailableModelsPB { pub models: Vec, #[pb(index = 2)] - pub selected_model: AIModelPB, + pub global_model: AIModelPB, +} + +#[derive(Default, ProtoBuf, Clone, Debug)] +pub struct RepeatedAIModelPB { + #[pb(index = 1)] + pub items: Vec, } #[derive(Default, ProtoBuf, Clone, Debug)] @@ -686,7 +692,7 @@ pub struct LocalAISettingPB { #[pb(index = 2)] #[validate(custom(function = "required_not_empty_str"))] - pub default_model: String, + pub global_chat_model: String, #[pb(index = 3)] #[validate(custom(function = "required_not_empty_str"))] @@ -697,7 +703,7 @@ impl From for LocalAISettingPB { fn from(value: LocalAISetting) -> Self { LocalAISettingPB { server_url: value.ollama_server_url, - default_model: value.chat_model_name, + global_chat_model: value.chat_model_name, embedding_model_name: value.embedding_model_name, } } @@ -707,7 +713,7 @@ impl From for LocalAISetting { fn from(value: LocalAISettingPB) -> Self { LocalAISetting { ollama_server_url: value.server_url, - chat_model_name: value.default_model, + chat_model_name: value.global_chat_model, embedding_model_name: value.embedding_model_name, } } diff --git a/frontend/rust-lib/flowy-ai/src/event_handler.rs b/frontend/rust-lib/flowy-ai/src/event_handler.rs index f778063309..160fbe6928 100644 --- a/frontend/rust-lib/flowy-ai/src/event_handler.rs +++ b/frontend/rust-lib/flowy-ai/src/event_handler.rs @@ -340,6 +340,15 @@ pub(crate) async fn get_local_ai_setting_handler( data_result_ok(pb) } +#[tracing::instrument(level = "debug", skip_all)] +pub(crate) async fn get_local_ai_models_handler( + ai_manager: AFPluginState>, +) -> DataResult { + let ai_manager = upgrade_ai_manager(ai_manager)?; + let data = ai_manager.get_local_available_models().await?; + data_result_ok(data) +} + #[tracing::instrument(level = "debug", skip_all, err)] pub(crate) async fn update_local_ai_setting_handler( ai_manager: AFPluginState>, diff --git a/frontend/rust-lib/flowy-ai/src/event_map.rs b/frontend/rust-lib/flowy-ai/src/event_map.rs index 5020836a30..ee77f454a5 100644 --- a/frontend/rust-lib/flowy-ai/src/event_map.rs +++ b/frontend/rust-lib/flowy-ai/src/event_map.rs @@ -31,6 +31,7 @@ pub fn init(ai_manager: Weak) -> AFPlugin { .event(AIEvent::ToggleLocalAI, toggle_local_ai_handler) .event(AIEvent::GetLocalAIState, get_local_ai_state_handler) .event(AIEvent::GetLocalAISetting, get_local_ai_setting_handler) + .event(AIEvent::GetLocalAIModels, get_local_ai_models_handler) .event( AIEvent::UpdateLocalAISetting, update_local_ai_setting_handler, @@ -121,4 +122,7 @@ pub enum AIEvent { #[event(input = "UpdateSelectedModelPB")] UpdateSelectedModel = 32, + + #[event(output = "AvailableModelsPB")] + GetLocalAIModels = 33, } diff --git a/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs b/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs index d384ddfb75..6b147b9c0c 100644 --- a/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs +++ b/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs @@ -384,11 +384,7 @@ impl LocalAIController { setting, std::thread::current().id() ); - - if self.resource.set_llm_setting(setting).await.is_ok() { - let is_enabled = self.is_enabled(); - self.toggle_plugin(is_enabled).await?; - } + self.resource.set_llm_setting(setting).await?; Ok(()) } From 90000add22bf1b47ef57eee0e1015a02debb7407 Mon Sep 17 00:00:00 2001 From: Nathan Date: Fri, 25 Apr 2025 23:41:35 +0800 Subject: [PATCH 3/6] chore: update ui --- frontend/rust-lib/flowy-ai/src/ai_manager.rs | 28 +++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/frontend/rust-lib/flowy-ai/src/ai_manager.rs b/frontend/rust-lib/flowy-ai/src/ai_manager.rs index bfddebcd8e..c21ccef080 100644 --- a/frontend/rust-lib/flowy-ai/src/ai_manager.rs +++ b/frontend/rust-lib/flowy-ai/src/ai_manager.rs @@ -342,15 +342,24 @@ impl AIManager { pub async fn update_local_ai_setting(&self, setting: LocalAISetting) -> FlowyResult<()> { let old_settings = self.local_ai.get_local_ai_setting(); - let need_restart = old_settings.ollama_server_url != setting.ollama_server_url; - self.local_ai.update_local_ai_setting(setting).await?; - let current_model = self.local_ai.get_local_ai_setting().chat_model_name; - if old_settings.chat_model_name != current_model { + // Only restart if the server URL has changed and local AI is not running + let need_restart = + old_settings.ollama_server_url != setting.ollama_server_url && !self.local_ai.is_running(); + + // Update settings first + self + .local_ai + .update_local_ai_setting(setting.clone()) + .await?; + + // Handle model change if needed + let model_changed = old_settings.chat_model_name != setting.chat_model_name; + if model_changed { info!( "[AI Plugin] update global active model, previous: {}, current: {}", - old_settings.chat_model_name, current_model + old_settings.chat_model_name, setting.chat_model_name ); - let model = AIModel::local(current_model, "".to_string()); + let model = AIModel::local(setting.chat_model_name, "".to_string()); self .update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model) .await?; @@ -359,6 +368,7 @@ impl AIManager { if need_restart { self.local_ai.restart_plugin().await; } + Ok(()) } @@ -506,9 +516,13 @@ impl AIManager { pub async fn get_local_available_models(&self) -> FlowyResult { let setting = self.local_ai.get_local_ai_setting(); - let models = self.local_ai.get_all_chat_local_models().await; + let mut models = self.local_ai.get_all_chat_local_models().await; let selected_model = AIModel::local(setting.chat_model_name, "".to_string()); + if models.is_empty() { + models.push(selected_model.clone()); + } + Ok(AvailableModelsPB { models: models.into_iter().map(AIModelPB::from).collect(), global_model: AIModelPB::from(selected_model), From 3bc0cc7b43db809718b82aee27299a8482614bcb Mon Sep 17 00:00:00 2001 From: Nathan Date: Sat, 26 Apr 2025 00:56:45 +0800 Subject: [PATCH 4/6] chore: rename --- .../ai/service/ai_model_state_notifier.dart | 34 +++++++++---------- .../lib/ai/service/select_model_bloc.dart | 2 +- .../setting/ai/ai_settings_group.dart | 4 +-- .../message/ai_message_action_bar.dart | 2 +- .../message/ai_message_bubble.dart | 2 +- .../settings/ai/ollama_setting_bloc.dart | 8 ++--- .../settings/ai/settings_ai_bloc.dart | 9 ++--- .../setting_ai_view/model_selection.dart | 2 +- .../pages/setting_ai_view/ollama_setting.dart | 2 +- frontend/rust-lib/flowy-ai/src/ai_manager.rs | 32 ++++++++++------- frontend/rust-lib/flowy-ai/src/entities.rs | 8 ++--- .../rust-lib/flowy-ai/src/event_handler.rs | 22 ++++++------ frontend/rust-lib/flowy-ai/src/event_map.rs | 29 +++++++++------- frontend/rust-lib/flowy-sqlite/src/schema.rs | 26 +++++++------- 14 files changed, 97 insertions(+), 85 deletions(-) diff --git a/frontend/appflowy_flutter/lib/ai/service/ai_model_state_notifier.dart b/frontend/appflowy_flutter/lib/ai/service/ai_model_state_notifier.dart index c360356240..783046a3d0 100644 --- a/frontend/appflowy_flutter/lib/ai/service/ai_model_state_notifier.dart +++ b/frontend/appflowy_flutter/lib/ai/service/ai_model_state_notifier.dart @@ -53,7 +53,7 @@ class AIModelStateNotifier { final LocalAIStateListener? _localAIListener; final AIModelSwitchListener _aiModelSwitchListener; LocalAIPB? _localAIState; - AvailableModelsPB? _availableModels; + ModelSelectionPB? _sourceModelSelection; // callbacks final List _stateChangedCallbacks = []; @@ -69,7 +69,7 @@ class AIModelStateNotifier { if (state.state == RunningStatePB.Running || state.state == RunningStatePB.Stopped) { - await _loadAvailableModels(); + await _loadModelSelection(); _notifyAvailableModelsChanged(); } }, @@ -78,11 +78,11 @@ class AIModelStateNotifier { _aiModelSwitchListener.start( onUpdateSelectedModel: (model) async { - final updatedModels = _availableModels?.deepCopy() - ?..globalModel = model; - _availableModels = updatedModels; - _notifyAvailableModelsChanged(); + final updatedModels = _sourceModelSelection?.deepCopy() + ?..selectedModel = model; + _sourceModelSelection = updatedModels; + _notifyAvailableModelsChanged(); if (model.isLocal && UniversalPlatform.isDesktop) { await _loadLocalAiState(); } @@ -92,7 +92,7 @@ class AIModelStateNotifier { } void _init() async { - await Future.wait([_loadLocalAiState(), _loadAvailableModels()]); + await Future.wait([_loadLocalAiState(), _loadModelSelection()]); _notifyStateChanged(); _notifyAvailableModelsChanged(); } @@ -139,7 +139,7 @@ class AIModelStateNotifier { ); } - final availableModels = _availableModels; + final availableModels = _sourceModelSelection; final localAiState = _localAIState; if (availableModels == null) { @@ -161,7 +161,7 @@ class AIModelStateNotifier { ); } - if (!availableModels.globalModel.isLocal) { + if (!availableModels.selectedModel.isLocal) { return AIModelState( type: AiType.cloud, hintText: LocaleKeys.chat_inputMessageHint.tr(), @@ -194,16 +194,16 @@ class AIModelStateNotifier { ); } - (List, AIModelPB?) getAvailableModels() { - final availableModels = _availableModels; + (List, AIModelPB?) getModelSelection() { + final availableModels = _sourceModelSelection; if (availableModels == null) { return ([], null); } - return (availableModels.models, availableModels.globalModel); + return (availableModels.models, availableModels.selectedModel); } void _notifyAvailableModelsChanged() { - final (models, selectedModel) = getAvailableModels(); + final (models, selectedModel) = getModelSelection(); for (final callback in _availableModelsChangedCallbacks) { callback(models, selectedModel); } @@ -216,10 +216,10 @@ class AIModelStateNotifier { } } - Future _loadAvailableModels() { - final payload = AvailableModelsQueryPB(source: objectId); - return AIEventGetAvailableModels(payload).send().fold( - (models) => _availableModels = models, + Future _loadModelSelection() { + final payload = ModelSourcePB(source: objectId); + return AIEventGetSourceModelSelection(payload).send().fold( + (models) => _sourceModelSelection = models, (err) => Log.error("Failed to get available models: $err"), ); } diff --git a/frontend/appflowy_flutter/lib/ai/service/select_model_bloc.dart b/frontend/appflowy_flutter/lib/ai/service/select_model_bloc.dart index 7ad52b9ec4..4044227fb1 100644 --- a/frontend/appflowy_flutter/lib/ai/service/select_model_bloc.dart +++ b/frontend/appflowy_flutter/lib/ai/service/select_model_bloc.dart @@ -83,7 +83,7 @@ class SelectModelState with _$SelectModelState { }) = _SelectModelState; factory SelectModelState.initial(AIModelStateNotifier notifier) { - final (models, selectedModel) = notifier.getAvailableModels(); + final (models, selectedModel) = notifier.getModelSelection(); return SelectModelState( models: models, selectedModel: selectedModel, diff --git a/frontend/appflowy_flutter/lib/mobile/presentation/setting/ai/ai_settings_group.dart b/frontend/appflowy_flutter/lib/mobile/presentation/setting/ai/ai_settings_group.dart index 6bed7ed035..fe0e0a7160 100644 --- a/frontend/appflowy_flutter/lib/mobile/presentation/setting/ai/ai_settings_group.dart +++ b/frontend/appflowy_flutter/lib/mobile/presentation/setting/ai/ai_settings_group.dart @@ -36,7 +36,7 @@ class AiSettingsGroup extends StatelessWidget { MobileSettingItem( name: LocaleKeys.settings_aiPage_keys_llmModelType.tr(), trailing: MobileSettingTrailing( - text: state.availableModels?.globalModel.name ?? "", + text: state.availableModels?.selectedModel.name ?? "", ), onTap: () => _onLLMModelTypeTap(context, state), ), @@ -73,7 +73,7 @@ class AiSettingsGroup extends StatelessWidget { text: entry.value.name, showTopBorder: entry.key == 0, isSelected: - availableModels?.globalModel.name == entry.value.name, + availableModels?.selectedModel.name == entry.value.name, onTap: () { context .read() diff --git a/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_message_action_bar.dart b/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_message_action_bar.dart index 738e8c9574..cab3d486cf 100644 --- a/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_message_action_bar.dart +++ b/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_message_action_bar.dart @@ -448,7 +448,7 @@ class _ChangeModelButtonState extends State { child: buildButton(context), popupBuilder: (_) { final bloc = context.read(); - final (models, _) = bloc.aiModelStateNotifier.getAvailableModels(); + final (models, _) = bloc.aiModelStateNotifier.getModelSelection(); return SelectModelPopoverContent( models: models, selectedModel: null, diff --git a/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_message_bubble.dart b/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_message_bubble.dart index 85796cae65..03f0d61ebc 100644 --- a/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_message_bubble.dart +++ b/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_message_bubble.dart @@ -407,7 +407,7 @@ class ChatAIMessagePopup extends StatelessWidget { return MobileQuickActionButton( onTap: () async { final bloc = context.read(); - final (models, _) = bloc.aiModelStateNotifier.getAvailableModels(); + final (models, _) = bloc.aiModelStateNotifier.getModelSelection(); final result = await showChangeModelBottomSheet(context, models); if (result != null) { onChangeModel?.call(result); diff --git a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/ollama_setting_bloc.dart b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/ollama_setting_bloc.dart index 3b9060bb08..2a058c6f37 100644 --- a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/ollama_setting_bloc.dart +++ b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/ollama_setting_bloc.dart @@ -32,11 +32,11 @@ class OllamaSettingBloc extends Bloc { ) async { try { final results = await Future.wait([ - AIEventGetLocalAIModels().send().then((r) => r.getOrThrow()), + AIEventGetLocalModelSelection().send().then((r) => r.getOrThrow()), AIEventGetLocalAISetting().send().then((r) => r.getOrThrow()), ]); - final models = results[0] as AvailableModelsPB; + final models = results[0] as ModelSelectionPB; final setting = results[1] as LocalAISettingPB; if (!isClosed) { @@ -189,7 +189,7 @@ class SubmittedItem extends Equatable { class OllamaSettingEvent with _$OllamaSettingEvent { const factory OllamaSettingEvent.started() = _Started; const factory OllamaSettingEvent.didLoadLocalModels( - AvailableModelsPB models, + ModelSelectionPB models, ) = _DidLoadLocalModels; const factory OllamaSettingEvent.didLoadSetting( LocalAISettingPB setting, @@ -213,7 +213,7 @@ class OllamaSettingState with _$OllamaSettingState { LocalAISettingPB? setting, @Default([]) List inputItems, AIModelPB? selectedModel, - AvailableModelsPB? localModels, + ModelSelectionPB? localModels, AIModelPB? defaultModel, @Default([]) List submittedItems, @Default(false) bool isEdited, diff --git a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/settings_ai_bloc.dart b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/settings_ai_bloc.dart index 0141283765..9409494b57 100644 --- a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/settings_ai_bloc.dart +++ b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/settings_ai_bloc.dart @@ -93,7 +93,7 @@ class SettingsAIBloc extends Bloc { ), ); }, - didLoadAvailableModels: (AvailableModelsPB models) { + didLoadAvailableModels: (ModelSelectionPB models) { emit( state.copyWith( availableModels: models, @@ -134,7 +134,8 @@ class SettingsAIBloc extends Bloc { ); void _loadModelList() { - AIEventGetServerAvailableModels().send().then((result) { + final payload = ModelSourcePB(source: aiModelsGlobalActiveModel); + AIEventGetSettingModelSelection(payload).send().then((result) { result.fold((models) { if (!isClosed) { add(SettingsAIEvent.didLoadAvailableModels(models)); @@ -175,7 +176,7 @@ class SettingsAIEvent with _$SettingsAIEvent { ) = _DidReceiveUserProfile; const factory SettingsAIEvent.didLoadAvailableModels( - AvailableModelsPB models, + ModelSelectionPB models, ) = _DidLoadAvailableModels; } @@ -184,7 +185,7 @@ class SettingsAIState with _$SettingsAIState { const factory SettingsAIState({ required UserProfilePB userProfile, WorkspaceSettingsPB? aiSettings, - AvailableModelsPB? availableModels, + ModelSelectionPB? availableModels, @Default(true) bool enableSearchIndexing, }) = _SettingsAIState; } diff --git a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart index 83f4ff603e..7357c2951c 100644 --- a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart +++ b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart @@ -30,7 +30,7 @@ class AIModelSelection extends StatelessWidget { final localModels = models.where((model) => model.isLocal).toList(); final cloudModels = models.where((model) => !model.isLocal).toList(); - final selectedModel = state.availableModels!.globalModel; + final selectedModel = state.availableModels!.selectedModel; return Padding( padding: const EdgeInsets.symmetric(vertical: 6), diff --git a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/ollama_setting.dart b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/ollama_setting.dart index fc56fc61e7..10e804e65e 100644 --- a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/ollama_setting.dart +++ b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/ollama_setting.dart @@ -157,7 +157,7 @@ class LocalAIModelSelection extends StatelessWidget { onChanged: (model) => context .read() .add(OllamaSettingEvent.setDefaultModel(model)), - selectedOption: models.globalModel, + selectedOption: models.selectedModel, selectOptionCompare: (left, right) => left?.name == right?.name, options: models.models .map( diff --git a/frontend/rust-lib/flowy-ai/src/ai_manager.rs b/frontend/rust-lib/flowy-ai/src/ai_manager.rs index c21ccef080..eceab53ad9 100644 --- a/frontend/rust-lib/flowy-ai/src/ai_manager.rs +++ b/frontend/rust-lib/flowy-ai/src/ai_manager.rs @@ -1,7 +1,7 @@ use crate::chat::Chat; use crate::entities::{ - AIModelPB, AvailableModelsPB, ChatInfoPB, ChatMessageListPB, ChatMessagePB, ChatSettingsPB, - FilePB, PredefinedFormatPB, RepeatedRelatedQuestionPB, StreamMessageParams, + AIModelPB, ChatInfoPB, ChatMessageListPB, ChatMessagePB, ChatSettingsPB, FilePB, + ModelSelectionPB, PredefinedFormatPB, RepeatedRelatedQuestionPB, StreamMessageParams, }; use crate::local_ai::controller::{LocalAIController, LocalAISetting}; use crate::middleware::chat_service_mw::ChatServiceMiddleware; @@ -514,7 +514,7 @@ impl AIManager { } } - pub async fn get_local_available_models(&self) -> FlowyResult { + pub async fn get_local_available_models(&self) -> FlowyResult { let setting = self.local_ai.get_local_ai_setting(); let mut models = self.local_ai.get_all_chat_local_models().await; let selected_model = AIModel::local(setting.chat_model_name, "".to_string()); @@ -523,13 +523,17 @@ impl AIManager { models.push(selected_model.clone()); } - Ok(AvailableModelsPB { + Ok(ModelSelectionPB { models: models.into_iter().map(AIModelPB::from).collect(), - global_model: AIModelPB::from(selected_model), + selected_model: AIModelPB::from(selected_model), }) } - pub async fn get_available_models(&self, source: String) -> FlowyResult { + pub async fn get_available_models( + &self, + source: String, + setting_only: bool, + ) -> FlowyResult { let is_local_mode = self.user_service.is_local_model().await?; if is_local_mode { return self.get_local_available_models().await; @@ -547,15 +551,19 @@ impl AIManager { // Add local models if enabled if self.local_ai.is_enabled() { - let setting = self.local_ai.get_local_ai_setting(); - all_models.push(AIModel::local(setting.chat_model_name, "".to_string()).into()); + if setting_only { + let setting = self.local_ai.get_local_ai_setting(); + all_models.push(AIModel::local(setting.chat_model_name, "".to_string()).into()); + } else { + all_models.extend(self.local_ai.get_all_chat_local_models().await); + } } // Return early if no models available if all_models.is_empty() { - return Ok(AvailableModelsPB { + return Ok(ModelSelectionPB { models: Vec::new(), - global_model: AIModelPB::default(), + selected_model: AIModelPB::default(), }); } @@ -617,9 +625,9 @@ impl AIManager { trace!("[Model Selection] final active model: {:?}", active_model); // Create response with one transformation pass - Ok(AvailableModelsPB { + Ok(ModelSelectionPB { models: all_models.into_iter().map(AIModelPB::from).collect(), - global_model: AIModelPB::from(active_model), + selected_model: AIModelPB::from(active_model), }) } diff --git a/frontend/rust-lib/flowy-ai/src/entities.rs b/frontend/rust-lib/flowy-ai/src/entities.rs index 5e03fadf56..9ffa2acf33 100644 --- a/frontend/rust-lib/flowy-ai/src/entities.rs +++ b/frontend/rust-lib/flowy-ai/src/entities.rs @@ -182,7 +182,7 @@ pub struct ChatMessageListPB { } #[derive(Default, ProtoBuf, Clone, Debug)] -pub struct ServerAvailableModelsPB { +pub struct ServerModelSelectionPB { #[pb(index = 1)] pub models: Vec, } @@ -200,7 +200,7 @@ pub struct AvailableModelPB { } #[derive(Default, ProtoBuf, Validate, Clone, Debug)] -pub struct AvailableModelsQueryPB { +pub struct ModelSourcePB { #[pb(index = 1)] #[validate(custom(function = "required_not_empty_str"))] pub source: String, @@ -217,12 +217,12 @@ pub struct UpdateSelectedModelPB { } #[derive(Default, ProtoBuf, Clone, Debug)] -pub struct AvailableModelsPB { +pub struct ModelSelectionPB { #[pb(index = 1)] pub models: Vec, #[pb(index = 2)] - pub global_model: AIModelPB, + pub selected_model: AIModelPB, } #[derive(Default, ProtoBuf, Clone, Debug)] diff --git a/frontend/rust-lib/flowy-ai/src/event_handler.rs b/frontend/rust-lib/flowy-ai/src/event_handler.rs index 160fbe6928..fd7bab3298 100644 --- a/frontend/rust-lib/flowy-ai/src/event_handler.rs +++ b/frontend/rust-lib/flowy-ai/src/event_handler.rs @@ -1,4 +1,4 @@ -use crate::ai_manager::{AIManager, GLOBAL_ACTIVE_MODEL_KEY}; +use crate::ai_manager::AIManager; use crate::completion::AICompletion; use crate::entities::*; use flowy_ai_pub::cloud::{AIModel, ChatMessageType}; @@ -77,24 +77,24 @@ pub(crate) async fn regenerate_response_handler( } #[tracing::instrument(level = "debug", skip_all, err)] -pub(crate) async fn get_server_model_list_handler( +pub(crate) async fn get_setting_model_selection_handler( + data: AFPluginData, ai_manager: AFPluginState>, -) -> DataResult { +) -> DataResult { + let data = data.try_into_inner()?; let ai_manager = upgrade_ai_manager(ai_manager)?; - let models = ai_manager - .get_available_models(GLOBAL_ACTIVE_MODEL_KEY.to_string()) - .await?; + let models = ai_manager.get_available_models(data.source, true).await?; data_result_ok(models) } #[tracing::instrument(level = "debug", skip_all, err)] -pub(crate) async fn get_chat_models_handler( - data: AFPluginData, +pub(crate) async fn get_source_model_selection_handler( + data: AFPluginData, ai_manager: AFPluginState>, -) -> DataResult { +) -> DataResult { let data = data.try_into_inner()?; let ai_manager = upgrade_ai_manager(ai_manager)?; - let models = ai_manager.get_available_models(data.source).await?; + let models = ai_manager.get_available_models(data.source, false).await?; data_result_ok(models) } @@ -343,7 +343,7 @@ pub(crate) async fn get_local_ai_setting_handler( #[tracing::instrument(level = "debug", skip_all)] pub(crate) async fn get_local_ai_models_handler( ai_manager: AFPluginState>, -) -> DataResult { +) -> DataResult { let ai_manager = upgrade_ai_manager(ai_manager)?; let data = ai_manager.get_local_available_models().await?; data_result_ok(data) diff --git a/frontend/rust-lib/flowy-ai/src/event_map.rs b/frontend/rust-lib/flowy-ai/src/event_map.rs index ee77f454a5..df0721b361 100644 --- a/frontend/rust-lib/flowy-ai/src/event_map.rs +++ b/frontend/rust-lib/flowy-ai/src/event_map.rs @@ -31,21 +31,24 @@ pub fn init(ai_manager: Weak) -> AFPlugin { .event(AIEvent::ToggleLocalAI, toggle_local_ai_handler) .event(AIEvent::GetLocalAIState, get_local_ai_state_handler) .event(AIEvent::GetLocalAISetting, get_local_ai_setting_handler) - .event(AIEvent::GetLocalAIModels, get_local_ai_models_handler) + .event(AIEvent::GetLocalModelSelection, get_local_ai_models_handler) + .event( + AIEvent::GetSourceModelSelection, + get_source_model_selection_handler, + ) .event( AIEvent::UpdateLocalAISetting, update_local_ai_setting_handler, ) - .event( - AIEvent::GetServerAvailableModels, - get_server_model_list_handler, - ) .event(AIEvent::CreateChatContext, create_chat_context_handler) .event(AIEvent::GetChatInfo, create_chat_context_handler) .event(AIEvent::GetChatSettings, get_chat_settings_handler) .event(AIEvent::UpdateChatSettings, update_chat_settings_handler) .event(AIEvent::RegenerateResponse, regenerate_response_handler) - .event(AIEvent::GetAvailableModels, get_chat_models_handler) + .event( + AIEvent::GetSettingModelSelection, + get_setting_model_selection_handler, + ) .event(AIEvent::UpdateSelectedModel, update_selected_model_handler) } @@ -108,21 +111,21 @@ pub enum AIEvent { #[event(input = "RegenerateResponsePB")] RegenerateResponse = 27, - #[event(output = "AvailableModelsPB")] - GetServerAvailableModels = 28, - #[event(output = "LocalAISettingPB")] GetLocalAISetting = 29, #[event(input = "LocalAISettingPB")] UpdateLocalAISetting = 30, - #[event(input = "AvailableModelsQueryPB", output = "AvailableModelsPB")] - GetAvailableModels = 31, + #[event(input = "ModelSourcePB", output = "ModelSelectionPB")] + GetSettingModelSelection = 31, #[event(input = "UpdateSelectedModelPB")] UpdateSelectedModel = 32, - #[event(output = "AvailableModelsPB")] - GetLocalAIModels = 33, + #[event(output = "ModelSelectionPB")] + GetLocalModelSelection = 33, + + #[event(input = "ModelSourcePB", output = "ModelSelectionPB")] + GetSourceModelSelection = 34, } diff --git a/frontend/rust-lib/flowy-sqlite/src/schema.rs b/frontend/rust-lib/flowy-sqlite/src/schema.rs index bf7f431682..1271589da6 100644 --- a/frontend/rust-lib/flowy-sqlite/src/schema.rs +++ b/frontend/rust-lib/flowy-sqlite/src/schema.rs @@ -140,17 +140,17 @@ diesel::table! { } diesel::allow_tables_to_appear_in_same_query!( - af_collab_metadata, - chat_local_setting_table, - chat_message_table, - chat_table, - collab_snapshot, - local_ai_model_table, - upload_file_part, - upload_file_table, - user_data_migration_records, - user_table, - user_workspace_table, - workspace_members_table, - workspace_setting_table, + af_collab_metadata, + chat_local_setting_table, + chat_message_table, + chat_table, + collab_snapshot, + local_ai_model_table, + upload_file_part, + upload_file_table, + user_data_migration_records, + user_table, + user_workspace_table, + workspace_members_table, + workspace_setting_table, ); From 549e8aee03b9649d8850a690620434416f761159 Mon Sep 17 00:00:00 2001 From: Nathan Date: Sat, 26 Apr 2025 10:04:03 +0800 Subject: [PATCH 5/6] chore: clippy --- frontend/rust-lib/flowy-ai/src/ai_manager.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/rust-lib/flowy-ai/src/ai_manager.rs b/frontend/rust-lib/flowy-ai/src/ai_manager.rs index eceab53ad9..9996947ed8 100644 --- a/frontend/rust-lib/flowy-ai/src/ai_manager.rs +++ b/frontend/rust-lib/flowy-ai/src/ai_manager.rs @@ -553,7 +553,7 @@ impl AIManager { if self.local_ai.is_enabled() { if setting_only { let setting = self.local_ai.get_local_ai_setting(); - all_models.push(AIModel::local(setting.chat_model_name, "".to_string()).into()); + all_models.push(AIModel::local(setting.chat_model_name, "".to_string())); } else { all_models.extend(self.local_ai.get_all_chat_local_models().await); } From f374ca157423b0c7e8bbca118e55abffc898460b Mon Sep 17 00:00:00 2001 From: Nathan Date: Sat, 26 Apr 2025 10:40:33 +0800 Subject: [PATCH 6/6] chore: update --- .../ai/service/ai_model_state_notifier.dart | 212 ++++++++---------- .../prompt_input/select_model_menu.dart | 76 +++---- .../menu/sidebar/shared/sidebar_setting.dart | 13 +- 3 files changed, 135 insertions(+), 166 deletions(-) diff --git a/frontend/appflowy_flutter/lib/ai/service/ai_model_state_notifier.dart b/frontend/appflowy_flutter/lib/ai/service/ai_model_state_notifier.dart index 783046a3d0..1d9ba5bf0e 100644 --- a/frontend/appflowy_flutter/lib/ai/service/ai_model_state_notifier.dart +++ b/frontend/appflowy_flutter/lib/ai/service/ai_model_state_notifier.dart @@ -7,7 +7,6 @@ import 'package:appflowy_backend/log.dart'; import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart'; import 'package:appflowy_result/appflowy_result.dart'; import 'package:easy_localization/easy_localization.dart'; -import 'package:protobuf/protobuf.dart'; import 'package:universal_platform/universal_platform.dart'; typedef OnModelStateChangedCallback = void Function(AIModelState state); @@ -52,25 +51,29 @@ class AIModelStateNotifier { final String objectId; final LocalAIStateListener? _localAIListener; final AIModelSwitchListener _aiModelSwitchListener; - LocalAIPB? _localAIState; - ModelSelectionPB? _sourceModelSelection; - // callbacks + LocalAIPB? _localAIState; + ModelSelectionPB? _modelSelection; + + AIModelState _currentState = _defaultState(); + List _availableModels = []; + AIModelPB? _selectedModel; + final List _stateChangedCallbacks = []; final List _availableModelsChangedCallbacks = []; + /// Starts platform-specific listeners void _startListening() { if (UniversalPlatform.isDesktop) { _localAIListener?.start( stateCallback: (state) async { _localAIState = state; - _notifyStateChanged(); - + _updateAll(); if (state.state == RunningStatePB.Running || state.state == RunningStatePB.Stopped) { await _loadModelSelection(); - _notifyAvailableModelsChanged(); + _updateAll(); } }, ); @@ -78,25 +81,25 @@ class AIModelStateNotifier { _aiModelSwitchListener.start( onUpdateSelectedModel: (model) async { - final updatedModels = _sourceModelSelection?.deepCopy() - ?..selectedModel = model; - _sourceModelSelection = updatedModels; - - _notifyAvailableModelsChanged(); + _selectedModel = model; + _updateAll(); if (model.isLocal && UniversalPlatform.isDesktop) { - await _loadLocalAiState(); + await _loadLocalState(); + _updateAll(); } - _notifyStateChanged(); }, ); } - void _init() async { - await Future.wait([_loadLocalAiState(), _loadModelSelection()]); - _notifyStateChanged(); - _notifyAvailableModelsChanged(); + Future _init() async { + await Future.wait([ + if (UniversalPlatform.isDesktop) _loadLocalState(), + _loadModelSelection(), + ]); + _updateAll(); } + /// Register callbacks for state or available-models changes void addListener({ OnModelStateChangedCallback? onStateChanged, OnAvailableModelsChangedCallback? onAvailableModelsChanged, @@ -109,6 +112,7 @@ class AIModelStateNotifier { } } + /// Remove previously registered callbacks void removeListener({ OnModelStateChangedCallback? onStateChanged, OnAvailableModelsChangedCallback? onAvailableModelsChanged, @@ -128,116 +132,88 @@ class AIModelStateNotifier { await _aiModelSwitchListener.stop(); } - AIModelState getState() { - if (UniversalPlatform.isMobile) { - return AIModelState( - type: AiType.cloud, - hintText: LocaleKeys.chat_inputMessageHint.tr(), - tooltip: null, - isEditable: true, - localAIEnabled: false, - ); + /// Returns current AIModelState + AIModelState getState() => _currentState; + + /// Returns available models and the selected model + (List, AIModelPB?) getModelSelection() => + (_availableModels, _selectedModel); + + void _updateAll() { + _currentState = _computeState(); + for (final cb in _stateChangedCallbacks) { + cb(_currentState); } - - final availableModels = _sourceModelSelection; - final localAiState = _localAIState; - - if (availableModels == null) { - return AIModelState( - type: AiType.cloud, - hintText: LocaleKeys.chat_inputMessageHint.tr(), - isEditable: true, - tooltip: null, - localAIEnabled: false, - ); - } - if (localAiState == null) { - return AIModelState( - type: AiType.cloud, - hintText: LocaleKeys.chat_inputMessageHint.tr(), - tooltip: null, - isEditable: true, - localAIEnabled: false, - ); + for (final cb in _availableModelsChangedCallbacks) { + cb(_availableModels, _selectedModel); } + } - if (!availableModels.selectedModel.isLocal) { - return AIModelState( - type: AiType.cloud, - hintText: LocaleKeys.chat_inputMessageHint.tr(), - tooltip: null, - isEditable: true, - localAIEnabled: false, - ); - } - - final editable = localAiState.state == RunningStatePB.Running; - final tooltip = localAiState.enabled - ? (editable - ? null - : LocaleKeys.settings_aiPage_keys_localAINotReadyTextFieldPrompt - .tr()) - : LocaleKeys.settings_aiPage_keys_localAIDisabledTextFieldPrompt.tr(); - - final hintText = localAiState.enabled - ? (editable - ? LocaleKeys.chat_inputLocalAIMessageHint.tr() - : LocaleKeys.settings_aiPage_keys_localAIInitializing.tr()) - : LocaleKeys.settings_aiPage_keys_localAIDisabled.tr(); - - return AIModelState( - type: AiType.local, - hintText: hintText, - tooltip: tooltip, - isEditable: editable, - localAIEnabled: localAiState.enabled, + Future _loadModelSelection() async { + await AIEventGetSourceModelSelection( + ModelSourcePB(source: objectId), + ).send().fold( + (ms) { + _modelSelection = ms; + _availableModels = ms.models; + _selectedModel = ms.selectedModel; + }, + (e) => Log.error("Failed to fetch models: \$e"), ); } - (List, AIModelPB?) getModelSelection() { - final availableModels = _sourceModelSelection; - if (availableModels == null) { - return ([], null); - } - return (availableModels.models, availableModels.selectedModel); - } - - void _notifyAvailableModelsChanged() { - final (models, selectedModel) = getModelSelection(); - for (final callback in _availableModelsChangedCallbacks) { - callback(models, selectedModel); - } - } - - void _notifyStateChanged() { - final state = getState(); - for (final callback in _stateChangedCallbacks) { - callback(state); - } - } - - Future _loadModelSelection() { - final payload = ModelSourcePB(source: objectId); - return AIEventGetSourceModelSelection(payload).send().fold( - (models) => _sourceModelSelection = models, - (err) => Log.error("Failed to get available models: $err"), + Future _loadLocalState() async { + await AIEventGetLocalAIState().send().fold( + (s) => _localAIState = s, + (e) => Log.error("Failed to fetch local AI state: \$e"), ); } - Future _loadLocalAiState() { - return AIEventGetLocalAIState().send().fold( - (localAIState) => _localAIState = localAIState, - (error) => Log.error("Failed to get local AI state: $error"), - ); + static AIModelState _defaultState() => AIModelState( + type: AiType.cloud, + hintText: LocaleKeys.chat_inputMessageHint.tr(), + tooltip: null, + isEditable: true, + localAIEnabled: false, + ); + + /// Core logic computing the state from local and selection data + AIModelState _computeState() { + if (UniversalPlatform.isMobile) return _defaultState(); + + if (_modelSelection == null || _localAIState == null) { + return _defaultState(); + } + + if (!_selectedModel!.isLocal) { + return _defaultState(); + } + + final enabled = _localAIState!.enabled; + final running = _localAIState!.state == RunningStatePB.Running; + final hintKey = enabled + ? (running + ? LocaleKeys.chat_inputLocalAIMessageHint + : LocaleKeys.settings_aiPage_keys_localAIInitializing) + : LocaleKeys.settings_aiPage_keys_localAIDisabled; + final tooltipKey = enabled + ? (running + ? null + : LocaleKeys.settings_aiPage_keys_localAINotReadyTextFieldPrompt) + : LocaleKeys.settings_aiPage_keys_localAIDisabledTextFieldPrompt; + + return AIModelState( + type: AiType.local, + hintText: hintKey.tr(), + tooltip: tooltipKey?.tr(), + isEditable: running, + localAIEnabled: enabled, + ); } } -extension AiModelExtension on AIModelPB { - bool get isDefault { - return name == "Auto"; - } - - String get i18n { - return isDefault ? LocaleKeys.chat_switchModel_autoModel.tr() : name; - } +extension AIModelPBExtension on AIModelPB { + bool get isDefault => name == 'Auto'; + String get i18n => + isDefault ? LocaleKeys.chat_switchModel_autoModel.tr() : name; } diff --git a/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart b/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart index 317f90ac21..d826be78d8 100644 --- a/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart +++ b/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart @@ -217,51 +217,41 @@ class _CurrentModelButton extends StatelessWidget { behavior: HitTestBehavior.opaque, child: SizedBox( height: DesktopAIPromptSizes.actionBarButtonSize, - child: AnimatedSize( - duration: const Duration(milliseconds: 200), - curve: Curves.easeOutCubic, - alignment: AlignmentDirectional.centerStart, - clipBehavior: Clip.none, - child: FlowyHover( - style: const HoverStyle( - borderRadius: BorderRadius.all(Radius.circular(8)), - ), - child: Padding( - padding: const EdgeInsetsDirectional.all(4.0), - child: Row( - mainAxisSize: MainAxisSize.min, - children: [ - Padding( - // TODO: remove this after change icon to 20px - padding: EdgeInsets.all(2), - child: FlowySvg( - FlowySvgs.ai_sparks_s, - color: Theme.of(context).hintColor, - size: Size.square(16), - ), - ), - if (model != null && !model!.isDefault) - AnimatedSize( - duration: const Duration(milliseconds: 150), - curve: Curves.easeOutCubic, - child: Padding( - padding: EdgeInsetsDirectional.only(end: 2.0), - child: FlowyText( - model!.i18n, - fontSize: 12, - figmaLineHeight: 16, - color: Theme.of(context).hintColor, - overflow: TextOverflow.ellipsis, - ), - ), - ), - FlowySvg( - FlowySvgs.ai_source_drop_down_s, + child: FlowyHover( + style: const HoverStyle( + borderRadius: BorderRadius.all(Radius.circular(8)), + ), + child: Padding( + padding: const EdgeInsetsDirectional.all(4.0), + child: Row( + mainAxisSize: MainAxisSize.min, + children: [ + Padding( + // TODO: remove this after change icon to 20px + padding: EdgeInsets.all(2), + child: FlowySvg( + FlowySvgs.ai_sparks_s, color: Theme.of(context).hintColor, - size: const Size.square(8), + size: Size.square(16), ), - ], - ), + ), + if (model != null && !model!.isDefault) + Padding( + padding: EdgeInsetsDirectional.only(end: 2.0), + child: FlowyText( + model!.i18n, + fontSize: 12, + figmaLineHeight: 16, + color: Theme.of(context).hintColor, + overflow: TextOverflow.ellipsis, + ), + ), + FlowySvg( + FlowySvgs.ai_source_drop_down_s, + color: Theme.of(context).hintColor, + size: const Size.square(8), + ), + ], ), ), ), diff --git a/frontend/appflowy_flutter/lib/workspace/presentation/home/menu/sidebar/shared/sidebar_setting.dart b/frontend/appflowy_flutter/lib/workspace/presentation/home/menu/sidebar/shared/sidebar_setting.dart index d1e32985fa..d21605bc00 100644 --- a/frontend/appflowy_flutter/lib/workspace/presentation/home/menu/sidebar/shared/sidebar_setting.dart +++ b/frontend/appflowy_flutter/lib/workspace/presentation/home/menu/sidebar/shared/sidebar_setting.dart @@ -31,7 +31,10 @@ HotKeyItem openSettingsHotKey( ), keyDownHandler: (_) { if (_settingsDialogKey.currentContext == null) { - showSettingsDialog(context); + showSettingsDialog( + context, + userWorkspaceBloc: context.read(), + ); } else { Navigator.of(context, rootNavigator: true) .popUntil((route) => route.isFirst); @@ -110,7 +113,7 @@ class _UserSettingButtonState extends State { void showSettingsDialog( BuildContext context, { - UserWorkspaceBloc? userWorkspaceBloc, + required UserWorkspaceBloc userWorkspaceBloc, PasswordBloc? passwordBloc, SettingsPage? initPage, }) { @@ -126,7 +129,7 @@ void showSettingsDialog( ) : BlocProvider( create: (context) => PasswordBloc( - context.read().state.userProfile, + userWorkspaceBloc.state.userProfile, ) ..add(PasswordEvent.init()) ..add(PasswordEvent.checkHasPassword()), @@ -135,11 +138,11 @@ void showSettingsDialog( value: BlocProvider.of(dialogContext), ), BlocProvider.value( - value: userWorkspaceBloc ?? context.read(), + value: userWorkspaceBloc, ), ], child: SettingsDialog( - context.read().state.userProfile, + userWorkspaceBloc.state.userProfile, initPage: initPage, didLogout: () async { // Pop the dialog using the dialog context