refactor: Model select (#7875)

* refactor: model select

* refactor: add test

* fix: add source

* fix: add source

* chore: notify all unset source

* chore: fix test
This commit is contained in:
Nathan.fooo 2025-05-02 15:37:40 +08:00 committed by Nathan
parent 4e2723f917
commit 529b5e5080
17 changed files with 1198 additions and 270 deletions

View File

@ -11,7 +11,7 @@ import 'package:freezed_annotation/freezed_annotation.dart';
part 'settings_ai_bloc.freezed.dart';
const String aiModelsGlobalActiveModel = "ai_models_global_active_model";
const String aiModelsGlobalActiveModel = "global_active_model";
class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
SettingsAIBloc(
@ -75,9 +75,6 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
);
},
selectModel: (AIModelPB model) async {
if (!model.isLocal) {
await _updateUserWorkspaceSetting(model: model.name);
}
await AIEventUpdateSelectedModel(
UpdateSelectedModelPB(
source: aiModelsGlobalActiveModel,

View File

@ -17,8 +17,6 @@ class AIModelSelection extends StatelessWidget {
@override
Widget build(BuildContext context) {
return BlocBuilder<SettingsAIBloc, SettingsAIState>(
buildWhen: (previous, current) =>
previous.availableModels != current.availableModels,
builder: (context, state) {
final models = state.availableModels?.models;
if (models == null) {
@ -44,7 +42,7 @@ class AIModelSelection extends StatelessWidget {
),
Flexible(
child: SettingsDropdown<AIModelPB>(
key: const Key('_AIModelSelection'),
key: ValueKey(selectedModel.name),
onChanged: (model) => context
.read<SettingsAIBloc>()
.add(SettingsAIEvent.selectModel(model)),

View File

@ -33,6 +33,17 @@ pub struct AIModel {
pub desc: String,
}
impl AIModel {
/// Create a new model instance
pub fn new(name: impl Into<String>, description: impl Into<String>, is_local: bool) -> Self {
Self {
name: name.into(),
desc: description.into(),
is_local,
}
}
}
impl From<AvailableModel> for AIModel {
fn from(value: AvailableModel) -> Self {
let desc = value
@ -175,4 +186,9 @@ pub trait ChatCloudService: Send + Sync + 'static {
async fn get_available_models(&self, workspace_id: &Uuid) -> Result<ModelList, FlowyError>;
async fn get_workspace_default_model(&self, workspace_id: &Uuid) -> Result<String, FlowyError>;
async fn set_workspace_default_model(
&self,
workspace_id: &Uuid,
model: &str,
) -> Result<(), FlowyError>;
}

View File

@ -9,27 +9,26 @@ use flowy_ai_pub::persistence::read_chat_metadata;
use std::collections::HashMap;
use dashmap::DashMap;
use flowy_ai_pub::cloud::{
AIModel, ChatCloudService, ChatSettings, UpdateChatParams, DEFAULT_AI_MODEL_NAME,
};
use flowy_ai_pub::cloud::{AIModel, ChatCloudService, ChatSettings, UpdateChatParams};
use flowy_error::{ErrorCode, FlowyError, FlowyResult};
use flowy_sqlite::kv::KVStorePreferences;
use crate::model_select::{
LocalAiSource, LocalModelStorageImpl, ModelSelectionControl, ServerAiSource,
ServerModelStorageImpl, SourceKey, GLOBAL_ACTIVE_MODEL_KEY,
};
use crate::notification::{chat_notification_builder, ChatNotification};
use crate::util::ai_available_models_key;
use collab_integrate::persistence::collab_metadata_sql::{
batch_insert_collab_metadata, batch_select_collab_metadata, AFCollabMetadata,
};
use flowy_ai_pub::cloud::ai_dto::AvailableModel;
use flowy_ai_pub::user_service::AIUserService;
use flowy_storage_pub::storage::StorageService;
use lib_infra::async_trait::async_trait;
use lib_infra::util::timestamp;
use serde_json::json;
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::{Arc, Weak};
use tokio::sync::RwLock;
use tokio::sync::Mutex;
use tracing::{error, info, instrument, trace};
use uuid::Uuid;
@ -52,14 +51,6 @@ pub trait AIExternalService: Send + Sync + 'static {
async fn notify_did_send_message(&self, chat_id: &Uuid, message: &str) -> Result<(), FlowyError>;
}
#[derive(Debug, Default)]
struct ServerModelsCache {
models: Vec<AvailableModel>,
timestamp: Option<i64>,
}
pub const GLOBAL_ACTIVE_MODEL_KEY: &str = "global_active_model";
pub struct AIManager {
pub cloud_service_wm: Arc<ChatServiceMiddleware>,
pub user_service: Arc<dyn AIUserService>,
@ -67,7 +58,7 @@ pub struct AIManager {
chats: Arc<DashMap<Uuid, Arc<Chat>>>,
pub local_ai: Arc<LocalAIController>,
pub store_preferences: Arc<KVStorePreferences>,
server_models: Arc<RwLock<ServerModelsCache>>,
model_control: Mutex<ModelSelectionControl>,
}
impl Drop for AIManager {
fn drop(&mut self) {
@ -97,6 +88,10 @@ impl AIManager {
local_ai.clone(),
storage_service,
));
let mut model_control = ModelSelectionControl::new();
model_control.set_local_storage(LocalModelStorageImpl(store_preferences.clone()));
model_control.set_server_storage(ServerModelStorageImpl(cloud_service_wm.clone()));
model_control.add_source(Box::new(ServerAiSource::new(cloud_service_wm.clone())));
Self {
cloud_service_wm,
@ -105,7 +100,7 @@ impl AIManager {
local_ai,
external_service,
store_preferences,
server_models: Arc::new(Default::default()),
model_control: Mutex::new(model_control),
}
}
@ -134,10 +129,6 @@ impl AIManager {
info!("[AI Manager] Local AI is running but not enabled, shutting it down");
let local_ai = self.local_ai.clone();
tokio::spawn(async move {
// Wait for 5 seconds to allow other services to initialize
// TODO: pick a right time to start plugin service. Maybe [UserStatusCallback::did_launch]
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
if let Err(err) = local_ai.toggle_plugin(false).await {
error!("[AI Manager] failed to shutdown local AI: {:?}", err);
}
@ -150,10 +141,6 @@ impl AIManager {
info!("[AI Manager] Local AI is enabled but not running, starting it now");
let local_ai = self.local_ai.clone();
tokio::spawn(async move {
// Wait for 5 seconds to allow other services to initialize
// TODO: pick a right time to start plugin service. Maybe [UserStatusCallback::did_launch]
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
if let Err(err) = local_ai.toggle_plugin(true).await {
error!("[AI Manager] failed to start local AI: {:?}", err);
}
@ -167,19 +154,41 @@ impl AIManager {
}
}
async fn prepare_local_ai(&self, workspace_id: &Uuid) {
self
.local_ai
.reload_ollama_client(&workspace_id.to_string());
self
.model_control
.lock()
.await
.add_source(Box::new(LocalAiSource::new(self.local_ai.clone())));
}
#[instrument(skip_all, err)]
pub async fn on_launch_if_authenticated(&self, workspace_id: &Uuid) -> Result<(), FlowyError> {
let is_enabled = self
.local_ai
.is_enabled_on_workspace(&workspace_id.to_string());
info!("local is enabled: {}", is_enabled);
if is_enabled {
self.prepare_local_ai(workspace_id).await;
} else {
self.model_control.lock().await.remove_local_source();
}
self.reload_with_workspace_id(workspace_id).await;
Ok(())
}
pub async fn initialize_after_sign_in(&self, workspace_id: &Uuid) -> Result<(), FlowyError> {
self.reload_with_workspace_id(workspace_id).await;
self.on_launch_if_authenticated(workspace_id).await?;
Ok(())
}
pub async fn initialize_after_sign_up(&self, workspace_id: &Uuid) -> Result<(), FlowyError> {
self.reload_with_workspace_id(workspace_id).await;
self.on_launch_if_authenticated(workspace_id).await?;
Ok(())
}
@ -188,7 +197,7 @@ impl AIManager {
&self,
workspace_id: &Uuid,
) -> Result<(), FlowyError> {
self.reload_with_workspace_id(workspace_id).await;
self.on_launch_if_authenticated(workspace_id).await?;
Ok(())
}
@ -309,7 +318,7 @@ impl AIManager {
) -> Result<ChatMessagePB, FlowyError> {
let chat = self.get_or_create_chat_instance(&params.chat_id).await?;
let ai_model = self.get_active_model(&params.chat_id.to_string()).await;
let question = chat.stream_chat_message(&params, ai_model).await?;
let question = chat.stream_chat_message(&params, Some(ai_model)).await?;
let _ = self
.external_service
.notify_did_send_message(&params.chat_id, &params.message)
@ -332,15 +341,16 @@ impl AIManager {
let model = match model {
None => self.get_active_model(&chat_id.to_string()).await,
Some(model) => Some(model.into()),
Some(model) => model.into(),
};
chat
.stream_regenerate_response(question_message_id, answer_stream_port, format, model)
.stream_regenerate_response(question_message_id, answer_stream_port, format, Some(model))
.await?;
Ok(())
}
pub async fn update_local_ai_setting(&self, setting: LocalAISetting) -> FlowyResult<()> {
let workspace_id = self.user_service.workspace_id()?;
let old_settings = self.local_ai.get_local_ai_setting();
// Only restart if the server URL has changed and local AI is not running
let need_restart =
@ -353,172 +363,137 @@ impl AIManager {
.await?;
// Handle model change if needed
let model_changed = old_settings.chat_model_name != setting.chat_model_name;
if model_changed {
info!(
"[AI Plugin] update global active model, previous: {}, current: {}",
old_settings.chat_model_name, setting.chat_model_name
);
let model = AIModel::local(setting.chat_model_name, "".to_string());
self
.update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model)
.await?;
}
info!(
"[AI Plugin] update global active model, previous: {}, current: {}",
old_settings.chat_model_name, setting.chat_model_name
);
let model = AIModel::local(setting.chat_model_name, "".to_string());
self
.update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model)
.await?;
if need_restart {
self
.local_ai
.reload_ollama_client(&workspace_id.to_string());
self.local_ai.restart_plugin().await;
}
Ok(())
}
async fn get_workspace_select_model(&self) -> FlowyResult<String> {
#[instrument(skip_all, level = "debug")]
pub async fn update_selected_model(&self, source: String, model: AIModel) -> FlowyResult<()> {
let workspace_id = self.user_service.workspace_id()?;
let model = self
.cloud_service_wm
.get_workspace_default_model(&workspace_id)
let source_key = SourceKey::new(source.clone());
self
.model_control
.lock()
.await
.set_active_model(&workspace_id, &source_key, model.clone())
.await?;
if model.is_empty() {
return Ok(DEFAULT_AI_MODEL_NAME.to_string());
}
Ok(model)
}
async fn get_server_available_models(&self) -> FlowyResult<Vec<AvailableModel>> {
let workspace_id = self.user_service.workspace_id()?;
let now = timestamp();
// First, try reading from the cache with expiration check
let should_fetch = {
let cached_models = self.server_models.read().await;
cached_models.models.is_empty() || cached_models.timestamp.map_or(true, |ts| now - ts >= 300)
};
if !should_fetch {
// Cache is still valid, return cached data
let cached_models = self.server_models.read().await;
return Ok(cached_models.models.clone());
}
// Cache miss or expired: fetch from the cloud.
match self
.cloud_service_wm
.get_available_models(&workspace_id)
.await
{
Ok(list) => {
let models = list.models;
if let Err(err) = self.update_models_cache(&models, now).await {
error!("Failed to update models cache: {}", err);
}
Ok(models)
},
Err(err) => {
error!("Failed to fetch available models: {}", err);
// Return cached data if available, even if expired
let cached_models = self.server_models.read().await;
if !cached_models.models.is_empty() {
info!("Returning expired cached models due to fetch failure");
return Ok(cached_models.models.clone());
}
// If no cached data, return empty list
Ok(Vec::new())
},
}
}
async fn update_models_cache(
&self,
models: &[AvailableModel],
timestamp: i64,
) -> FlowyResult<()> {
match self.server_models.try_write() {
Ok(mut cache) => {
cache.models = models.to_vec();
cache.timestamp = Some(timestamp);
Ok(())
},
Err(_) => {
// Handle lock acquisition failure
Err(FlowyError::internal().with_context("Failed to acquire write lock for models cache"))
},
}
}
pub async fn update_selected_model(&self, source: String, model: AIModel) -> FlowyResult<()> {
let source_key = ai_available_models_key(&source);
info!(
"[Model Selection] update {} selected model: {:?} for key:{}",
source, model, source_key
"[Model Selection] selected model: {:?} for key:{}",
model,
source_key.storage_id()
);
self
.store_preferences
.set_object::<AIModel>(&source_key, &model)?;
chat_notification_builder(&source_key, ChatNotification::DidUpdateSelectedModel)
.payload(AIModelPB::from(model))
.send();
let mut notify_source = vec![source.clone()];
if source == GLOBAL_ACTIVE_MODEL_KEY {
let ids = self
.model_control
.lock()
.await
.get_all_unset_sources()
.await;
info!("[Model Selection] notify all unset sources: {:?}", ids);
notify_source.extend(ids);
}
trace!("[Model Selection] notify sources: {:?}", notify_source);
for source in notify_source {
chat_notification_builder(&source, ChatNotification::DidUpdateSelectedModel)
.payload(AIModelPB::from(model.clone()))
.send();
}
Ok(())
}
#[instrument(skip_all, level = "debug")]
#[instrument(skip_all, level = "debug", err)]
pub async fn toggle_local_ai(&self) -> FlowyResult<()> {
let enabled = self.local_ai.toggle_local_ai().await?;
let workspace_id = self.user_service.workspace_id()?;
if enabled {
self.prepare_local_ai(&workspace_id).await;
if let Some(name) = self.local_ai.get_plugin_chat_model() {
info!("Set global active model to local ai: {}", name);
let model = AIModel::local(name, "".to_string());
self
info!(
"[Model Selection] Set global active model to local ai: {}",
model.name
);
if let Err(err) = self
.update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model)
.await?;
.await
{
error!(
"[Model Selection] Failed to set global active model: {}",
err
);
}
}
} else {
info!("Set global active model to default");
let global_active_model = self.get_workspace_select_model().await?;
let models = self.get_server_available_models().await?;
if let Some(model) = models.into_iter().find(|m| m.name == global_active_model) {
self
.update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), AIModel::from(model))
.await?;
let mut model_control = self.model_control.lock().await;
model_control.remove_local_source();
let model = model_control.get_global_active_model(&workspace_id).await;
let mut notify_source = model_control.get_all_unset_sources().await;
notify_source.push(GLOBAL_ACTIVE_MODEL_KEY.to_string());
drop(model_control);
trace!(
"[Model Selection] notify sources: {:?}, model:{}, when disable local ai",
notify_source,
model.name
);
for source in notify_source {
chat_notification_builder(&source, ChatNotification::DidUpdateSelectedModel)
.payload(AIModelPB::from(model.clone()))
.send();
}
}
Ok(())
}
pub async fn get_active_model(&self, source: &str) -> Option<AIModel> {
let mut model = self
.store_preferences
.get_object::<AIModel>(&ai_available_models_key(source));
match model {
None => {
if let Some(local_model) = self.local_ai.get_plugin_chat_model() {
model = Some(AIModel::local(local_model, "".to_string()));
}
model
},
Some(mut model) => {
let models = self.local_ai.get_all_chat_local_models().await;
if !models.contains(&model) {
if let Some(local_model) = self.local_ai.get_plugin_chat_model() {
model = AIModel::local(local_model, "".to_string());
}
}
Some(model)
pub async fn get_active_model(&self, source: &str) -> AIModel {
match self.user_service.workspace_id() {
Ok(workspace_id) => {
let source_key = SourceKey::new(source.to_string());
self
.model_control
.lock()
.await
.get_active_model(&workspace_id, &source_key)
.await
},
Err(_) => AIModel::default(),
}
}
pub async fn get_local_available_models(&self) -> FlowyResult<ModelSelectionPB> {
let setting = self.local_ai.get_local_ai_setting();
let mut models = self.local_ai.get_all_chat_local_models().await;
let selected_model = AIModel::local(setting.chat_model_name, "".to_string());
let workspace_id = self.user_service.workspace_id()?;
let mut models = self
.model_control
.lock()
.await
.get_local_models(&workspace_id)
.await;
let selected_model = AIModel::local(setting.chat_model_name, "".to_string());
if models.is_empty() {
models.push(selected_model.clone());
}
@ -539,92 +514,23 @@ impl AIManager {
return self.get_local_available_models().await;
}
// Fetch server models
let mut all_models: Vec<AIModel> = self
.get_server_available_models()
.await?
.into_iter()
.map(AIModel::from)
.collect();
trace!("[Model Selection]: Available models: {:?}", all_models);
// Add local models if enabled
if self.local_ai.is_enabled() {
if setting_only {
let setting = self.local_ai.get_local_ai_setting();
all_models.push(AIModel::local(setting.chat_model_name, "".to_string()));
} else {
all_models.extend(self.local_ai.get_all_chat_local_models().await);
}
}
// Return early if no models available
if all_models.is_empty() {
return Ok(ModelSelectionPB {
models: Vec::new(),
selected_model: AIModelPB::default(),
});
}
// Get server active model (only once)
let server_active_model = self
.get_workspace_select_model()
.await
.map(|m| AIModel::server(m, "".to_string()))
.unwrap_or_else(|_| AIModel::default());
trace!(
"[Model Selection] server active model: {:?}",
server_active_model
);
// Use server model as default if it exists in available models
let default_model = if all_models
.iter()
.any(|m| m.name == server_active_model.name)
{
server_active_model.clone()
let workspace_id = self.user_service.workspace_id()?;
let local_model_name = if setting_only {
Some(self.local_ai.get_local_ai_setting().chat_model_name)
} else {
AIModel::default()
None
};
// Get user's previously selected model
let user_selected_model = match self.get_active_model(&source).await {
Some(model) => {
trace!("[Model Selection] user previous select model: {:?}", model);
model
},
None => {
// When no selected model and local AI is active, use local AI model
all_models
.iter()
.find(|m| m.is_local)
.cloned()
.unwrap_or_else(|| default_model.clone())
},
};
let source_key = SourceKey::new(source);
let model_control = self.model_control.lock().await;
let active_model = model_control
.get_active_model(&workspace_id, &source_key)
.await;
let all_models = model_control
.get_models_with_specific_local_model(&workspace_id, local_model_name)
.await;
drop(model_control);
// Determine final active model - use user's selection if available, otherwise default
let active_model = all_models
.iter()
.find(|m| m.name == user_selected_model.name)
.cloned()
.unwrap_or(default_model.clone());
// Update stored preference if changed
if active_model.name != user_selected_model.name {
if let Err(err) = self
.update_selected_model(source, active_model.clone())
.await
{
error!("[Model Selection] failed to update selected model: {}", err);
}
}
trace!("[Model Selection] final active model: {:?}", active_model);
// Create response with one transformation pass
Ok(ModelSelectionPB {
models: all_models.into_iter().map(AIModelPB::from).collect(),
selected_model: AIModelPB::from(active_model),
@ -696,7 +602,9 @@ impl AIManager {
) -> Result<RepeatedRelatedQuestionPB, FlowyError> {
let chat = self.get_or_create_chat_instance(chat_id).await?;
let ai_model = self.get_active_model(&chat_id.to_string()).await;
let resp = chat.get_related_question(message_id, ai_model).await?;
let resp = chat
.get_related_question(message_id, Some(ai_model))
.await?;
Ok(resp)
}

View File

@ -98,6 +98,7 @@ pub(crate) async fn get_source_model_selection_handler(
data_result_ok(models)
}
#[tracing::instrument(level = "debug", skip_all, err)]
pub(crate) async fn update_selected_model_handler(
data: AFPluginData<UpdateSelectedModelPB>,
ai_manager: AFPluginState<Weak<AIManager>>,
@ -192,7 +193,7 @@ pub(crate) async fn start_complete_text_handler(
let data = data.into_inner();
let ai_manager = upgrade_ai_manager(ai_manager)?;
let ai_model = ai_manager.get_active_model(&data.object_id).await;
let task = tools.create_complete_task(data, ai_model).await?;
let task = tools.create_complete_task(data, Some(ai_model)).await?;
data_result_ok(task)
}

View File

@ -11,8 +11,10 @@ pub mod local_ai;
// pub mod mcp;
mod middleware;
mod model_select;
#[cfg(test)]
mod model_select_test;
pub mod notification;
pub mod offline;
mod protobuf;
mod stream_message;
mod util;

View File

@ -59,7 +59,7 @@ pub struct LocalAIController {
current_chat_id: ArcSwapOption<Uuid>,
store_preferences: Weak<KVStorePreferences>,
user_service: Arc<dyn AIUserService>,
ollama: ArcSwapOption<Ollama>,
pub(crate) ollama: ArcSwapOption<Ollama>,
}
impl Deref for LocalAIController {
@ -174,6 +174,33 @@ impl LocalAIController {
ollama,
}
}
pub fn reload_ollama_client(&self, workspace_id: &str) {
if !self.is_enabled_on_workspace(workspace_id) {
return;
}
let setting = self.resource.get_llm_setting();
if let Some(ollama) = self.ollama.load_full() {
if ollama.url_str() == setting.ollama_server_url {
info!("[Local AI] ollama client is already initialized");
return;
}
}
info!("[Local AI] reloading ollama client");
match Ollama::try_new(setting.ollama_server_url).map(Arc::new) {
Ok(new_ollama) => {
self.ollama.store(Some(new_ollama.clone()));
},
Err(err) => error!(
"failed to create ollama client: {:?}, thread: {:?}",
err,
std::thread::current().id()
),
}
}
#[instrument(level = "debug", skip_all)]
pub async fn observe_plugin_resource(&self) {
let sys = get_operating_system();
@ -244,11 +271,11 @@ impl LocalAIController {
}
pub fn is_enabled_on_workspace(&self, workspace_id: &str) -> bool {
let key = local_ai_enabled_key(workspace_id);
if !get_operating_system().is_desktop() {
return false;
}
let key = local_ai_enabled_key(workspace_id);
match self.upgrade_store_preferences() {
Ok(store) => store.get_bool(&key).unwrap_or(false),
Err(_) => false,

View File

@ -372,4 +372,15 @@ impl ChatCloudService for ChatServiceMiddleware {
.get_workspace_default_model(workspace_id)
.await
}
async fn set_workspace_default_model(
&self,
workspace_id: &Uuid,
model: &str,
) -> Result<(), FlowyError> {
self
.cloud_service
.set_workspace_default_model(workspace_id, model)
.await
}
}

View File

@ -0,0 +1,471 @@
use crate::local_ai::controller::LocalAIController;
use arc_swap::ArcSwapOption;
use flowy_ai_pub::cloud::{AIModel, ChatCloudService};
use flowy_error::{FlowyError, FlowyResult};
use flowy_sqlite::kv::KVStorePreferences;
use lib_infra::async_trait::async_trait;
use lib_infra::util::timestamp;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{error, info, trace};
use uuid::Uuid;
type Model = AIModel;
pub const GLOBAL_ACTIVE_MODEL_KEY: &str = "global_active_model";
/// Manages multiple sources and provides operations for model selection
pub struct ModelSelectionControl {
sources: Vec<Box<dyn ModelSource>>,
default_model: Model,
local_storage: ArcSwapOption<Box<dyn UserModelStorage>>,
server_storage: ArcSwapOption<Box<dyn UserModelStorage>>,
unset_sources: RwLock<HashSet<String>>,
}
impl ModelSelectionControl {
/// Create a new manager with the given storage backends
pub fn new() -> Self {
let default_model = Model::default();
Self {
sources: Vec::new(),
default_model,
local_storage: ArcSwapOption::new(None),
server_storage: ArcSwapOption::new(None),
unset_sources: Default::default(),
}
}
/// Replace the local storage backend at runtime
pub fn set_local_storage(&self, storage: impl UserModelStorage + 'static) {
self.local_storage.store(Some(Arc::new(Box::new(storage))));
}
/// Replace the server storage backend at runtime
pub fn set_server_storage(&self, storage: impl UserModelStorage + 'static) {
self.server_storage.store(Some(Arc::new(Box::new(storage))));
}
/// Add a new model source at runtime
pub fn add_source(&mut self, source: Box<dyn ModelSource>) {
info!("[Model Selection] Adding source: {}", source.source_name());
// remove existing source with the same name
self
.sources
.retain(|s| s.source_name() != source.source_name());
self.sources.push(source);
}
/// Remove all sources matching the given name
pub fn remove_local_source(&mut self) {
info!("[Model Selection] Removing local source");
self
.sources
.retain(|source| source.source_name() != "local");
}
/// Asynchronously aggregate models from all sources, or return the default if none found
pub async fn get_models(&self, workspace_id: &Uuid) -> Vec<Model> {
let mut models = Vec::new();
for source in &self.sources {
let mut list = source.list_chat_models(workspace_id).await;
models.append(&mut list);
}
if models.is_empty() {
vec![self.default_model.clone()]
} else {
models
}
}
/// Fetches all serverside models and, if specified, a single local model by name.
///
/// First collects models from any source named `"server"`. Then it fetches all local models
/// (from the `"local"` source) and:
/// - If `local_model_name` is `Some(name)`, it will append exactly that local model
/// if it exists.
/// - If `local_model_name` is `None`, it will append *all* local models.
///
pub async fn get_models_with_specific_local_model(
&self,
workspace_id: &Uuid,
local_model_name: Option<String>,
) -> Vec<Model> {
let mut models = Vec::new();
// add server models
for source in &self.sources {
if source.source_name() == "server" {
let mut list = source.list_chat_models(workspace_id).await;
models.append(&mut list);
}
}
// check input local model present in local models
let local_models = self.get_local_models(workspace_id).await;
match local_model_name {
Some(name) => {
local_models.into_iter().for_each(|model| {
if model.name == name {
models.push(model);
}
});
},
None => {
models.extend(local_models);
},
}
models
}
pub async fn get_local_models(&self, workspace_id: &Uuid) -> Vec<Model> {
for source in &self.sources {
if source.source_name() == "local" {
return source.list_chat_models(workspace_id).await;
}
}
vec![]
}
pub async fn get_all_unset_sources(&self) -> Vec<String> {
let unset_sources = self.unset_sources.read().await;
unset_sources.iter().cloned().collect()
}
pub async fn get_global_active_model(&self, workspace_id: &Uuid) -> Model {
self
.get_active_model(
workspace_id,
&SourceKey::new(GLOBAL_ACTIVE_MODEL_KEY.to_string()),
)
.await
}
/// Retrieves the active model: first tries local storage, then server storage. Ensures validity in the model list.
/// If neither storage yields a valid model, falls back to default.
pub async fn get_active_model(&self, workspace_id: &Uuid, source_key: &SourceKey) -> Model {
let available = self.get_models(workspace_id).await;
// Try local storage
if let Some(storage) = self.local_storage.load_full() {
trace!("[Model Selection] Checking local storage");
if let Some(local_model) = storage.get_selected_model(workspace_id, source_key).await {
trace!("[Model Selection] Found local model: {}", local_model.name);
if available.iter().any(|m| m.name == local_model.name) {
return local_model;
} else {
trace!(
"[Model Selection] Local {} not found in available list, available: {:?}",
local_model.name,
available.iter().map(|m| &m.name).collect::<Vec<_>>()
);
}
} else {
self
.unset_sources
.write()
.await
.insert(source_key.key.clone());
}
}
// use local model if user doesn't set the model for given source
if self
.sources
.iter()
.any(|source| source.source_name() == "local")
{
trace!("[Model Selection] Checking global active model");
let global_source = SourceKey::new(GLOBAL_ACTIVE_MODEL_KEY.to_string());
if let Some(storage) = self.local_storage.load_full() {
if let Some(local_model) = storage
.get_selected_model(workspace_id, &global_source)
.await
{
trace!(
"[Model Selection] Found global active model: {}",
local_model.name
);
if available.iter().any(|m| m.name == local_model.name) {
return local_model;
}
}
}
}
// Try server storage
if let Some(storage) = self.server_storage.load_full() {
trace!("[Model Selection] Checking server storage");
if let Some(server_model) = storage.get_selected_model(workspace_id, source_key).await {
trace!(
"[Model Selection] Found server model: {}",
server_model.name
);
if available.iter().any(|m| m.name == server_model.name) {
return server_model;
} else {
trace!(
"[Model Selection] Server {} not found in available list, available: {:?}",
server_model.name,
available.iter().map(|m| &m.name).collect::<Vec<_>>()
);
}
}
}
// Fallback: default
info!(
"[Model Selection] No active model found, using default: {}",
self.default_model.name
);
self.default_model.clone()
}
/// Sets the active model in both local and server storage
pub async fn set_active_model(
&self,
workspace_id: &Uuid,
source_key: &SourceKey,
model: Model,
) -> Result<(), FlowyError> {
info!(
"[Model Selection] active model: {} for source: {}",
model.name, source_key.key
);
self.unset_sources.write().await.remove(&source_key.key);
let available = self.get_models(workspace_id).await;
if available.contains(&model) {
// Update local storage
if let Some(storage) = self.local_storage.load_full() {
storage
.set_selected_model(workspace_id, source_key, model.clone())
.await?;
}
// Update server storage
if let Some(storage) = self.server_storage.load_full() {
storage
.set_selected_model(workspace_id, source_key, model)
.await?;
}
Ok(())
} else {
Err(
FlowyError::internal()
.with_context(format!("Model '{:?}' not found in available list", model)),
)
}
}
}
/// Namespaced key for model selection storage
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SourceKey {
key: String,
}
impl SourceKey {
/// Create a new SourceKey
pub fn new(key: String) -> Self {
Self { key }
}
/// Combine the UUID key with a model's is_local flag and name to produce a storage identifier
pub fn storage_id(&self) -> String {
format!("ai_models_{}", self.key)
}
}
/// A trait that defines an asynchronous source of AI models
#[async_trait]
pub trait ModelSource: Send + Sync {
/// Identifier for this source (e.g., "local" or "server")
fn source_name(&self) -> &'static str;
/// Asynchronously returns a list of available models from this source
async fn list_chat_models(&self, workspace_id: &Uuid) -> Vec<Model>;
}
pub struct LocalAiSource {
controller: Arc<LocalAIController>,
}
impl LocalAiSource {
pub fn new(controller: Arc<LocalAIController>) -> Self {
Self { controller }
}
}
#[async_trait]
impl ModelSource for LocalAiSource {
fn source_name(&self) -> &'static str {
"local"
}
async fn list_chat_models(&self, _workspace_id: &Uuid) -> Vec<Model> {
match self.controller.ollama.load_full() {
None => vec![],
Some(ollama) => ollama
.list_local_models()
.await
.map(|models| {
models
.into_iter()
.filter(|m| !m.name.contains("embed"))
.map(|m| AIModel::local(m.name, String::new()))
.collect()
})
.unwrap_or_default(),
}
}
}
/// A server-side AI source (e.g., cloud API)
#[derive(Debug, Default)]
struct ServerModelsCache {
models: Vec<Model>,
timestamp: Option<i64>,
}
pub struct ServerAiSource {
cached_models: Arc<RwLock<ServerModelsCache>>,
cloud_service: Arc<dyn ChatCloudService>,
}
impl ServerAiSource {
pub fn new(cloud_service: Arc<dyn ChatCloudService>) -> Self {
Self {
cached_models: Arc::new(Default::default()),
cloud_service,
}
}
async fn update_models_cache(&self, models: &[Model], timestamp: i64) -> FlowyResult<()> {
match self.cached_models.try_write() {
Ok(mut cache) => {
cache.models = models.to_vec();
cache.timestamp = Some(timestamp);
Ok(())
},
Err(_) => {
Err(FlowyError::internal().with_context("Failed to acquire write lock for models cache"))
},
}
}
}
#[async_trait]
impl ModelSource for ServerAiSource {
fn source_name(&self) -> &'static str {
"server"
}
async fn list_chat_models(&self, workspace_id: &Uuid) -> Vec<Model> {
let now = timestamp();
let should_fetch = {
let cached = self.cached_models.read().await;
cached.models.is_empty() || cached.timestamp.map_or(true, |ts| now - ts >= 300)
};
if !should_fetch {
return self.cached_models.read().await.models.clone();
}
match self.cloud_service.get_available_models(workspace_id).await {
Ok(resp) => {
let models = resp
.models
.into_iter()
.map(AIModel::from)
.collect::<Vec<_>>();
if let Err(e) = self.update_models_cache(&models, now).await {
error!("Failed to update cache: {}", e);
}
models
},
Err(err) => {
error!("Failed to fetch models: {}", err);
let cached = self.cached_models.read().await;
if !cached.models.is_empty() {
info!("Returning expired cache due to error");
return cached.models.clone();
}
Vec::new()
},
}
}
}
#[async_trait]
pub trait UserModelStorage: Send + Sync {
async fn get_selected_model(&self, workspace_id: &Uuid, source_key: &SourceKey) -> Option<Model>;
async fn set_selected_model(
&self,
workspace_id: &Uuid,
source_key: &SourceKey,
model: Model,
) -> Result<(), FlowyError>;
}
pub struct ServerModelStorageImpl(pub Arc<dyn ChatCloudService>);
#[async_trait]
impl UserModelStorage for ServerModelStorageImpl {
async fn get_selected_model(
&self,
workspace_id: &Uuid,
_source_key: &SourceKey,
) -> Option<Model> {
let name = self
.0
.get_workspace_default_model(workspace_id)
.await
.ok()?;
Some(Model::server(name, String::new()))
}
async fn set_selected_model(
&self,
workspace_id: &Uuid,
source_key: &SourceKey,
model: Model,
) -> Result<(), FlowyError> {
if model.is_local {
// local model does not need to be set
return Ok(());
}
if source_key.key != GLOBAL_ACTIVE_MODEL_KEY {
return Ok(());
}
self
.0
.set_workspace_default_model(workspace_id, &model.name)
.await?;
Ok(())
}
}
pub struct LocalModelStorageImpl(pub Arc<KVStorePreferences>);
#[async_trait]
impl UserModelStorage for LocalModelStorageImpl {
async fn get_selected_model(
&self,
_workspace_id: &Uuid,
source_key: &SourceKey,
) -> Option<Model> {
self.0.get_object::<AIModel>(&source_key.storage_id())
}
async fn set_selected_model(
&self,
_workspace_id: &Uuid,
source_key: &SourceKey,
model: Model,
) -> Result<(), FlowyError> {
self
.0
.set_object::<AIModel>(&source_key.storage_id(), &model)?;
Ok(())
}
}

View File

@ -0,0 +1,434 @@
use crate::model_select::{ModelSelectionControl, ModelSource, SourceKey, UserModelStorage};
use flowy_ai_pub::cloud::AIModel;
use flowy_error::FlowyError;
use lib_infra::async_trait::async_trait;
use tokio::sync::RwLock;
use uuid::Uuid;
// Mock implementations for testing
struct MockModelSource {
name: &'static str,
models: Vec<AIModel>,
}
#[async_trait]
impl ModelSource for MockModelSource {
fn source_name(&self) -> &'static str {
self.name
}
async fn list_chat_models(&self, _workspace_id: &Uuid) -> Vec<AIModel> {
self.models.clone()
}
}
struct MockModelStorage {
selected_model: RwLock<Option<AIModel>>,
}
impl MockModelStorage {
fn new(initial_model: Option<AIModel>) -> Self {
Self {
selected_model: RwLock::new(initial_model),
}
}
}
#[async_trait]
impl UserModelStorage for MockModelStorage {
async fn get_selected_model(
&self,
_workspace_id: &Uuid,
_source_key: &SourceKey,
) -> Option<AIModel> {
self.selected_model.read().await.clone()
}
async fn set_selected_model(
&self,
_workspace_id: &Uuid,
_source_key: &SourceKey,
model: AIModel,
) -> Result<(), FlowyError> {
*self.selected_model.write().await = Some(model);
Ok(())
}
}
#[tokio::test]
async fn test_empty_model_list_returns_default() {
let control = ModelSelectionControl::new();
let workspace_id = Uuid::new_v4();
let models = control.get_models(&workspace_id).await;
assert_eq!(models.len(), 1);
assert_eq!(models[0], AIModel::default());
}
#[tokio::test]
async fn test_get_models_from_multiple_sources() {
let mut control = ModelSelectionControl::new();
let local_source = Box::new(MockModelSource {
name: "local",
models: vec![
AIModel::local("local-model-1".to_string(), "".to_string()),
AIModel::local("local-model-2".to_string(), "".to_string()),
],
});
let server_source = Box::new(MockModelSource {
name: "server",
models: vec![
AIModel::server("server-model-1".to_string(), "".to_string()),
AIModel::server("server-model-2".to_string(), "".to_string()),
],
});
control.add_source(local_source);
control.add_source(server_source);
let workspace_id = Uuid::new_v4();
let models = control.get_models(&workspace_id).await;
assert_eq!(models.len(), 4);
assert!(models.iter().any(|m| m.name == "local-model-1"));
assert!(models.iter().any(|m| m.name == "local-model-2"));
assert!(models.iter().any(|m| m.name == "server-model-1"));
assert!(models.iter().any(|m| m.name == "server-model-2"));
}
#[tokio::test]
async fn test_get_models_with_specific_local_model() {
let mut control = ModelSelectionControl::new();
let local_source = Box::new(MockModelSource {
name: "local",
models: vec![
AIModel::local("local-model-1".to_string(), "".to_string()),
AIModel::local("local-model-2".to_string(), "".to_string()),
],
});
let server_source = Box::new(MockModelSource {
name: "server",
models: vec![
AIModel::server("server-model-1".to_string(), "".to_string()),
AIModel::server("server-model-2".to_string(), "".to_string()),
],
});
control.add_source(local_source);
control.add_source(server_source);
let workspace_id = Uuid::new_v4();
// Test with specific local model
let models = control
.get_models_with_specific_local_model(&workspace_id, Some("local-model-1".to_string()))
.await;
assert_eq!(models.len(), 3);
assert!(models.iter().any(|m| m.name == "local-model-1"));
assert!(!models.iter().any(|m| m.name == "local-model-2"));
// Test with non-existent local model
let models = control
.get_models_with_specific_local_model(&workspace_id, Some("non-existent".to_string()))
.await;
assert_eq!(models.len(), 2); // Only server models
// Test with no specified local model (should include all local models)
let models = control
.get_models_with_specific_local_model(&workspace_id, None)
.await;
assert_eq!(models.len(), 4);
}
#[tokio::test]
async fn test_get_local_models() {
let mut control = ModelSelectionControl::new();
let local_source = Box::new(MockModelSource {
name: "local",
models: vec![
AIModel::local("local-model-1".to_string(), "".to_string()),
AIModel::local("local-model-2".to_string(), "".to_string()),
],
});
let server_source = Box::new(MockModelSource {
name: "server",
models: vec![AIModel::server(
"server-model-1".to_string(),
"".to_string(),
)],
});
control.add_source(local_source);
control.add_source(server_source);
let workspace_id = Uuid::new_v4();
let local_models = control.get_local_models(&workspace_id).await;
assert_eq!(local_models.len(), 2);
assert!(local_models.iter().all(|m| m.is_local));
}
#[tokio::test]
async fn test_remove_local_source() {
let mut control = ModelSelectionControl::new();
let local_source = Box::new(MockModelSource {
name: "local",
models: vec![AIModel::local("local-model-1".to_string(), "".to_string())],
});
let server_source = Box::new(MockModelSource {
name: "server",
models: vec![AIModel::server(
"server-model-1".to_string(),
"".to_string(),
)],
});
control.add_source(local_source);
control.add_source(server_source);
let workspace_id = Uuid::new_v4();
assert_eq!(control.get_models(&workspace_id).await.len(), 2);
control.remove_local_source();
let models = control.get_models(&workspace_id).await;
assert_eq!(models.len(), 1);
assert_eq!(models[0].name, "server-model-1");
}
#[tokio::test]
async fn test_get_active_model_from_local_storage() {
let mut control = ModelSelectionControl::new();
let workspace_id = Uuid::new_v4();
let source_key = SourceKey::new("test".to_string());
// Add a local source with some models
let local_model = AIModel::local("local-model-1".to_string(), "".to_string());
let local_source = Box::new(MockModelSource {
name: "local",
models: vec![local_model.clone()],
});
control.add_source(local_source);
// Set up local storage with a selected model
let local_storage = MockModelStorage::new(Some(local_model.clone()));
control.set_local_storage(local_storage);
// Get active model should return the locally stored model
let active = control.get_active_model(&workspace_id, &source_key).await;
assert_eq!(active, local_model);
}
#[tokio::test]
async fn test_global_active_model_fallback() {
let mut control = ModelSelectionControl::new();
let workspace_id = Uuid::new_v4();
let source_key = SourceKey::new("specific_source".to_string());
// Add a local source with models
let local_model = AIModel::local("local-model-1".to_string(), "".to_string());
let local_source = Box::new(MockModelSource {
name: "local",
models: vec![local_model.clone()],
});
control.add_source(local_source);
// Set up local storage with global model but not specific source model
let global_storage = MockModelStorage::new(Some(local_model.clone()));
// Set the local storage
control.set_local_storage(global_storage);
// Get active model should fall back to the global model since
// there's no model for the specific source key
let active = control.get_active_model(&workspace_id, &source_key).await;
assert_eq!(active, local_model);
}
#[tokio::test]
async fn test_get_active_model_fallback_to_server_storage() {
let mut control = ModelSelectionControl::new();
let workspace_id = Uuid::new_v4();
let source_key = SourceKey::new("test".to_string());
// Add a server source with some models
let server_model = AIModel::server("server-model-1".to_string(), "".to_string());
let server_source = Box::new(MockModelSource {
name: "server",
models: vec![server_model.clone()],
});
control.add_source(server_source);
// Set up local storage with no selected model
let local_storage = MockModelStorage::new(None);
control.set_local_storage(local_storage);
// Set up server storage with a selected model
let server_storage = MockModelStorage::new(Some(server_model.clone()));
control.set_server_storage(server_storage);
// Get active model should fall back to server storage
let active = control.get_active_model(&workspace_id, &source_key).await;
assert_eq!(active, server_model);
}
#[tokio::test]
async fn test_get_active_model_fallback_to_default() {
let mut control = ModelSelectionControl::new();
let workspace_id = Uuid::new_v4();
let source_key = SourceKey::new("test".to_string());
// Add sources with some models
let model1 = AIModel::local("model-1".to_string(), "".to_string());
let model2 = AIModel::server("model-2".to_string(), "".to_string());
let source = Box::new(MockModelSource {
name: "test",
models: vec![model1.clone(), model2.clone()],
});
control.add_source(source);
// Set up storages with models that don't match available models
let different_model = AIModel::local("non-existent".to_string(), "".to_string());
let local_storage = MockModelStorage::new(Some(different_model.clone()));
let server_storage = MockModelStorage::new(Some(different_model.clone()));
control.set_local_storage(local_storage);
control.set_server_storage(server_storage);
// Should fall back to default model since storages return non-matching models
let active = control.get_active_model(&workspace_id, &source_key).await;
assert_eq!(active, AIModel::default());
}
#[tokio::test]
async fn test_set_active_model() {
let mut control = ModelSelectionControl::new();
let workspace_id = Uuid::new_v4();
let source_key = SourceKey::new("test".to_string());
// Add a source with some models
let model = AIModel::local("model-1".to_string(), "".to_string());
let source = Box::new(MockModelSource {
name: "test",
models: vec![model.clone()],
});
control.add_source(source);
// Set up storage
let local_storage = MockModelStorage::new(None);
let server_storage = MockModelStorage::new(None);
control.set_local_storage(local_storage);
control.set_server_storage(server_storage);
// Set active model
let result = control
.set_active_model(&workspace_id, &source_key, model.clone())
.await;
assert!(result.is_ok());
// Verify that the active model was set correctly
let active = control.get_active_model(&workspace_id, &source_key).await;
assert_eq!(active, model);
}
#[tokio::test]
async fn test_set_active_model_invalid_model() {
let mut control = ModelSelectionControl::new();
let workspace_id = Uuid::new_v4();
let source_key = SourceKey::new("test".to_string());
// Add a source with some models
let available_model = AIModel::local("available-model".to_string(), "".to_string());
let source = Box::new(MockModelSource {
name: "test",
models: vec![available_model.clone()],
});
control.add_source(source);
// Set up storage
let local_storage = MockModelStorage::new(None);
let server_storage = MockModelStorage::new(None);
control.set_local_storage(local_storage);
control.set_server_storage(server_storage);
// Try to set an invalid model
let invalid_model = AIModel::local("invalid-model".to_string(), "".to_string());
let result = control
.set_active_model(&workspace_id, &source_key, invalid_model)
.await;
// Should fail because the model is not in the available list
assert!(result.is_err());
}
#[tokio::test]
async fn test_global_active_model_fallback_with_local_source() {
let mut control = ModelSelectionControl::new();
let workspace_id = Uuid::new_v4();
let source_key = SourceKey::new("specific_source".to_string());
// Add a local source with models
let local_model = AIModel::local("local-model-1".to_string(), "".to_string());
let local_source = Box::new(MockModelSource {
name: "local", // This is important - the fallback only happens when a local source exists
models: vec![local_model.clone()],
});
control.add_source(local_source);
// Create a custom storage that only returns a model for the global key
struct GlobalOnlyStorage {
global_model: AIModel,
}
#[async_trait]
impl UserModelStorage for GlobalOnlyStorage {
async fn get_selected_model(
&self,
_workspace_id: &Uuid,
source_key: &SourceKey,
) -> Option<AIModel> {
if source_key.storage_id()
== format!("ai_models_{}", crate::model_select::GLOBAL_ACTIVE_MODEL_KEY)
{
Some(self.global_model.clone())
} else {
None
}
}
async fn set_selected_model(
&self,
_workspace_id: &Uuid,
_source_key: &SourceKey,
_model: AIModel,
) -> Result<(), FlowyError> {
Ok(())
}
}
// Set up local storage with only the global model
let global_storage = GlobalOnlyStorage {
global_model: local_model.clone(),
};
control.set_local_storage(global_storage);
// Get active model for a specific source_key (not the global one)
// Should fall back to the global model since:
// 1. There's no model for the specific source_key
// 2. There is a local source
// 3. There is a global active model set
let active = control.get_active_model(&workspace_id, &source_key).await;
// Should get the global model
assert_eq!(active, local_model);
}

View File

@ -255,4 +255,15 @@ impl ChatCloudService for AutoSyncChatService {
.get_workspace_default_model(workspace_id)
.await
}
async fn set_workspace_default_model(
&self,
workspace_id: &Uuid,
model: &str,
) -> Result<(), FlowyError> {
self
.cloud_service
.set_workspace_default_model(workspace_id, model)
.await
}
}

View File

@ -1,3 +0,0 @@
pub fn ai_available_models_key(object_id: &str) -> String {
format!("ai_models_{}", object_id)
}

View File

@ -818,6 +818,18 @@ impl ChatCloudService for ServerProvider {
.get_workspace_default_model(workspace_id)
.await
}
async fn set_workspace_default_model(
&self,
workspace_id: &Uuid,
model: &str,
) -> Result<(), FlowyError> {
self
.get_server()?
.chat_service()
.set_workspace_default_model(workspace_id, model)
.await
}
}
#[async_trait]

View File

@ -77,6 +77,12 @@ impl ServerProvider {
local_ai,
}
}
pub fn on_launch_if_authenticated(&self, _workspace_type: &WorkspaceType) {}
pub fn on_sign_in(&self, _workspace_type: &WorkspaceType) {}
pub fn on_sign_up(&self, _workspace_type: &WorkspaceType) {}
pub fn init_after_open_workspace(&self, _workspace_type: &WorkspaceType) {}
pub fn set_auth_type(&self, new_auth_type: AuthType) {
let old_type = self.get_auth_type();

View File

@ -197,10 +197,14 @@ impl UserStatusCallback for UserStatusCallbackImpl {
.initialize_after_sign_in(user_id)
.await?;
self
.ai_manager()?
.initialize_after_sign_in(workspace_id)
.await?;
let ai_manager = self.ai_manager()?;
let cloned_workspace_id = *workspace_id;
self.runtime.spawn(async move {
ai_manager
.initialize_after_sign_in(&cloned_workspace_id)
.await?;
Ok::<_, FlowyError>(())
});
Ok(())
}
@ -248,10 +252,15 @@ impl UserStatusCallback for UserStatusCallbackImpl {
.await
.context("DocumentManager error")?;
self
.ai_manager()?
.initialize_after_sign_up(workspace_id)
.await?;
let ai_manager = self.ai_manager()?;
let cloned_workspace_id = *workspace_id;
self.runtime.spawn(async move {
ai_manager
.initialize_after_sign_up(&cloned_workspace_id)
.await?;
Ok::<_, FlowyError>(())
});
Ok(())
}
@ -283,10 +292,15 @@ impl UserStatusCallback for UserStatusCallbackImpl {
.document_manager()?
.initialize_after_open_workspace(user_id)
.await?;
self
.ai_manager()?
.initialize_after_open_workspace(workspace_id)
.await?;
let ai_manager = self.ai_manager()?;
let cloned_workspace_id = *workspace_id;
self.runtime.spawn(async move {
ai_manager
.initialize_after_open_workspace(&cloned_workspace_id)
.await?;
Ok::<_, FlowyError>(())
});
self
.storage_manager()?
.initialize_after_open_workspace(workspace_id)

View File

@ -8,8 +8,8 @@ use client_api::entity::chat_dto::{
RepeatedChatMessage,
};
use flowy_ai_pub::cloud::{
AIModel, ChatCloudService, ChatMessage, ChatMessageType, ChatSettings, ModelList, StreamAnswer,
StreamComplete, UpdateChatParams,
AFWorkspaceSettingsChange, AIModel, ChatCloudService, ChatMessage, ChatMessageType, ChatSettings,
ModelList, StreamAnswer, StreamComplete, UpdateChatParams,
};
use flowy_error::FlowyError;
use futures_util::{StreamExt, TryStreamExt};
@ -267,4 +267,18 @@ where
.await?;
Ok(setting.ai_model)
}
async fn set_workspace_default_model(
&self,
workspace_id: &Uuid,
model: &str,
) -> Result<(), FlowyError> {
let change = AFWorkspaceSettingsChange::new().ai_model(model.to_string());
let setting = self
.inner
.try_get_client()?
.update_workspace_settings(workspace_id.to_string().as_str(), &change)
.await?;
Ok(())
}
}

View File

@ -319,6 +319,15 @@ impl ChatCloudService for LocalChatServiceImpl {
async fn get_workspace_default_model(&self, _workspace_id: &Uuid) -> Result<String, FlowyError> {
Ok(DEFAULT_AI_MODEL_NAME.to_string())
}
async fn set_workspace_default_model(
&self,
_workspace_id: &Uuid,
_model: &str,
) -> Result<(), FlowyError> {
// do nothing
Ok(())
}
}
fn chat_message_from_row(row: ChatMessageTable) -> ChatMessage {