Fix after upgrading candle

This commit is contained in:
Clément Renault 2025-01-08 15:59:56 +01:00
parent 68333424c6
commit 0ee4671a91
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F
3 changed files with 20 additions and 6 deletions

View File

@ -163,8 +163,10 @@ impl Embedder {
let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?;
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
let embeddings =
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
let embeddings = self
.model
.forward(&token_ids, &token_type_ids, None)
.map_err(EmbedError::model_forward)?;
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) =
@ -185,8 +187,10 @@ impl Embedder {
Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?;
let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?;
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
let embeddings =
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
let embeddings = self
.model
.forward(&token_ids, &token_type_ids, None)
.map_err(EmbedError::model_forward)?;
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) =

View File

@ -1,3 +1,4 @@
use std::fmt;
use std::time::Instant;
use ordered_float::OrderedFloat;
@ -168,7 +169,6 @@ fn infer_api_key() -> String {
.unwrap_or_default()
}
#[derive(Debug)]
pub struct Embedder {
tokenizer: tiktoken_rs::CoreBPE,
rest_embedder: RestEmbedder,
@ -302,3 +302,13 @@ impl Embedder {
self.options.distribution()
}
}
impl fmt::Debug for Embedder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Embedder")
.field("tokenizer", &"CoreBPE")
.field("rest_embedder", &self.rest_embedder)
.field("options", &self.options)
.finish()
}
}

View File

@ -175,7 +175,7 @@ impl Embedder {
pub fn embed_tokens(
&self,
tokens: &[usize],
tokens: &[u32],
deadline: Option<Instant>,
) -> Result<Embedding, EmbedError> {
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?;