mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-01-25 12:47:28 +01:00
Fix after upgrading candle
This commit is contained in:
parent
68333424c6
commit
0ee4671a91
@ -163,8 +163,10 @@ impl Embedder {
|
|||||||
|
|
||||||
let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?;
|
let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?;
|
||||||
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
||||||
let embeddings =
|
let embeddings = self
|
||||||
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
|
.model
|
||||||
|
.forward(&token_ids, &token_type_ids, None)
|
||||||
|
.map_err(EmbedError::model_forward)?;
|
||||||
|
|
||||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
let (_n_sentence, n_tokens, _hidden_size) =
|
let (_n_sentence, n_tokens, _hidden_size) =
|
||||||
@ -185,8 +187,10 @@ impl Embedder {
|
|||||||
Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?;
|
Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?;
|
||||||
let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?;
|
let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?;
|
||||||
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
||||||
let embeddings =
|
let embeddings = self
|
||||||
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
|
.model
|
||||||
|
.forward(&token_ids, &token_type_ids, None)
|
||||||
|
.map_err(EmbedError::model_forward)?;
|
||||||
|
|
||||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
let (_n_sentence, n_tokens, _hidden_size) =
|
let (_n_sentence, n_tokens, _hidden_size) =
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
use std::fmt;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
@ -168,7 +169,6 @@ fn infer_api_key() -> String {
|
|||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Embedder {
|
pub struct Embedder {
|
||||||
tokenizer: tiktoken_rs::CoreBPE,
|
tokenizer: tiktoken_rs::CoreBPE,
|
||||||
rest_embedder: RestEmbedder,
|
rest_embedder: RestEmbedder,
|
||||||
@ -302,3 +302,13 @@ impl Embedder {
|
|||||||
self.options.distribution()
|
self.options.distribution()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for Embedder {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.debug_struct("Embedder")
|
||||||
|
.field("tokenizer", &"CoreBPE")
|
||||||
|
.field("rest_embedder", &self.rest_embedder)
|
||||||
|
.field("options", &self.options)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -175,7 +175,7 @@ impl Embedder {
|
|||||||
|
|
||||||
pub fn embed_tokens(
|
pub fn embed_tokens(
|
||||||
&self,
|
&self,
|
||||||
tokens: &[usize],
|
tokens: &[u32],
|
||||||
deadline: Option<Instant>,
|
deadline: Option<Instant>,
|
||||||
) -> Result<Embedding, EmbedError> {
|
) -> Result<Embedding, EmbedError> {
|
||||||
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?;
|
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user