Simplify the analytics chat completions aggragetor

This commit is contained in:
Kerollmops 2025-06-25 11:50:26 +02:00
parent 5f50fc9464
commit adc9976615
No known key found for this signature in database
GPG key ID: F250A4C4E3AE5F5F
2 changed files with 12 additions and 16 deletions

View file

@ -2,15 +2,10 @@ use std::collections::BinaryHeap;
use serde_json::{json, Value}; use serde_json::{json, Value};
use crate::aggregate_methods; use crate::analytics::Aggregate;
use crate::analytics::{Aggregate, AggregateMethod};
aggregate_methods!(
ChatCompletionPOST => "Chat Completion POST",
);
#[derive(Default)] #[derive(Default)]
pub struct ChatCompletionAggregator<Method: AggregateMethod> { pub struct ChatCompletionAggregator {
// requests // requests
total_received: usize, total_received: usize,
total_succeeded: usize, total_succeeded: usize,
@ -23,22 +18,23 @@ pub struct ChatCompletionAggregator<Method: AggregateMethod> {
// model usage tracking // model usage tracking
models_used: std::collections::HashMap<String, usize>, models_used: std::collections::HashMap<String, usize>,
_method: std::marker::PhantomData<Method>,
} }
impl<Method: AggregateMethod> ChatCompletionAggregator<Method> { impl ChatCompletionAggregator {
pub fn from_request(model: &str, message_count: usize, is_stream: bool) -> Self { pub fn from_request(model: &str, message_count: usize, is_stream: bool) -> Self {
let mut models_used = std::collections::HashMap::new(); let mut models_used = std::collections::HashMap::new();
models_used.insert(model.to_string(), 1); models_used.insert(model.to_string(), 1);
Self { Self {
total_received: 1, total_received: 1,
total_succeeded: 0,
time_spent: BinaryHeap::new(),
total_messages: message_count, total_messages: message_count,
total_streamed_requests: if is_stream { 1 } else { 0 }, total_streamed_requests: if is_stream { 1 } else { 0 },
total_non_streamed_requests: if is_stream { 0 } else { 1 }, total_non_streamed_requests: if is_stream { 0 } else { 1 },
models_used, models_used,
..Default::default()
} }
} }
@ -48,9 +44,9 @@ impl<Method: AggregateMethod> ChatCompletionAggregator<Method> {
} }
} }
impl<Method: AggregateMethod> Aggregate for ChatCompletionAggregator<Method> { impl Aggregate for ChatCompletionAggregator {
fn event_name(&self) -> &'static str { fn event_name(&self) -> &'static str {
Method::event_name() "Chat Completion POST"
} }
fn aggregate(mut self: Box<Self>, new: Box<Self>) -> Box<Self> { fn aggregate(mut self: Box<Self>, new: Box<Self>) -> Box<Self> {

View file

@ -36,7 +36,7 @@ use serde_json::json;
use tokio::runtime::Handle; use tokio::runtime::Handle;
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use super::chat_completion_analytics::{ChatCompletionAggregator, ChatCompletionPOST}; use super::chat_completion_analytics::ChatCompletionAggregator;
use super::config::Config; use super::config::Config;
use super::errors::{MistralError, OpenAiOutsideError, StreamErrorEvent}; use super::errors::{MistralError, OpenAiOutsideError, StreamErrorEvent};
use super::utils::format_documents; use super::utils::format_documents;
@ -325,7 +325,7 @@ async fn non_streamed_chat(
index_scheduler.features().check_chat_completions("using the /chats chat completions route")?; index_scheduler.features().check_chat_completions("using the /chats chat completions route")?;
// Create analytics aggregator // Create analytics aggregator
let aggregate = ChatCompletionAggregator::<ChatCompletionPOST>::from_request( let aggregate = ChatCompletionAggregator::from_request(
&chat_completion.model, &chat_completion.model,
chat_completion.messages.len(), chat_completion.messages.len(),
false, // non_streamed_chat is not streaming false, // non_streamed_chat is not streaming
@ -466,7 +466,7 @@ async fn streamed_chat(
}; };
// Create analytics aggregator // Create analytics aggregator
let mut aggregate = ChatCompletionAggregator::<ChatCompletionPOST>::from_request( let mut aggregate = ChatCompletionAggregator::from_request(
&chat_completion.model, &chat_completion.model,
chat_completion.messages.len(), chat_completion.messages.len(),
true, // streamed_chat is always streaming true, // streamed_chat is always streaming