Simplify the analytics chat completions aggragetor

This commit is contained in:
Kerollmops 2025-06-25 11:50:26 +02:00
parent 5f50fc9464
commit adc9976615
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F
2 changed files with 12 additions and 16 deletions

View File

@ -2,15 +2,10 @@ use std::collections::BinaryHeap;
use serde_json::{json, Value};
use crate::aggregate_methods;
use crate::analytics::{Aggregate, AggregateMethod};
aggregate_methods!(
ChatCompletionPOST => "Chat Completion POST",
);
use crate::analytics::Aggregate;
#[derive(Default)]
pub struct ChatCompletionAggregator<Method: AggregateMethod> {
pub struct ChatCompletionAggregator {
// requests
total_received: usize,
total_succeeded: usize,
@ -23,22 +18,23 @@ pub struct ChatCompletionAggregator<Method: AggregateMethod> {
// model usage tracking
models_used: std::collections::HashMap<String, usize>,
_method: std::marker::PhantomData<Method>,
}
impl<Method: AggregateMethod> ChatCompletionAggregator<Method> {
impl ChatCompletionAggregator {
pub fn from_request(model: &str, message_count: usize, is_stream: bool) -> Self {
let mut models_used = std::collections::HashMap::new();
models_used.insert(model.to_string(), 1);
Self {
total_received: 1,
total_succeeded: 0,
time_spent: BinaryHeap::new(),
total_messages: message_count,
total_streamed_requests: if is_stream { 1 } else { 0 },
total_non_streamed_requests: if is_stream { 0 } else { 1 },
models_used,
..Default::default()
}
}
@ -48,9 +44,9 @@ impl<Method: AggregateMethod> ChatCompletionAggregator<Method> {
}
}
impl<Method: AggregateMethod> Aggregate for ChatCompletionAggregator<Method> {
impl Aggregate for ChatCompletionAggregator {
fn event_name(&self) -> &'static str {
Method::event_name()
"Chat Completion POST"
}
fn aggregate(mut self: Box<Self>, new: Box<Self>) -> Box<Self> {

View File

@ -36,7 +36,7 @@ use serde_json::json;
use tokio::runtime::Handle;
use tokio::sync::mpsc::error::SendError;
use super::chat_completion_analytics::{ChatCompletionAggregator, ChatCompletionPOST};
use super::chat_completion_analytics::ChatCompletionAggregator;
use super::config::Config;
use super::errors::{MistralError, OpenAiOutsideError, StreamErrorEvent};
use super::utils::format_documents;
@ -325,7 +325,7 @@ async fn non_streamed_chat(
index_scheduler.features().check_chat_completions("using the /chats chat completions route")?;
// Create analytics aggregator
let aggregate = ChatCompletionAggregator::<ChatCompletionPOST>::from_request(
let aggregate = ChatCompletionAggregator::from_request(
&chat_completion.model,
chat_completion.messages.len(),
false, // non_streamed_chat is not streaming
@ -466,7 +466,7 @@ async fn streamed_chat(
};
// Create analytics aggregator
let mut aggregate = ChatCompletionAggregator::<ChatCompletionPOST>::from_request(
let mut aggregate = ChatCompletionAggregator::from_request(
&chat_completion.model,
chat_completion.messages.len(),
true, // streamed_chat is always streaming