mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-01-10 21:44:34 +01:00
Fix after upgrading candle
This commit is contained in:
parent
68333424c6
commit
0ee4671a91
@ -163,8 +163,10 @@ impl Embedder {
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?;
|
||||
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
||||
let embeddings =
|
||||
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
|
||||
let embeddings = self
|
||||
.model
|
||||
.forward(&token_ids, &token_type_ids, None)
|
||||
.map_err(EmbedError::model_forward)?;
|
||||
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) =
|
||||
@ -185,8 +187,10 @@ impl Embedder {
|
||||
Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?;
|
||||
let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?;
|
||||
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
||||
let embeddings =
|
||||
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
|
||||
let embeddings = self
|
||||
.model
|
||||
.forward(&token_ids, &token_type_ids, None)
|
||||
.map_err(EmbedError::model_forward)?;
|
||||
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) =
|
||||
|
@ -1,3 +1,4 @@
|
||||
use std::fmt;
|
||||
use std::time::Instant;
|
||||
|
||||
use ordered_float::OrderedFloat;
|
||||
@ -168,7 +169,6 @@ fn infer_api_key() -> String {
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedder {
|
||||
tokenizer: tiktoken_rs::CoreBPE,
|
||||
rest_embedder: RestEmbedder,
|
||||
@ -302,3 +302,13 @@ impl Embedder {
|
||||
self.options.distribution()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Embedder {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Embedder")
|
||||
.field("tokenizer", &"CoreBPE")
|
||||
.field("rest_embedder", &self.rest_embedder)
|
||||
.field("options", &self.options)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
@ -175,7 +175,7 @@ impl Embedder {
|
||||
|
||||
pub fn embed_tokens(
|
||||
&self,
|
||||
tokens: &[usize],
|
||||
tokens: &[u32],
|
||||
deadline: Option<Instant>,
|
||||
) -> Result<Embedding, EmbedError> {
|
||||
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?;
|
||||
|
Loading…
x
Reference in New Issue
Block a user