Update ollama and openai impls to use the rest embedder internally

This commit is contained in:
Louis Dureuil 2024-03-19 15:41:37 +01:00
parent 8708cbef25
commit ac52c857e8
No known key found for this signature in database
8 changed files with 394 additions and 779 deletions

View file

@ -339,6 +339,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
prompt_reader: grenad::Reader<R>,
indexer: GrenadParameters,
embedder: Arc<Embedder>,
request_threads: &rayon::ThreadPool,
) -> Result<grenad::Reader<BufReader<File>>> {
puffin::profile_function!();
let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism
@ -376,7 +377,10 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
if chunks.len() == chunks.capacity() {
let chunked_embeds = embedder
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)))
.embed_chunks(
std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)),
request_threads,
)
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
@ -394,7 +398,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
// send last chunk
if !chunks.is_empty() {
let chunked_embeds = embedder
.embed_chunks(std::mem::take(&mut chunks))
.embed_chunks(std::mem::take(&mut chunks), request_threads)
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids
@ -408,7 +412,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
if !current_chunk.is_empty() {
let embeds = embedder
.embed_chunks(vec![std::mem::take(&mut current_chunk)])
.embed_chunks(vec![std::mem::take(&mut current_chunk)], request_threads)
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;

View file

@ -238,7 +238,15 @@ fn send_original_documents_data(
let documents_chunk_cloned = original_documents_chunk.clone();
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
let request_threads = rayon::ThreadPoolBuilder::new()
.num_threads(crate::vector::REQUEST_PARALLELISM)
.thread_name(|index| format!("embedding-request-{index}"))
.build()
.unwrap();
rayon::spawn(move || {
/// FIXME: unwrap
for (name, (embedder, prompt)) in embedders {
let result = extract_vector_points(
documents_chunk_cloned.clone(),
@ -249,7 +257,12 @@ fn send_original_documents_data(
);
match result {
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) {
let embeddings = match extract_embeddings(
prompts,
indexer,
embedder.clone(),
&request_threads,
) {
Ok(results) => Some(results),
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));