mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-23 13:24:27 +01:00
Fix hf embedder
This commit is contained in:
parent
e32677999f
commit
bef8fc6cf1
@ -183,14 +183,17 @@ impl Embedder {
|
|||||||
let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids };
|
let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids };
|
||||||
let token_ids =
|
let token_ids =
|
||||||
Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?;
|
Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?;
|
||||||
|
let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?;
|
||||||
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
||||||
let embeddings =
|
let embeddings =
|
||||||
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
|
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
|
||||||
|
|
||||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
let (n_tokens, _hidden_size) = embeddings.dims2().map_err(EmbedError::tensor_shape)?;
|
let (_n_sentence, n_tokens, _hidden_size) =
|
||||||
let embedding = (embeddings.sum(0).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
|
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
|
||||||
|
let embedding = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
|
||||||
.map_err(EmbedError::tensor_shape)?;
|
.map_err(EmbedError::tensor_shape)?;
|
||||||
|
let embedding = embedding.squeeze(0).map_err(EmbedError::tensor_shape)?;
|
||||||
let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?;
|
let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?;
|
||||||
Ok(embedding)
|
Ok(embedding)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user