mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-07-04 04:17:10 +02:00
Introduce the support of Azure, Gemini, vLLM
This commit is contained in:
parent
4dfb89168b
commit
70670c3be4
8 changed files with 261 additions and 18 deletions
|
@ -7,7 +7,6 @@ use std::time::Duration;
|
|||
use actix_web::web::{self, Data};
|
||||
use actix_web::{Either, HttpRequest, HttpResponse, Responder};
|
||||
use actix_web_lab::sse::{Event, Sse};
|
||||
use async_openai::config::{Config, OpenAIConfig};
|
||||
use async_openai::types::{
|
||||
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
||||
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
|
||||
|
@ -35,6 +34,7 @@ use serde_json::json;
|
|||
use tokio::runtime::Handle;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
|
||||
use super::config::Config;
|
||||
use super::errors::StreamErrorEvent;
|
||||
use super::utils::format_documents;
|
||||
use super::{
|
||||
|
@ -312,15 +312,8 @@ async fn non_streamed_chat(
|
|||
}
|
||||
};
|
||||
|
||||
let mut config = OpenAIConfig::default();
|
||||
if let Some(api_key) = chat_settings.api_key.as_ref() {
|
||||
config = config.with_api_key(api_key);
|
||||
}
|
||||
if let Some(base_api) = chat_settings.base_api.as_ref() {
|
||||
config = config.with_api_base(base_api);
|
||||
}
|
||||
let config = Config::new(&chat_settings);
|
||||
let client = Client::with_config(config);
|
||||
|
||||
let auth_token = extract_token_from_request(&req)?.unwrap();
|
||||
// TODO do function support later
|
||||
let _function_support =
|
||||
|
@ -413,14 +406,7 @@ async fn streamed_chat(
|
|||
};
|
||||
drop(rtxn);
|
||||
|
||||
let mut config = OpenAIConfig::default();
|
||||
if let Some(api_key) = chat_settings.api_key.as_ref() {
|
||||
config = config.with_api_key(api_key);
|
||||
}
|
||||
if let Some(base_api) = chat_settings.base_api.as_ref() {
|
||||
config = config.with_api_base(base_api);
|
||||
}
|
||||
|
||||
let config = Config::new(&chat_settings);
|
||||
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
|
||||
let function_support =
|
||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?;
|
||||
|
@ -465,7 +451,7 @@ async fn streamed_chat(
|
|||
/// Updates the chat completion with the new messages, streams the LLM tokens,
|
||||
/// and report progress and errors.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn run_conversation<C: Config>(
|
||||
async fn run_conversation<C: async_openai::config::Config>(
|
||||
index_scheduler: &GuardedData<
|
||||
ActionPolicy<{ actions::CHAT_COMPLETIONS }>,
|
||||
Data<IndexScheduler>,
|
||||
|
|
87
crates/meilisearch/src/routes/chats/config.rs
Normal file
87
crates/meilisearch/src/routes/chats/config.rs
Normal file
|
@ -0,0 +1,87 @@
|
|||
use async_openai::config::{AzureConfig, OpenAIConfig};
|
||||
use meilisearch_types::features::ChatCompletionSettings as DbChatSettings;
|
||||
use reqwest::header::HeaderMap;
|
||||
use secrecy::SecretString;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Config {
|
||||
OpenAiCompatible(OpenAIConfig),
|
||||
AzureOpenAiCompatible(AzureConfig),
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn new(chat_settings: &DbChatSettings) -> Self {
|
||||
use meilisearch_types::features::ChatCompletionSource::*;
|
||||
match chat_settings.source {
|
||||
OpenAi | Mistral | Gemini | VLlm => {
|
||||
let mut config = OpenAIConfig::default();
|
||||
if let Some(org_id) = chat_settings.org_id.as_ref() {
|
||||
config = config.with_org_id(org_id);
|
||||
}
|
||||
if let Some(project_id) = chat_settings.project_id.as_ref() {
|
||||
config = config.with_project_id(project_id);
|
||||
}
|
||||
if let Some(api_key) = chat_settings.api_key.as_ref() {
|
||||
config = config.with_api_key(api_key);
|
||||
}
|
||||
if let Some(base_api) = chat_settings.base_api.as_ref() {
|
||||
config = config.with_api_base(base_api);
|
||||
}
|
||||
Self::OpenAiCompatible(config)
|
||||
}
|
||||
AzureOpenAi => {
|
||||
let mut config = AzureConfig::default();
|
||||
if let Some(version) = chat_settings.api_version.as_ref() {
|
||||
config = config.with_api_version(version);
|
||||
}
|
||||
if let Some(deployment_id) = chat_settings.deployment_id.as_ref() {
|
||||
config = config.with_deployment_id(deployment_id);
|
||||
}
|
||||
if let Some(api_key) = chat_settings.api_key.as_ref() {
|
||||
config = config.with_api_key(api_key);
|
||||
}
|
||||
if let Some(base_api) = chat_settings.base_api.as_ref() {
|
||||
config = config.with_api_base(base_api);
|
||||
}
|
||||
Self::AzureOpenAiCompatible(config)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl async_openai::config::Config for Config {
|
||||
fn headers(&self) -> HeaderMap {
|
||||
match self {
|
||||
Config::OpenAiCompatible(config) => config.headers(),
|
||||
Config::AzureOpenAiCompatible(config) => config.headers(),
|
||||
}
|
||||
}
|
||||
|
||||
fn url(&self, path: &str) -> String {
|
||||
match self {
|
||||
Config::OpenAiCompatible(config) => config.url(path),
|
||||
Config::AzureOpenAiCompatible(config) => config.url(path),
|
||||
}
|
||||
}
|
||||
|
||||
fn query(&self) -> Vec<(&str, &str)> {
|
||||
match self {
|
||||
Config::OpenAiCompatible(config) => config.query(),
|
||||
Config::AzureOpenAiCompatible(config) => config.query(),
|
||||
}
|
||||
}
|
||||
|
||||
fn api_base(&self) -> &str {
|
||||
match self {
|
||||
Config::OpenAiCompatible(config) => config.api_base(),
|
||||
Config::AzureOpenAiCompatible(config) => config.api_base(),
|
||||
}
|
||||
}
|
||||
|
||||
fn api_key(&self) -> &SecretString {
|
||||
match self {
|
||||
Config::OpenAiCompatible(config) => config.api_key(),
|
||||
Config::AzureOpenAiCompatible(config) => config.api_key(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -18,6 +18,7 @@ use crate::extractors::authentication::GuardedData;
|
|||
use crate::routes::PAGINATION_DEFAULT_LIMIT;
|
||||
|
||||
pub mod chat_completions;
|
||||
mod config;
|
||||
mod errors;
|
||||
pub mod settings;
|
||||
mod utils;
|
||||
|
|
|
@ -109,6 +109,26 @@ async fn patch_settings(
|
|||
Setting::Reset => DbChatCompletionSource::default(),
|
||||
Setting::NotSet => old_settings.source,
|
||||
},
|
||||
org_id: match new.org_id {
|
||||
Setting::Set(new_org_id) => Some(new_org_id),
|
||||
Setting::Reset => None,
|
||||
Setting::NotSet => old_settings.org_id,
|
||||
},
|
||||
project_id: match new.project_id {
|
||||
Setting::Set(new_project_id) => Some(new_project_id),
|
||||
Setting::Reset => None,
|
||||
Setting::NotSet => old_settings.project_id,
|
||||
},
|
||||
api_version: match new.api_version {
|
||||
Setting::Set(new_api_version) => Some(new_api_version),
|
||||
Setting::Reset => None,
|
||||
Setting::NotSet => old_settings.api_version,
|
||||
},
|
||||
deployment_id: match new.deployment_id {
|
||||
Setting::Set(new_deployment_id) => Some(new_deployment_id),
|
||||
Setting::Reset => None,
|
||||
Setting::NotSet => old_settings.deployment_id,
|
||||
},
|
||||
base_api: match new.base_api {
|
||||
Setting::Set(new_base_api) => Some(new_base_api),
|
||||
Setting::Reset => None,
|
||||
|
@ -171,6 +191,22 @@ pub struct GlobalChatSettings {
|
|||
#[schema(value_type = Option<ChatCompletionSource>)]
|
||||
pub source: Setting<ChatCompletionSource>,
|
||||
#[serde(default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionOrgId>)]
|
||||
#[schema(value_type = Option<String>, example = json!("dcba4321..."))]
|
||||
pub org_id: Setting<String>,
|
||||
#[serde(default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionProjectId>)]
|
||||
#[schema(value_type = Option<String>, example = json!("4321dcba..."))]
|
||||
pub project_id: Setting<String>,
|
||||
#[serde(default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionApiVersion>)]
|
||||
#[schema(value_type = Option<String>, example = json!("2024-02-01"))]
|
||||
pub api_version: Setting<String>,
|
||||
#[serde(default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionDeploymentId>)]
|
||||
#[schema(value_type = Option<String>, example = json!("1234abcd..."))]
|
||||
pub deployment_id: Setting<String>,
|
||||
#[serde(default)]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidChatCompletionBaseApi>)]
|
||||
#[schema(value_type = Option<String>, example = json!("https://api.mistral.ai/v1"))]
|
||||
pub base_api: Setting<String>,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue