From 294cf39cad33a127537c6fdd82331e8eba3b19ba Mon Sep 17 00:00:00 2001
From: Louis Dureuil <louis@meilisearch.com>
Date: Thu, 20 Feb 2025 11:37:27 +0100
Subject: [PATCH] Integrate composite embedder

---
 crates/milli/src/vector/mod.rs | 11 +++++++++++
 1 file changed, 11 insertions(+)

diff --git a/crates/milli/src/vector/mod.rs b/crates/milli/src/vector/mod.rs
index d5569a8e6..a253963d2 100644
--- a/crates/milli/src/vector/mod.rs
+++ b/crates/milli/src/vector/mod.rs
@@ -538,6 +538,8 @@ pub enum Embedder {
     Ollama(ollama::Embedder),
     /// An embedder based on making embedding queries against a generic JSON/REST embedding server.
     Rest(rest::Embedder),
+    /// An embedder composed of an embedder at search time and an embedder at indexing time.
+    Composite(composite::Embedder),
 }
 
 /// Configuration for an embedder.
@@ -607,6 +609,7 @@ pub enum EmbedderOptions {
     Ollama(ollama::EmbedderOptions),
     UserProvided(manual::EmbedderOptions),
     Rest(rest::EmbedderOptions),
+    Composite(composite::EmbedderOptions),
 }
 
 impl Default for EmbedderOptions {
@@ -648,6 +651,7 @@ impl Embedder {
             Embedder::Ollama(embedder) => embedder.embed(&texts, deadline),
             Embedder::UserProvided(embedder) => embedder.embed(&texts),
             Embedder::Rest(embedder) => embedder.embed(texts, deadline),
+            Embedder::Composite(embedder) => embedder.search.embed(texts, deadline),
         }
     }
 
@@ -676,6 +680,7 @@ impl Embedder {
             Embedder::Ollama(embedder) => embedder.embed_index(text_chunks, threads),
             Embedder::UserProvided(embedder) => embedder.embed_index(text_chunks),
             Embedder::Rest(embedder) => embedder.embed_index(text_chunks, threads),
+            Embedder::Composite(embedder) => embedder.index.embed_index(text_chunks, threads),
         }
     }
 
@@ -691,6 +696,7 @@ impl Embedder {
             Embedder::Ollama(embedder) => embedder.embed_index_ref(texts, threads),
             Embedder::UserProvided(embedder) => embedder.embed_index_ref(texts),
             Embedder::Rest(embedder) => embedder.embed_index_ref(texts, threads),
+            Embedder::Composite(embedder) => embedder.index.embed_index_ref(texts, threads),
         }
     }
 
@@ -702,6 +708,7 @@ impl Embedder {
             Embedder::Ollama(embedder) => embedder.chunk_count_hint(),
             Embedder::UserProvided(_) => 100,
             Embedder::Rest(embedder) => embedder.chunk_count_hint(),
+            Embedder::Composite(embedder) => embedder.index.chunk_count_hint(),
         }
     }
 
@@ -713,6 +720,7 @@ impl Embedder {
             Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
             Embedder::UserProvided(_) => 1,
             Embedder::Rest(embedder) => embedder.prompt_count_in_chunk_hint(),
+            Embedder::Composite(embedder) => embedder.index.prompt_count_in_chunk_hint(),
         }
     }
 
@@ -724,6 +732,7 @@ impl Embedder {
             Embedder::Ollama(embedder) => embedder.dimensions(),
             Embedder::UserProvided(embedder) => embedder.dimensions(),
             Embedder::Rest(embedder) => embedder.dimensions(),
+            Embedder::Composite(embedder) => embedder.dimensions(),
         }
     }
 
@@ -735,6 +744,7 @@ impl Embedder {
             Embedder::Ollama(embedder) => embedder.distribution(),
             Embedder::UserProvided(embedder) => embedder.distribution(),
             Embedder::Rest(embedder) => embedder.distribution(),
+            Embedder::Composite(embedder) => embedder.distribution(),
         }
     }
 
@@ -745,6 +755,7 @@ impl Embedder {
             | Embedder::Ollama(_)
             | Embedder::Rest(_) => true,
             Embedder::UserProvided(_) => false,
+            Embedder::Composite(embedder) => embedder.index.uses_document_template(),
         }
     }
 }